# <center>This `.ipynb` file contains the code for generating samples by utlizing latent features of the training images as a conditioning mechanism</center>

### 1. Import the required libraries

In [None]:
import torch
from torch import nn

from torchvision import transforms

import sys
import os
import random
import numpy as np

import matplotlib.pyplot as plt
from tqdm import tqdm

sys.path.insert(0, '..')
from pfiles.unet_cond_base import UNet
from pfiles.vqvae import VQVAE
from pfiles.linear_noise_scheduler import LinearNoiseScheduler

### 2. Define the device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device is:', device)

### 3. Set different hyperparameters

In [None]:
seed = 765

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if device == 'cuda':
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
num_timesteps = 1000
beta_start = 0.0001
beta_end = 0.02

In [None]:
select_batch_size = 1
rgb_input = 3
z_channels = 16
n_clusters = 14 # change it to 10, 11, 12, 13, 15, or 16 for other partitions

image_size = 128
num_samples = 1

### 4. Load the dataset

In [None]:
dir_src = '/project/dsc-is/nono/Documents/kpc/dat0'
data_src = 'slice128_Block2_11K.npy'

print(os.path.join(dir_src, data_src))

kpc_dataset = np.load(os.path.join(dir_src, data_src))
kpc_dataset = kpc_dataset[:, 0, :, :, :]

print(kpc_dataset.shape)
N_SAMPLE, HEIGHT, WIDTH, CHANNELS = kpc_dataset.shape

In [None]:
index_range = np.arange(N_SAMPLE)
split = np.array_split(index_range, 11)
test_dataset = split[10]
training_dataset = np.setdiff1d(index_range, test_dataset)

In [None]:
print('Length of the training dataset:', len(training_dataset))
print('Length of the test dataset:', len(test_dataset))

### 5. Custom functions for extracting batches of samples from the dataset

In [None]:
def make_batch_list(idx, n_batch=10, batch_size=None, shuffle=True):
    if shuffle:
        np.random.shuffle(idx)
    if batch_size is not None:
        n_batch = len(idx) // batch_size
    batch_list = np.array_split(idx, n_batch)
    return batch_list

In [None]:
transform = transforms.ToTensor()

def generate_batch(idx, kpc_dataset):
    tmp = []
    for i in idx:
        xxx = transform(kpc_dataset[i])
        tmp.append(xxx)
    xxx_batch = torch.stack(tmp, dim=0)
    return xxx_batch

### 6. Set up directory for saving models

In [None]:
task_name = 'models_14'

if not os.path.exists(task_name):
    os.mkdir(task_name)

### 7. Neural network for deep learning-based clustering

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        
        self.classifier = nn.Sequential()
        self.classifier.add_module('conv1', nn.Conv2d(in_channels=z_channels, out_channels=128, kernel_size=4, stride=2,
                                                      padding=1))
        self.classifier.add_module('bnor1', nn.BatchNorm2d(num_features=128, affine=True, track_running_stats=True))
        self.classifier.add_module('lrel1', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.classifier.add_module('conv2', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1))
        self.classifier.add_module('bnor2', nn.BatchNorm2d(num_features=128, affine=True, track_running_stats=True))
        self.classifier.add_module('lrel2', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.classifier.add_module('conv3', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1))
        self.classifier.add_module('lrel3', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.classifier.add_module('conv4', nn.Conv2d(in_channels=128, out_channels=n_clusters, kernel_size=4, stride=1,
                                                      padding=0))
        self.classifier.add_module('lrel4', nn.LeakyReLU(negative_slope=0.1, inplace=True))
        
    def forward(self, lat):
        out = self.classifier(lat)
        return out

### 8. Instantiate `UNet`, `VQVAE`, and `Classifier` architecture

In [None]:
model = UNet(im_channels=z_channels, cls=n_clusters).to(device)
model.eval()
print('Loaded unet finetuning2 checkpoint')
model.load_state_dict(torch.load(os.path.join(task_name, 'unet_finetuning2_ckpt_20250128_70_14.pth'), map_location=device,
                                     weights_only=True))

In [None]:
vq_vae = VQVAE(im_channels=rgb_input).to(device)
vq_vae.eval()
print('Loaded vq_vae checkpoint')
vq_vae.load_state_dict(torch.load(os.path.join('../kpc_ldm', 'vqvae_autoencoder_ckpt.pth'), map_location=device,
                                  weights_only=True))

In [None]:
model_cl = Classifier().to(device)
model_cl.eval()
print('Loaded model_cl finetuning2 checkpoint')
model_cl.load_state_dict(torch.load(os.path.join(task_name, 'classifier_finetuning2_ckpt_20250128_70_14.pth'),
                                    map_location=device, weights_only=True))

### 9. Custom function to conditionally generate samples

In [None]:
def cond_sample(model, scheduler, vq_vae):
    
    im_size = image_size // 2 ** sum([True, True])
    xt = torch.randn((num_samples, z_channels, im_size, im_size)).to(device)
    
    for i in tqdm(reversed(range(num_timesteps))):
        
        t = (torch.ones((xt.shape[0],)) * i).long().to(device)
        noise_pred = model(xt, t, training_cond_input)
        
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
        
        if i==0:
            ims = vq_vae.decode(xt)
        else:
            ims = xt
            
    return ims

### 10. Instantiate `linear` scheduler

In [None]:
scheduler = LinearNoiseScheduler(num_timesteps=num_timesteps, beta_start=beta_start, beta_end=beta_end)

### 11. Prepare to generate images

In [None]:
training_batch_list = make_batch_list(training_dataset, batch_size=select_batch_size, shuffle=False)

### 12. Generate images

In [None]:
for i in range(len(training_batch_list)):
    training_xxx = generate_batch(training_batch_list[i], kpc_dataset)
    training_xxx = training_xxx.to(device)
    training_im, _ = vq_vae.encode(training_xxx)
    training_out_cl = model_cl(training_im)
    training_cond_input = torch.argmax(training_out_cl.reshape((-1, n_clusters)), dim=1)
    
    with torch.no_grad():
        gen_ims = cond_sample(model=cond_unet, scheduler=scheduler, vq_vae=vq_vae)
        
    gen_ims = torch.clamp(gen_ims, min=0., max=1.).detach().cpu()
    print(i+1)
    plt.figure(figsize=(1, 1))
    plt.imshow(gen_ims.squeeze().permute(1, 2, 0))
    plt.axis(False)
    plt.savefig(f'gen_14/{i+1}.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
    plt.show()