# Load VQVAE and dataset. 

In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import argparse
from models.vqvae import VQVAE
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torchvision.utils import make_grid
import numpy as np
import utils
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(path):
    if torch.cuda.is_available():
        data = torch.load(path)
    else:
        data = torch.load(path, map_location=lambda storage, loc: storage)
    
    params = data["hyperparameters"]
    
    model = VQVAE(params['n_hiddens'], params['n_residual_hiddens'],
                  params['n_half_conv_layers'],
                  params['n_residual_layers'], 
                  params['n_embeddings'], 
                  params['embedding_dim'], params['beta']).to(device)

    model.load_state_dict(data['model'])
    
    return model, data


model_filename = "./results/anime 2023-06-28 04.42.54/12000.pth"
model, vqvae_data = load_model(model_filename)
training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders("anime", batch_size=32)

# Generate encoding indices. (1x512)

In [2]:
encoding_indices_dataset = []
for x, useless in tqdm(training_loader):
    x: torch.Tensor = x.to(device)
    # print("Read image data, shape:", x.shape) # [32, 3, 64, 64]
    embedding_loss, x_recon, encoding_indices, perplexity = model(x)
    encoding_indices = encoding_indices.view(-1, 64) # 32x64. 32 is batch size. 64 is the encoding length for single image. 
    # print("Generate encoding indices of image data, shape:", encoding_indices.shape) # [32, 64]
    encoding_indices_dataset.append(encoding_indices)

encoding_indices_dataset_tensor = torch.concat(encoding_indices_dataset)
print("Concat encoding indices of training data, shape:", encoding_indices_dataset_tensor.shape)

100%|██████████| 658/658 [00:14<00:00, 46.74it/s]

Concat encoding indices of training data, shape: torch.Size([21051, 64])





In [3]:
os.makedirs("./data/encoding_indices/", exist_ok=True)
np.save("./data/encoding_indices/anime.npy", encoding_indices_dataset_tensor.cpu())

In [4]:
print(np.load("./data/encoding_indices/anime.npy"))
print(np.load("./data/encoding_indices/anime.npy").shape)

[[286  99  99 ... 130 190 377]
 [445 149 116 ... 472 360 316]
 [154 138  99 ...  69 319 306]
 ...
 [460 212 264 ... 504 314 363]
 [389 498 102 ... 286 299 210]
 [235 132 105 ... 357 453 286]]
(21051, 64)
