In this assignment we are going to explore vector quantization and different sorts of it. 

In [None]:
from pathlib import Path

import lightning as L
import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
%load_ext tensorboard

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from data import *
from model import *
from tests import *
from vector_quantization import *

In [None]:
device_id = 7
device = torch.device(f"cuda:{device_id}" if device_id >= 0 else torch.device("cpu"))

data_path = Path("../data/08_RVQ")
data_path.mkdir(exist_ok=True)

### 1. Training loop


To start with, we are going to train an encoder-decoder model to have something to compare with.
The encoder-decoder model is implemented in file `model.py`, it consists of simple convolutional residual blocks.
We'll train on MNIST dataset, images of size `[1 x 28 x 28]`.
Then encode them into size `[1 x 3 x 3]`, then decode it back.

In this assignment, we are going to use lightning, as it is a fast and easy way to organize training loop, logging, checkpointing and so on... It assembles of 3 parts, majority of which are implemented for you.
- `DataModule`. This is a class, which does the data management: downloading, train-test splitting, loaders creation. It is implemented in `data.py`.
- `Trainer`. It is a class, which manages training loop and dedicated stuff, like checkpointing, earlystopping and so on. It is pre-configured for you, so that
    - It stops training, when validation loss starts to rise
    - Every 10 steps of training it stops, calculaties min loss on the whole validation dataset and logs it to tensorboard.
    - It implements a progress bar to monitor progress
- `Model` (`MNISTEncoderDecoder`). This is a class, which collects methods dedicated to work with a model. It is implemented if `model.py`, you'll further need to write a part of it.
    - It is initialize with quantizer and vq_loss, which we'll implement further.
    - You usually understand, what happen's in its methods by the name.
    - It has an inherited `self.log("name", metric)` method for convinient logging of metrics. Majority of logging is implemented for you.

Let's initialize those methods and train basic encoder decoder.

In [None]:
datamodule = MNISTDataModule(data_dir=data_path, batch_size=256)

In [None]:
def get_trainer():
    earlystopping = EarlyStopping(monitor="val/loss")
    checkpoint = ModelCheckpoint(dirpath=data_path / "model", save_top_k=2, monitor="val/loss")

    trainer = L.Trainer(
        callbacks=[earlystopping, checkpoint],
        devices=[device_id],
        check_val_every_n_epoch=1,
        max_epochs=30,
        enable_progress_bar=True,
        enable_checkpointing=True,
        num_sanity_val_steps=0,
        default_root_dir=data_path / "model",
        log_every_n_steps=10,
        # gradient_clip_val=0.1,
    )
    return trainer

In [None]:
trainer = get_trainer()
model = MNISTEncoderDecoder(quantizer=None, vq_loss=None)
trainer.fit(model=model, datamodule=datamodule)

Now go to the tensorboard and make sure that you are able to monitor metrics and see the pictures of generated images.

In [None]:
# %tensorboard --logdir str(data_path)

### 2. Vector quantisation

In [None]:
# Assignment: go to vector_quantisation.py
# Implement the missed methods in VectorQuantisation

In [None]:
assert test_vector_quantization()

Ok, now as we've implemented the VectorQuantisation class, we need to train it.
Training Vector Quantisation is not the obvious part.
There are three losses to make it work:
1. ***Reconstruction loss.***
This loss is is between our predicted picture and target image.
In our case it  is MSELoss.
The tricky thing with this loss is that its gradients should propogate over decoder and encoder.
When we train with vector quantisation, the loss is lost when we pick vectors from the codebook.
That's why we need to explicitely copy gradients from decoder to encoder.
2. ***Quantisation loss.***
This loss forces vectors from embedding to be more alike vectors from the encoder.
This is is MSELoss between vectors from encoder and quantized vectors. But it should propagate only to quantizer.
3. ***Commitment loss.***
This loss forces encoder to predict vectors more alike vectors from the codebook.
This loss should propogate only to encoder. 

Let's write a Loss class, which collects second and third part of the whole loss. 

In [None]:
# TODO: Very nice picture

In [None]:
## Assignment: implement vector_quantization.VectorQuantizationLoss.forward method

In [None]:
assert test_vector_quantisation_loss()

Now the tricky part: we should implement the forward method.

In [None]:
## Assignment: implement model.MNISTEncoderDecoder.training_step_with_quantizer method

In [None]:
assert test_training_step()

Now we are ready to train a model

In [None]:
trainer = get_trainer()
quantizer = VectorQuantizer(codebook_size=16, embedding_dim=16)
model = MNISTEncoderDecoder(quantizer=quantizer, vq_loss=VectorQuantizationLoss())
trainer.fit(model=model, datamodule=datamodule)

### 3. Residual vector quantisation

Now we have significantly restricted the complexity of ways, how the network can compress the images. Previously it was a continuous embedding, now it is a restricted amount of inegers.
We can increase the amount of possible encodings, while maintaining the amoubnt of used vectors.

For this purpose we'll use residual vecotr quantisation (RVQ).
We'll use several quantisers, let's say $N$.
The first quantiser encodes each vector as usual.
Second quantiser encodes the residaul between the ground-truth vector and first quantised vector.
Third quantiser encodes the residual between the ground-truth vector and sum of first and second quantisers.
And so on.
Now instead of encoding vector by $1$ index from a single codebook, we encode each vector by $N$ integers, each one represent the index from the dedicated codebook.

![rvq.png](pictures/rvq.png)

In [None]:
# Assignment: implement vector_quantization.ResidualVectorQuantizer

Now let's train the model and see, how the prediction changes

In [None]:
trainer = get_trainer()
quantizer = ResidualVectorQuantizer(codebook_size=4, embedding_dim=16, n_codebooks=4)
model = MNISTEncoderDecoder(quantizer=quantizer, vq_loss=VectorQuantizationLoss())
trainer.fit(model=model, datamodule=datamodule)