<a href="https://colab.research.google.com/github/serag-ai/I-SynMed/blob/main/Paper_GrayScale_Conditional_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is based on
https://iopscience.iop.org/article/10.1088/1361-6560/acca5c/meta

# Installation & Imports

In [None]:
!git clone https://github.com/shaoyanpan/2D-Medical-Denoising-Diffusion-Probabilistic-Model-.git

In [None]:
%cd /content/2D-Medical-Denoising-Diffusion-Probabilistic-Model-

In [None]:
!pip install timm
!pip install monai
!pip install einops

Mount Google Drive if needed,

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import PIL
import time
import torch
import torchvision
import torch.nn.functional as F
from einops import rearrange
from torch import nn
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader
import glob
import scipy.io
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
from random import randint
import random
import time
import re
from scipy import ndimage
from skimage import io
from skimage import transform
from natsort import natsorted
from skimage.transform import rotate, AffineTransform
from timm.models.layers import DropPath, to_3tuple, trunc_normal_
import  torchvision.transforms as transforms
#The diffusion module adpated from https://github.com/openai/guided-diffusion
from diffusion.Create_diffusion import *
from diffusion.resampler import *
from diffusion.normal_diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10

from network.Diffusion_model_transformer import *
from torchvision.datasets import ImageFolder


# Setup Hyperparameters

In [23]:
# @markdown \
# @markdown # Image Hyperparameters
# @markdown \
image_size = 256 # @param {type:"integer"}
spacing = (1, 1) # @param {type:"raw"}
channels = 1# @param {type:"integer"}

# @markdown \
# @markdown # Diffusion Hyperparameters
# @markdown \
diffusion_steps = 1000 # @param {type:"integer"}
learn_sigma=True# @param {type:"boolean"}




# @markdown \
# @markdown # Network Parameters
# @markdown Here enter your network parameters:num_channels means the initial channels in each block, channel_mult means the multipliers of the channels (in this case, 128,128,256,256,512,512 for the first to the sixth block), attention_resulution means we use the transformer blocks in the third to the sixth block number of heads, window size in each transformer block
# @markdown \
# @markdown \
num_channels=128 # @param {type:"integer"}
channel_mult = (1, 1, 2, 2, 4, 4)# @param {type:"raw"}
attention_resolutions="64,32,16,8" # @param {type:"string"}
num_heads=[4,4,4,8,16,16]# @param {type:"raw"}
window_size = [[4,4],[4,4],[4,4],[8,8],[8,8],[4,4]]# @param {type:"raw"}
num_res_blocks = [2,2,1,1,1,1]# @param {type:"raw"}
sample_kernel=([2,2],[2,2],[2,2],[2,2],[2,2]),
use_scale_shift_norm=True# @param {type:"boolean"}
resblock_updown = False# @param {type:"boolean"}
attention_ds = []
for res in attention_resolutions.split(","):
    attention_ds.append(int(res))


# @markdown \
# @markdown # Training Parameters
# @markdown \
N_EPOCHS = 1000 # @param {type:"integer"}
BATCH_SIZE_TRAIN = 4 # @param {type:"integer"}
class_cond = True# @param {type:"boolean"}
NUM_CLASSES=4# @param {type:"integer"}
learning_rate = 5e-6# @param {type:"number"}
weight_decay =  1e-4# @param {type:"number"}
training_dt_dir = "/content/path_to_data"# @param {type:"string"}
output_dir = "/content/tmp/"# @param {type:"string"}


In [25]:
# Don't toch these parameters, they are irrelant to the image synthesis
sigma_small=False
noise_schedule='linear'
use_kl=False
predict_xstart=False
rescale_timesteps=True
rescale_learned_sigmas=True
use_checkpoint=False
img_size = (image_size,image_size)
BATCH_SIZE_TRAIN = BATCH_SIZE_TRAIN * channels
timestep_respacing=[50]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Build the Diffusion process

In [26]:
diffusion = create_gaussian_diffusion(
    steps=diffusion_steps,
    learn_sigma=learn_sigma,
    sigma_small=sigma_small,
    noise_schedule=noise_schedule,
    use_kl=use_kl,
    predict_xstart=predict_xstart,
    rescale_timesteps=rescale_timesteps,
    rescale_learned_sigmas=rescale_learned_sigmas,
    timestep_respacing=timestep_respacing,
)
schedule_sampler = UniformSampler(diffusion)


# Build the network

In [27]:
model = SwinVITModel(
        image_size=(image_size,image_size),
        in_channels=channels,
        model_channels=num_channels,
        out_channels=channels*2,
        sample_kernel=sample_kernel,
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=0,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond else None),
        use_checkpoint=False,
        use_fp16=False,
        num_heads=num_heads,
        window_size = window_size,
        num_head_channels=64,
        num_heads_upsample=-1,
        use_scale_shift_norm=use_scale_shift_norm,
        resblock_updown=resblock_updown,
        use_new_attention_order=False,
    ).to(device)


In [None]:
# #In case you want to use CNN
# from network.Diffusion_model_Unet import *
# model = UNetModel(
#         image_size=image_size,
#         in_channels=1,
#         model_channels=num_channels,
#         out_channels=2,
#         num_res_blocks=num_res_blocks[0],
#         attention_resolutions=tuple(attention_ds),
#         dropout=0.,
#         sample_kernel=sample_kernel,
#         channel_mult=channel_mult,
#         num_classes=(NUM_CLASSES if class_cond else None),
#         use_checkpoint=False,
#         use_fp16=False,
#         num_heads=4,
#         num_head_channels=64,
#         num_heads_upsample=-1,
#         use_scale_shift_norm=use_scale_shift_norm,
#         resblock_updown=False,
#         use_new_attention_order=False,
#     ).to(device)

# Call the optimizer and ready for start

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
print('parameter number is '+str(pytorch_total_params))
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate,weight_decay = weight_decay)

# Build the training function. Run the training function once = one epoch

In [29]:
# Here we explain the training process
def train(model, optimizer,data_loader1, loss_history):

    #1: set the model to training mode
    model.train()
    total_samples = len(data_loader1.dataset)
    loss_sum = []
    total_time = 0

    #2: Loop the whole dataset, x1 (traindata) is the image batch
    for i, (x1,labels) in enumerate(data_loader1):

        traindata = x1.to(device)
        trainlabel = labels.to(device)


        #3: extract random timestep for training
        t, weights = schedule_sampler.sample(traindata.shape[0], device)

        aa = time.time()

        #4: Optimize the TDM network
        optimizer.zero_grad()
        all_loss = diffusion.training_losses(model,traindata,t=t,model_kwargs={'y':trainlabel})
        loss = (all_loss["loss"] * weights).mean()
        loss.backward()
        loss_sum.append(loss.detach().cpu().numpy())
        optimizer.step()

        #5:print out the intermediate loss for every 100 batches
        total_time += time.time()-aa
        if i % 100 == 0:
            print('optimization time: '+ str(time.time()-aa))
            print('[' +  '{:5}'.format(i * BATCH_SIZE_TRAIN) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader1)) + '%)]  Loss: ' +
                  '{:6.7f}'.format(np.nanmean(loss_sum)))

    #6: print out the average loss for this epoch
    average_loss = np.nanmean(loss_sum)
    loss_history.append(average_loss)
    print("Total time per sample is: "+str(total_time))
    print('Averaged loss is: '+ str(average_loss))
    return average_loss

# Build the testing function.

In [30]:
# Run the evaluate function will generate 4 samples and will be save to a folder in MAT format
num_sample = 4 if class_cond ==False else NUM_CLASSES
conditions=torch.arange(NUM_CLASSES).long().to(device)

def evaluate(model,epoch,path):
    model.eval()
    aa = time.time()
    prediction = []
    true = []
    img = []
    loss_all = []
    with torch.no_grad():
        x_clean = diffusion.p_sample_loop(model,(num_sample, channels, image_size, image_size),clip_denoised=True,model_kwargs={'y':conditions},)
        img.append(x_clean.cpu().numpy())
    print('Generate for the epoch #'+str(epoch)+' result:')
    plt.rcParams['figure.figsize'] = [20, 20]
    plt.figure()
    f, axarr = plt.subplots(1,num_sample)
    for ind in range(num_sample):
        axarr[ind].imshow(x_clean[ind,0,:,:].cpu().numpy(),'gray')
    plt.show()
    data = {"img":img}
    print(str(time.time()-aa))
    scipy.io.savemat(path+ 'test_example_epoch'+str(epoch)+'.mat',data)

In [31]:
num_sample = 4 if class_cond ==False else NUM_CLASSES
conditions=torch.arange(NUM_CLASSES).long().to(device)
def evaluate_fixed_noise(model,epoch,path,noise):
    model.eval()
    aa = time.time()

    with torch.no_grad():
        x_clean = diffusion.p_sample_loop(model,(num_sample, channels, image_size, image_size),clip_denoised=True,model_kwargs={'y':conditions},noise=noise)

    grid=torchvision.utils.make_grid(x_clean,nrow=NUM_CLASSES)
    torchvision.utils.save_image(grid, path+ 'test_example_epoch'+str(epoch)+'.png', normalize=False, )
    print('Generate for the epoch #'+str(epoch)+' result:')

# Start the training and testing

In [35]:

training_set2 = ImageFolder(training_dt_dir,transform=transforms.Compose([
        torchvision.transforms.Grayscale(),
         torchvision.transforms.Resize((image_size, image_size)),
        torchvision.transforms.ToTensor(),
    ]))
train_dev_sets = training_set2#torch.utils.data.ConcatDataset([training_set1, training_set2])


In [37]:
# Enter your data reader parameters
params = {'batch_size': BATCH_SIZE_TRAIN,
          'shuffle': True,
          'pin_memory': True,
          'drop_last': False}
train_loader1 = torch.utils.data.DataLoader(train_dev_sets, **params)
shape_=(num_sample, channels, image_size, image_size)
fixed_noise = torch.randn(*shape_, device=device)
# Enter your total number of epoch

# Enter the address you save the checkpoint and the evaluation examples
path = output_dir

PATH = path+'ViTRes1.pt' # Use your own path
best_loss = 1
if not os.path.exists(path):
  os.makedirs(path)
train_loss_history, test_loss_history = [], []


# Uncomment this when you resume the checkpoint
#model.load_state_dict(torch.load("/content/drive/MyDrive/mixed_DDPM_transformer/ViTRes1_26.pt"),strict=False)

train_loss= np.array([])

for epoch in range(0, N_EPOCHS):
    print('Epoch:', epoch)
    start_time = time.time()

    average_loss = train(model, optimizer, train_loader1, train_loss_history)
    train_loss=np.append(train_loss,average_loss)
    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
    if epoch % 5 == 0:

        evaluate_fixed_noise(model,epoch,path,noise=fixed_noise)
        print('Save the latest best model')
        torch.save(model.state_dict(), PATH)
print('Execution time')
np.save(path+'loss.npy',train_loss)

# **Sampling**

In [None]:
model.load_state_dict(torch.load("/content/tmp/ViTRes1.pt"))

In [None]:
!mkdir /content/output_dir;mkdir /content/output_dir/0;mkdir /content/output_dir/1;mkdir /content/output_dir/2;mkdir /content/output_dir/3

In [None]:
num_sample = 1
epochs=1
for epoch in range(epochs):
    for label in range(NUM_CLASSES):
        model.eval()
        conditions=torch.ones(1).long().repeat(num_sample).to(device) * label
        images=None
        with torch.no_grad():
            x_clean = diffusion.p_sample_loop(model,(num_sample, 1, image_size, image_size),clip_denoised=True,model_kwargs={'y':conditions},)
        for i,img in enumerate(x_clean):
            torchvision.utils.save_image(img, f'/content/output_dir/{label}/{epoch}_{i}.png', normalize=False, )

        print(f"end of labeel: {label}")

    print(f"end of epoch: {epoch}")
