## 1. Import the required libraries

In [None]:
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = 'False'
import sys
import datetime
import configparser

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import re
import random
from tqdm import tqdm
from torchinfo import summary

import torch
from torch import optim
from torch.backends import cudnn

import torchvision
from torchvision import transforms

sys.path.insert(0, '../MODULES')
from DENOISING_DIFFUSION_PYTORCH import Unet, GaussianDiffusion

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

## 2. Write device agnostic code

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device is:', device)

## 3. Set stamp to save models

In [None]:
stamp = datetime.datetime.now().strftime('%Y%m%d')
print(stamp)

## 4. Fix hyperparameters

In [None]:
config = configparser.ConfigParser()

config['diffusion_model'] = {'rand_seed': 76543, 'n_pix': 128, 'batch_size': 10, 'learning_rate': 1e-4, 'h_dim': 64,
                             'dim_mults': [1, 2, 4, 8], 'self_condition': False, 'timesteps': 1000}

config.write(sys.stdout)

## 5. Set hyperparameters

In [None]:
rand_seed = int(config['diffusion_model']['rand_seed'])

random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.random.manual_seed(rand_seed)

if device == 'cuda':
    torch.cuda.manual_seed(rand_seed)
    torch.cuda.manual_seed_all(rand_seed)

torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True

In [None]:
n_pix         = int(config['diffusion_model']['n_pix'])
batch_size    = int(config['diffusion_model']['batch_size'])
learning_rate = float(config['diffusion_model']['learning_rate'])
h_dim         = int(config['diffusion_model']['h_dim'])

In [None]:
match = re.search('\[(.*)\]', config['diffusion_model']['dim_mults'])
dim_mults = re.split(',\s*', match[1])
dim_mults = [int(i) for i in dim_mults]

In [None]:
self_condition = bool(config['diffusion_model']['self_condition'])
timesteps      = int(config['diffusion_model']['timesteps'])

## 6. Load dataset

In [None]:
dir_src = '/project/dsc-is/nono/Documents/kpc/dat0'
data_src = 'slice128_Block2_11K.npy'

print(os.path.join(dir_src, data_src))

pix_src = np.load(os.path.join(dir_src, data_src))
pix_src = pix_src[:, 0, :, :, :]

print(pix_src.shape)
n_sample, nx, ny, nc = pix_src.shape

In [None]:
# randomly show 20 images from pix_src

fig, axis = plt.subplots(nrows=2, ncols=10, figsize=(20, 4))

for ax, _ in zip(axis.ravel(), range(20)):
    inx = np.random.randint(len(pix_src))
    ax.imshow(pix_src[inx])
    ax.axis(False);

## 7. Instantiate Unet

In [None]:
model = Unet(dim=h_dim, dim_mults=dim_mults, self_condition=self_condition, flash_attn=True)

## 8. Visualize Unet

In [None]:
summary(model, input_size=[(10, 3, 128, 128), (10,)])

## 9. Instantiate diffusion model

In [None]:
diffusion = GaussianDiffusion(model, image_size=n_pix, timesteps=timesteps, objective='pred_noise', beta_schedule='linear',
                              auto_normalize=False).to(device)

## 10. Set up optimizer

In [None]:
optim1 = optim.AdamW(diffusion.parameters(), lr=learning_rate, amsgrad=True)

## 11. Custom functions to save diffusion model history

In [None]:
class HistDict():
    def __init__(self, keys):
        self.values = {}
        for kk in keys:
            self.values[kk] = []
        self.keys = keys
        
    def append(self, dict_hist):
        for kk in dict_hist.keys():
            self.values[kk].append(dict_hist[kk])
            
    def mean(self, keys=None):
        if (keys is None):
            keys = self.keys
        mm = {}
        for kk in keys:
            mm[kk] = np.round(np.mean(self.values[kk]), 6)
        return mm
    
    def __getitem__(self, key):
        return self.values[key]
    
    def DataFrame(self):
        tmp = pd.DataFrame.from_dict(self.values)
        return tmp
            
    def read_tsv(self, filepath):
        tmp = pd.read_csv(filepath, delimiter='\t')
        tmp = tmp.iloc[:, 1:]
        dict_tmp = tmp.to_dict(orient="list")
        keys_tmp = dict_tmp.keys()
        self.keys = keys_tmp
        for kk in keys_tmp:
            self.values[kk] = dict_tmp[kk]

## 12. Custom function to create a list of batches

In [None]:
def make_batch_list(idx, n_batch=10, batch_size=None, shuffle=True):
    if shuffle:
        np.random.shuffle(idx)
    if (batch_size is not None):
        n_batch = len(idx) // batch_size
    batch_list = np.array_split(idx, n_batch)
    return batch_list

## 13. Custom function to extract samples in batches

In [None]:
transform_pix = transforms.Compose([transforms.ToTensor()])

def generate_batch(idx, pix_src):
    tmp = []
    for ii in idx:
        xxx = transform_pix(pix_src[ii])
        tmp.append(xxx)
    xxx_batch = torch.stack(tmp, dim=0)
    return xxx_batch

## 14. Train DDPM

In [None]:
def train_diff(t_epoch, t_print, hist_tt=None):
    print('Training starts at', datetime.datetime.now().strftime('%H:%M'), '(24-hour format)')
    key_trn = ['loss_trn']
    
    if (hist_tt is None):
        hist_tt = HistDict(['tt'] + key_trn)
        
    for tt in range(t_epoch):
        diffusion.train()
        idx_trn = np.arange(n_sample)
        batch_list = make_batch_list(idx_trn, batch_size=batch_size)
        hist_batch = HistDict(key_trn)
        
        for idx_tmp in tqdm(batch_list):
            xxx_tmp = generate_batch(idx_tmp, pix_src)
            xxx_tmp = xxx_tmp.to(device)
            loss_tmp = diffusion(xxx_tmp)
            
            optim1.zero_grad()
            loss_tmp.backward()
            optim1.step()
            
            hist_batch.append({'loss_trn': loss_tmp.item()})
            
        hist_trn = hist_batch.mean()
        hist_tt.append({'tt': tt})
        hist_tt.append(hist_trn)
        
        if (tt + 1) % t_print == 0:
            print(f'Epoch: {(tt + 1)}/{t_epoch}','|','Training loss:', np.round(hist_trn['loss_trn'], 4))
            
    print('Training finishes at', datetime.datetime.now().strftime('%H:%M'), '(24-hour format)')
    return hist_tt

In [None]:
dir_save = '../P1'
key_trn = ['loss_trn']
hist_tt = HistDict(['tt'] + key_trn)

In [None]:
t_epoch = 150
t_print = 10

hist_tt = train_diff(t_epoch, t_print, hist_tt)

## 15. Save history and diffusion model

In [None]:
path_model = os.path.join(dir_save, 'model_ddpm.{}.{}.ckpt'.format(stamp, t_epoch))
path_hist = os.path.join(dir_save, 'hist_ddpm.{}.{}.tsv'.format(stamp, t_epoch))

In [None]:
print('saving', path_model)
torch.save(diffusion.model.state_dict(), path_model)
print('saving', path_hist)
hist_tt.DataFrame().to_csv(path_hist, sep='\t')