# <center>This `.ipynb` file contains the code for training the autoencoder</center>

### 1. Import the required libraries

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.utils import make_grid

import os
import random
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

from pfiles.lpips import LPIPS
from pfiles.vqvae import VQVAE
from pfiles.discriminator import Discriminator

### 2. Define the device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device is:', device)

### 3. Set different hyperparameters

In [None]:
seed = 765

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if device == 'cuda':
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
select_batch_size = 16
rgb_input = 3
num_epochs = 20

In [None]:
disc_step_start = 1000
step_count = 0
acc_steps = 1
codebook_weight = 1
commitment_beta = 0.2
disc_weight = 0.5
perceptual_weight = 1

### 4. Load the dataset

In [None]:
dir_src = '/project/dsc-is/nono/Documents/kpc/dat0'
data_src = 'slice128_Block2_11K.npy'

print(os.path.join(dir_src, data_src))

kpc_dataset = np.load(os.path.join(dir_src, data_src))
kpc_dataset = kpc_dataset[:, 0, :, :, :]

print(kpc_dataset.shape)
N_SAMPLE, HEIGHT, WIDTH, CHANNELS = kpc_dataset.shape

In [None]:
index_range = np.arange(N_SAMPLE)
split = np.array_split(index_range, 11)
test_dataset = split[10]
training_dataset = np.setdiff1d(index_range, test_dataset)

In [None]:
print('Length of the training dataset:', len(training_dataset))
print('Length of the test dataset:', len(test_dataset))

### 5. Custom functions for extracting batches of samples from the dataset

In [None]:
def make_batch_list(idx, n_batch=10, batch_size=None, shuffle=True):
    if shuffle:
        np.random.shuffle(idx)
    if batch_size is not None:
        n_batch = len(idx) // batch_size
    batch_list = np.array_split(idx, n_batch)
    return batch_list

In [None]:
transform = transforms.ToTensor()

def generate_batch(idx, kpc_dataset):
    tmp = []
    for i in idx:
        xxx = transform(kpc_dataset[i])
        tmp.append(xxx)
    xxx_batch = torch.stack(tmp, dim=0)
    return xxx_batch

### 6. Set up directory for saving models

In [None]:
task_name = 'kpc_ldm'

if not os.path.exists(task_name):
    os.mkdir(task_name)

### 7. Instantiate `VQVAE`, `LPIPS model`, and `Discriminator`

In [None]:
model = VQVAE(im_channels=3).to(device)
model.train()
lpips_model = LPIPS().eval().to(device)
discrim = Discriminator(im_channels=3).to(device)
discrim.train()

# setting up additional hyperparameters
recon_criterion = nn.MSELoss()
disc_criterion = nn.MSELoss()

optimizer_g = optim.Adam(model.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discrim.parameters(), lr=0.0001, betas=(0.5, 0.999))

### 8. Train the autoencoder

In [None]:
for epoch_idx in range(num_epochs):
    batch_list = make_batch_list(training_dataset, batch_size=select_batch_size)
    
    recon_losses = []
    codebook_losses = []
    
    perceptual_losses = []
    
    disc_losses = []
    gen_losses = []
    losses = []
    
    optimizer_d.zero_grad()
    optimizer_g.zero_grad()
    
    for idx_tmp in tqdm(batch_list):
        
        step_count += 1
        xxx_tmp = generate_batch(idx_tmp, kpc_dataset)
        im = xxx_tmp.to(device)
        
        model_output = model(im)
        output, z, quantize_losses = model_output
            
        recon_loss = recon_criterion(output, im)
        recon_losses.append(recon_loss.item())
        recon_loss = recon_loss / acc_steps
        
        g_loss = (recon_loss + (codebook_weight * quantize_losses['codebook_loss'] / acc_steps) +\
                               (commitment_beta * quantize_losses['commitment_loss'] / acc_steps))
        
        codebook_losses.append(codebook_weight * quantize_losses['codebook_loss'].item())
        
        if step_count > disc_step_start:
            disc_fake_pred = discrim(model_output[0])
            disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device))
            
            gen_losses.append(disc_weight * disc_fake_loss.item())
            g_loss += disc_weight * disc_fake_loss / acc_steps
            
        lpips_loss = torch.mean(lpips_model(output, im)) / acc_steps
        perceptual_losses.append(perceptual_weight * lpips_loss.item())
        g_loss += perceptual_weight * lpips_loss / acc_steps
        
        losses.append(g_loss.item())
        g_loss.backward()
        
        if step_count > disc_step_start:
            
            fake = output
            
            disc_fake_pred = discrim(fake.detach())
            disc_real_pred = discrim(im)
            
            disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device))
            disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device))
            
            disc_loss = disc_weight * (disc_fake_loss + disc_real_loss) / 2
            disc_losses.append(disc_loss.item())
            
            disc_loss = disc_loss / acc_steps
            disc_loss.backward()
            
            if step_count % acc_steps == 0:
                optimizer_d.step()
                optimizer_d.zero_grad()
                
        if step_count % acc_steps == 0:
            optimizer_g.step()
            optimizer_g.zero_grad()
            
    optimizer_d.step()
    optimizer_d.zero_grad()
    optimizer_g.step()
    optimizer_g.zero_grad()
    
    if len(disc_losses) > 0:
        print('Finished epoch: {} | Recon loss: {:.4f} | Perceptual loss: {:.4f} | Codebook: {:.4f} | G loss: {:.4f} | '
              'D loss: {:.4f}'.format(epoch_idx + 1, np.mean(recon_losses), np.mean(perceptual_losses),
                                                     np.mean(codebook_losses), np.mean(gen_losses), np.mean(disc_losses)))
        
    else:
        print('Finished epoch: {} | Recon loss: {:.4f} | Perceptual loss: {:.4f} | Codebook: {:.4f}'.format(epoch_idx + 1,
                                                     np.mean(recon_losses), np.mean(perceptual_losses),
                                                     np.mean(codebook_losses)))

print('Done training...')

### 9. Save models after training

In [None]:
torch.save(model.state_dict(), os.path.join(task_name, 'vqvae_autoencoder_ckpt.pth'))