# Vanilla VQ-VAE applied on MNIST

## Dependencies

In [2]:
!nvidia-smi

Wed Nov 23 19:40:37 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:61:00.0 Off |                    0 |
| N/A   56C    P0   209W / 300W |  21035MiB / 32510MiB |     92%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:62:00.0 Off |                    0 |
| N/A   29C    P0    42W / 300W |      0MiB / 32510MiB |      0%      Defaul

In [3]:
import os, sys
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
PROJECT_ROOT = "/project/fdreyer/projects/vqvae-vc"
sys.path.append(PROJECT_ROOT)
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.optim import Adam
from src.models.encoders import ResNetEncoder
from src.models.quantizers import VanillaVectorQuantizer
from src.models.decoders import ResNetDecoder
from src.models.vqvae import VQVAE
from src.losses import VQVAELoss
from src.training import VQVAETrainer

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


## Load MNIST Dataset

In [4]:
train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
test_dataset = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

In [5]:
print(f"Train data: \n{train_dataset}")
print(f"\nTest data: \n{test_dataset}")

Train data: 
Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()

Test data: 
Dataset MNIST
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: ToTensor()


## Configure DataLoader

In [6]:
BATCH_SIZE = 64

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [7]:
X, y = next(iter(train_dataloader))
X.shape

torch.Size([64, 1, 28, 28])

## Instatiate VQ-VAE

In [8]:
encoder = ResNetEncoder(1, 32).to(device)
quantizer = VanillaVectorQuantizer(32, 100).to(device)
decoder = ResNetDecoder(32, 1).to(device)
vqvae = VQVAE(encoder, quantizer, decoder).to(device)

In [9]:
loss_fn = VQVAELoss(beta=0.25)
optimizer = Adam(vqvae.parameters())
trainer = VQVAETrainer(vqvae, train_dataloader, loss_fn, optimizer, device=device)

In [10]:
trainer.train(10)

Epoch 1 
-------------------
Loss: 0.506821 [    0 / 60000]
Loss: 0.389691 [ 6400 / 60000]
Loss: 0.144186 [12800 / 60000]
Loss: 0.086042 [19200 / 60000]
Loss: 0.066518 [25600 / 60000]
Loss: 0.055518 [32000 / 60000]


KeyboardInterrupt: 