# <center>This `.ipynb` file contains the code for training the `cLDM` architecture</center>

### 1. Import the required libraries

In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import sys
import os
import datetime
import random

import numpy as np
from PIL import Image

from torchvision import datasets, transforms
from torchinfo import summary

sys.path.insert(0, '..')
from pfiles.unet_cond_base import UNet
from pfiles.vqvae import VQVAE
from pfiles.linear_noise_scheduler import LinearNoiseScheduler

### 2. Define a stamp to save model

In [None]:
def timestamp():
    time_cur = datetime.datetime.now()
    stamp = time_cur.strftime('%Y%m%d')
    return stamp

In [None]:
stmp = timestamp()
stmp

### 3. Define the device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device is:', device)

### 4. Custom functions for `marginal entropy`, `conditional entropy`, and `KL divergence`

In [None]:
# marginal entropy
def get_entropy_1D(xxx):
    return (-torch.sum(xxx * torch.log(xxx + 1e-8)))

# conditional entropy
def get_entropy_2D(xxx):
    return (-torch.sum(xxx * torch.log(xxx + 1e-8), dim=1))

# KL divergence
def get_KLD_1D(ppp, qqq, batch_mean=True):
    tmp = torch.sum((ppp * torch.log(ppp + 1e-8) - ppp * torch.log(qqq + 1e-8)), dim=1)
    if batch_mean:
        return torch.mean(tmp)
    else:
        return tmp

### 5. Set different hyperparameters

In [None]:
seed = 765

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if device == 'cuda':
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
num_timesteps = 1000
beta_start = 0.0001
beta_end = 0.02

In [None]:
select_batch_size = 16
rgb_input = 3
z_channels = 16
n_clusters = 14 # change it to 10, 11, 12, 13, 15, or 16 for other partitions

### 6. Load the dataset

In [None]:
dir_src = '/project/dsc-is/nono/Documents/kpc/dat0'
data_src = 'slice128_Block2_11K.npy'

print(os.path.join(dir_src, data_src))

kpc_dataset = np.load(os.path.join(dir_src, data_src))
kpc_dataset = kpc_dataset[:, 0, :, :, :]

print(kpc_dataset.shape)
N_SAMPLE, HEIGHT, WIDTH, CHANNELS = kpc_dataset.shape

In [None]:
index_range = np.arange(N_SAMPLE)
split = np.array_split(index_range, 11)
test_dataset = split[10]
training_dataset = np.setdiff1d(index_range, test_dataset)

In [None]:
print('Length of the training dataset:', len(training_dataset))
print('Length of the test dataset:', len(test_dataset))

### 7. Custom functions for model metrics

In [None]:
class history():
    def __init__(self, keys):
        self.values = {}
        for k in keys:
            self.values[k] = []
        self.keys = keys
        
    def append(self, dict_hist):
        for k in dict_hist.keys():
            self.values[k].append(dict_hist[k])
    
    def mean(self, keys=None):
        if (keys is None):
            keys = self.keys
        m = {}
        for k in keys:
            m[k] = np.round(np.mean(self.values[k]), 4)
        return m
    
    def __getitem__(self, key):
        return (self.values[key])
    
    def __str__(self):
        get = self.mean(self.keys)
        return ('\t'.join([k + ': ' + str(get[k]) for k in self.keys]))

### 8. Custom functions for extracting batches of samples from the dataset

In [None]:
def make_batch_list(idx, n_batch=10, batch_size=None, shuffle=True):
    if shuffle:
        np.random.shuffle(idx)
    if batch_size is not None:
        n_batch = len(idx) // batch_size
    batch_list = np.array_split(idx, n_batch)
    return batch_list

In [None]:
transform = transforms.ToTensor()

def generate_batch(idx, kpc_dataset):
    tmp = []
    for i in idx:
        xxx = transform(kpc_dataset[i])
        tmp.append(xxx)
    xxx_batch = torch.stack(tmp, dim=0)
    return xxx_batch

### 9. Apply transformations

In [None]:
np.round(np.mean(kpc_dataset, axis=(0, 1, 2)))

In [None]:
ix, iy, nc = 128, 128, 3 # height, width, channels

add_random_affine = transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=(130, 97, 154))

def generate_random_affine_batch(take_batch, data_src):
    
    tmp = np.empty((len(take_batch), ix, iy, nc))
    
    for a, i in enumerate(take_batch):
        img_tmp = Image.fromarray(data_src[i])
        img_tmp = add_random_affine(img_tmp)
        tmp[a] = img_tmp
    xxx = torch.tensor(tmp/255.0, dtype=torch.float32).permute(0, 3, 1, 2)
    return xxx

### 10. Set up directory for saving models

In [None]:
task_name = 'models_14'

if not os.path.exists(task_name):
    os.mkdir(task_name)

### 11. Instantiate `linear` scheduler

In [None]:
scheduler = LinearNoiseScheduler(num_timesteps=num_timesteps, beta_start=beta_start, beta_end=beta_end)

### 12. Neural network for deep learning-based clustering

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        
        self.classifier = nn.Sequential()
        self.classifier.add_module('conv1', nn.Conv2d(in_channels=z_channels, out_channels=128, kernel_size=4, stride=2,
                                                      padding=1))
        self.classifier.add_module('bnor1', nn.BatchNorm2d(num_features=128, affine=True, track_running_stats=True))
        self.classifier.add_module('lrel1', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.classifier.add_module('conv2', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1))
        self.classifier.add_module('bnor2', nn.BatchNorm2d(num_features=128, affine=True, track_running_stats=True))
        self.classifier.add_module('lrel2', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.classifier.add_module('conv3', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1))
        self.classifier.add_module('lrel3', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.classifier.add_module('conv4', nn.Conv2d(in_channels=128, out_channels=n_clusters, kernel_size=4, stride=1,
                                                      padding=0))
        self.classifier.add_module('lrel4', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        
    def forward(self, lat):
        out = self.classifier(lat)
        return out

### 13. Visualize `Classifier` architecture

In [None]:
summary(Classifier(), input_size=(16, 16, 32, 32)) # batch_size, z_channels, latent_height, latent_width

### 14. Visualize `VQVAE` architecture

In [None]:
summary(VQVAE(im_channels=rgb_input), input_size=(16, 3, 128, 128)) # batch_size, channels, height, width

### 15. Visualize `UNet` architecture

In [None]:
summary(UNet(im_channels=z_channels, cls=n_clusters), input_size=[(16, 16, 32, 32), (16,)])
# (batch_size, z_channels, latent_height, latent_width), (batch_size)

### 16. Instantiate `UNet`, `VQVAE`, and `Classifier` architecture

In [None]:
model = UNet(im_channels=z_channels, cls=n_clusters).to(device)
model.train()

vq_vae = VQVAE(im_channels=rgb_input).to(device)
vq_vae.eval()
print('Loaded vq_vae checkpoint')
vq_vae.load_state_dict(torch.load(os.path.join('../kpc_ldm', 'vqvae_autoencoder_ckpt.pth'), map_location=device,
                                  weights_only=True))

In [None]:
model_cl = Classifier().to(device)
model_cl.train()

### 17. Prepare to train the `cLDM`

In [None]:
key_loss = ['Loss', 'MSE', 'ME', 'CE', 'AF']
loss_hist = history(['Epoch'] + key_loss)

# setting up additional hyperparameters
num_epochs = 600
learning_rate = 0.003
optimizer = optim.Adadelta(list(model.parameters()) + list(model_cl.parameters()), lr=learning_rate)
criterion = nn.MSELoss()

l_me = 0.1
l_ce = 0.06
l_af = 0.04

### 18. Train the `cLDM `

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if device == 'cuda':
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

for param in vq_vae.parameters():
    param.requires_grad = False

for epoch_idx in range(num_epochs):
    batch_list = make_batch_list(training_dataset, batch_size=select_batch_size)
    
    loss_tt = history(key_loss)
    
    for idx_tmp in batch_list:
        optimizer.zero_grad()
        xxx_tmp = generate_batch(idx_tmp, kpc_dataset)
        im = xxx_tmp.to(device)

        with torch.no_grad():
            im, _ = vq_vae.encode(im)
    
        out_cl = model_cl(im)
        ppp_tmp = F.softmax(out_cl.reshape((-1, n_clusters)), dim=1)
        ppp_mean = torch.mean(ppp_tmp, dim=0, keepdim=True)
        entropy_marginal = get_entropy_1D(ppp_mean)
        entropy_cond = torch.mean(get_entropy_2D(ppp_tmp))
        cond_input = torch.argmax(out_cl.reshape((-1, n_clusters)), dim=1)
        
        xxa_tmp = generate_random_affine_batch(idx_tmp, kpc_dataset)
        im_af = xxa_tmp.to(device)
        
        with torch.no_grad():
            im_af, _ = vq_vae.encode(im_af)
        
        out_cl_af = model_cl(im_af)
        ppa_tmp = F.softmax(out_cl_af.reshape((-1, n_clusters)), dim=1)
        
        loss_affine = get_KLD_1D(ppp_tmp, ppa_tmp)

        noise = torch.randn_like(im).to(device)

        t = torch.randint(low=0, high=num_timesteps, size=(im.shape[0],)).to(device)

        noisy_im = scheduler.add_noise(im, noise, t)
        noise_pred = model(noisy_im, t, cond_input=cond_input)

        mse_loss = criterion(noise_pred, noise)
        loss_tmp = mse_loss - l_me * entropy_marginal + l_ce * entropy_cond + l_af * loss_affine
        
        loss_tmp.backward()
        optimizer.step()

        loss_tt.append({'Loss': loss_tmp.item(), 'MSE': mse_loss.item(), 'ME': entropy_marginal.item(),
                        'CE': entropy_cond.item(), 'AF': loss_affine.item()})
    
    loss_hist.append({'Epoch': epoch_idx + 1})
    loss_hist.append(loss_tt.mean())
    
    print('Epoch:', epoch_idx + 1, '\t', str(loss_tt))
    
print('Done training...')

### 19. Save models after training

In [None]:
torch.save(model.state_dict(), os.path.join(task_name, f'unet_training_ckpt_{stmp}_{num_epochs}_{n_clusters}.pth'))
torch.save(model_cl.state_dict(), os.path.join(task_name, f'classifier_training_ckpt_{stmp}_{num_epochs}_{n_clusters}.pth'))