In [1]:
from train import train_vqvae, DEVICE
from modules.vqvae import VQVAE
from dataset import train_dl, test_dl, NumpyDataset, codebook_transform, batch_size
from torch.utils.data import DataLoader
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

def show(img):
    """
    Plotting func
    """
    np_img = img.numpy()
    fig = plt.imshow(np.transpose(np_img, (1, 2, 0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

In [2]:
model = VQVAE(latent_dim=128, res_h_dim=32, num_embeddings=512, embedding_dim=64, beta=0.25)
model.to(DEVICE)
EPOCHS = 2

optim = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
training_reconstruction_loss = []
for i in range(EPOCHS):
    print(f"=======================EPOCH = {i + 1}======================")
    loss = train_vqvae(dl=train_dl, model=model, optim=optim)
    training_reconstruction_loss.append(loss)
    print(f"Reconstruction loss: {loss}")

batch    0/354 	 |current loss: 1.367112
batch   25/354 	 |current loss: 0.695352
batch   50/354 	 |current loss: 0.356504
batch   75/354 	 |current loss: 0.252510
batch  100/354 	 |current loss: 0.244847
batch  125/354 	 |current loss: 0.157223
batch  150/354 	 |current loss: 0.143243
batch  175/354 	 |current loss: 0.166342
batch  200/354 	 |current loss: 0.186239
batch  225/354 	 |current loss: 0.135978
batch  250/354 	 |current loss: 0.263740
batch  275/354 	 |current loss: 0.202205
batch  300/354 	 |current loss: 0.156482
batch  325/354 	 |current loss: 0.116243
batch  350/354 	 |current loss: 0.127772
Reconstruction loss: 0.2808905892958075
batch    0/354 	 |current loss: 0.099904
batch   25/354 	 |current loss: 0.140721
batch   50/354 	 |current loss: 0.216068
batch   75/354 	 |current loss: 0.151931
batch  100/354 	 |current loss: 0.213867
batch  125/354 	 |current loss: 0.207343


In [None]:
test_real = next(iter(test_dl))  # load some from test dl
test_real = test_real[0]
test_real = test_real.to(DEVICE)
pre_conv = model.pre_quantization_conv(model.encoder(test_real))  # encoder, reshape
_, test_quantized, _, _ = model.vector_quantizer(pre_conv)
test_reconstructions = model.decoder(test_quantized)

In [None]:
# show reconstructed images
show(torchvision.utils.make_grid(test_reconstructions.cpu()))

In [None]:
# show original images
show(torchvision.utils.make_grid(test_real.cpu()))

In [None]:
test_input = next(iter(test_dl))
test_input = test_input[0][0]
test_input = test_input.unsqueeze(0)
print(test_input.shape)
test_input = test_input.to(DEVICE)
test_encoded = model.encoder(test_input)
test_encoded = model.pre_quantization_conv(test_encoded)
_, test_encoded, encodings, indices = model.vector_quantizer(test_encoded)
decoded = model.decoder(test_encoded)
# z is codebook index
# Plot codebook index
plot_image = indices.view(64, 64)
print(torch.unique(indices.to('cpu')))
plot_image = plot_image.to('cpu')
detached_image = plot_image.detach().numpy()

test_input = test_input[0][0].cpu().detach().numpy()
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.suptitle('Real vs Codebook index')
ax1.imshow(test_input)
ax2.imshow(detached_image)

In [None]:
# Convert Codebook indice to quatized
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.suptitle('Codebook indice vs quantized')
zz = model.vector_quantizer.get_quantized(indices)
zzz = model.decoder(zz)
immi = zzz[0]
immi = immi.to('cpu')
immi = immi.detach().numpy()
ax2.imshow(immi[1])
ax1.imshow(indices.cpu().view(64, 64).detach().numpy())

In [None]:
real = next(iter(test_dl))
codebook_data = []
while True:
    next_val = next(iter(test_dl),'end')
    if next_val == 'end':
        break
    else:
        current_input = next_val[0][0]
        current_input = current_input.unsqueeze(0)
        current_input = current_input.to(DEVICE)
        encoded = model.encoder(current_input)
        encoded = model.pre_quantization_conv(encoded)
        _, encoded, encodings, indices = model.vector_quantizer(encoded)
        decoded = model.decoder(encoded)
        # z is codebook index
        # Plot codebook index
        plot_image = indices.view(64, 64)
        print(torch.unique(indices.to('cpu')))
        plot_image = plot_image.to('cpu')
        detached_image = plot_image.detach().numpy()

        codebook_data.append(detached_image)


codebook_set = NumpyDataset(data=codebook_data, targets=[1], transform=codebook_transform)
codebook_loader = DataLoader(codebook_set, batch_size=batch_size)