#install

In [None]:
!pip install monai
!pip install nibabel
!pip install imageio
!pip install natsort
!pip install lpips
!pip install kornia

# One sided paired Medical Image Translation with Normalized Edge Priors

In [None]:
import nibabel as nib
import os, glob

import matplotlib.pyplot as plt
import numpy as np
import torch

from tqdm import tqdm
from monai.config import print_config
from monai.data import DataLoader, Dataset, CacheDataset, PatchDataset, SmartCacheDataset
from monai.inferers import SliceInferer,sliding_window_inference
from monai.utils import set_determinism, first
from monai.transforms import(
    Compose,
    Lambdad,
    LoadImaged,
    SaveImage,
    EnsureChannelFirstd,
    SqueezeDimd,
    RandSpatialCropSamplesd,
    ScaleIntensityRangePercentilesd,
    ScaleIntensityRanged,
    EnsureTyped,
    Resized,
    CropForegroundd,
    CenterSpatialCropd,
    RandZoomd
)
# print_config()
from datetime import date
today = str(date.today()).replace('-','').replace(' ', '')
gpu_device = torch.device(f'cuda:{0}')
weights_output_dir = 'Weights'

Get gamma images, define splits, slice them and write into a 2D folder

In [None]:
suffix2d = '_2d_transformed'
MRs='MR'
CTs='CT'
MASKs='MASK'

# 2D processing

In [None]:
transforms_2d = Compose(
    [
        LoadImaged(keys=["SRC", "TGT", "MASK"], image_only=False),
        EnsureChannelFirstd(keys=["SRC", "TGT", "MASK"]),
        EnsureTyped(keys=["SRC", "TGT", "MASK"], dtype=torch.float32),
    ]
)

BATCH_SIZE=14
NUM_WORKERS=12

MRs_sufx = os.path.join(MRs + suffix2d)
CTs_sufx = os.path.join(CTs + suffix2d)
MASKs_sufx = os.path.join(MASKs + suffix2d)

# transforms_2d = Compose(
#     [
#         LoadImaged(keys=["SRC", "TGT"], image_only=False),
#         EnsureChannelFirstd(keys=["SRC", "TGT"]),
#         EnsureTyped(keys=["SRC", "TGT"], dtype=torch.float32),
#     ]
# )

# BATCH_SIZE=14
# NUM_WORKERS=12

# MRs_sufx = os.path.join(MRs + suffix2d)
# CTs_sufx = os.path.join(CTs + suffix2d)

- Train

In [None]:
fnames_train_A_2d = sorted(glob.glob(os.path.join(MRs_sufx, 'train_mr', '*.nii.gz')))
fnames_train_B_2d = sorted(glob.glob(os.path.join(CTs_sufx, 'train_ct', '*.nii.gz')))
fnames_train_C_2d = sorted(glob.glob(os.path.join(MASKs_sufx, 'train_mask', '*.nii.gz')))
train_dic_2d = [{"SRC": img1, "TGT": img2, "MASK": img3} for (img1,img2,img3) in zip(
    fnames_train_A_2d,
    fnames_train_B_2d,
    fnames_train_C_2d
)]

train_ds = CacheDataset(train_dic_2d, transforms_2d)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)


# fnames_train_A_2d = sorted(glob.glob(os.path.join(MRs_sufx, 'train_mr', '*.nii.gz')))
# fnames_train_B_2d = sorted(glob.glob(os.path.join(CTs_sufx, 'train_ct', '*.nii.gz')))
# train_dic_2d = [{"SRC": img1, "TGT": img2} for (img1,img2) in zip(
#     fnames_train_A_2d,
#     fnames_train_B_2d
# )]

# train_ds = CacheDataset(train_dic_2d, transforms_2d)
# train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

- Val

In [None]:
fnames_val_A_2d = sorted(glob.glob(os.path.join(MRs_sufx, 'val_mr', '*.nii.gz')))
fnames_val_B_2d = sorted(glob.glob(os.path.join(CTs_sufx, 'val_ct', '*.nii.gz')))
fnames_val_C_2d = sorted(glob.glob(os.path.join(MASKs_sufx, 'val_mask', '*.nii.gz')))
val_dic_2d = [{"SRC": img1, "TGT": img2, "MASK": img3} for (img1,img2,img3) in zip(
    fnames_val_A_2d,
    fnames_val_B_2d,
    fnames_val_C_2d
)]

val_ds = CacheDataset(val_dic_2d, transforms_2d)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)


# fnames_val_A_2d = sorted(glob.glob(os.path.join(MRs_sufx, 'val_mr', '*.nii.gz')))
# fnames_val_B_2d = sorted(glob.glob(os.path.join(CTs_sufx, 'val_ct', '*.nii.gz')))
# val_dic_2d = [{"SRC": img1, "TGT": img2} for (img1,img2) in zip(
#     fnames_val_A_2d,
#     fnames_val_B_2d
# )]

# val_ds = CacheDataset(val_dic_2d, transforms_2d)
# val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)


In [None]:
check_data = first(train_loader)
print("first patch's shape: ", check_data["SRC"].shape)
plt.figure(figsize=(15,15))
plt.subplot(1,4,1)
plt.imshow(check_data["TGT"][0,0,:,:].detach().cpu().numpy().squeeze(), vmin=-1, vmax=1, cmap="gray")
plt.title('Target')
plt.subplot(1,4,2)
plt.imshow(check_data["SRC"][0,0,:,:].detach().cpu().numpy().squeeze(), vmin=-1, vmax=1, cmap="gray")
plt.title('Source')
plt.subplot(1,4,3)
plt.imshow(check_data["MASK"][0,0,:,:].detach().cpu().numpy().squeeze(), vmin=0, vmax=1, cmap="gray")
plt.title('Mask')


# Network training

In [None]:
from vjnetworks import Pix2Pix

In [None]:
class Options():
    def __init__(self):
        # model parameters
        self.in_channels = 1  # Adjust according to your input image channel dimensions
        self.out_channels = 1  # Adjust according to your output image channel dimensions
        self.num_filters_d = 128  # Adjust the number of filters in the discriminator
        self.num_layers_d = 4  # Adjust the number of layers in the discriminator (i.e. the receptive field)
        self.num_d = 2
        self.num_res_units_G = 10
        self.lambda_gan = 1 # Adjust the weight for the cycle consistency loss

        # training parameters
        self.num_epochs=200  # 300 is enough
        self.learning_rate = 2e-4  # typical value for CycleGAN
        self.lambda_identity = 0  # Adjust the weight for the identity loss
        self.lambda_bg = 2.0 # bg-air term used when working with weighted masks
        self.lambda_NGF=20 # best 20
        self.alpha_NGF=.08 # best .15
        self.lambda_l1=140 #  # best 100 for pix2pix

opt=Options()

In [None]:
import torch.optim
import torch.nn.functional as F
import monai.networks.nets as nets

pix2pix_model = Pix2Pix(
    in_channels=opt.in_channels,
    out_channels=opt.out_channels,
    num_d=opt.num_d,
    num_layers_d=opt.num_layers_d,
    num_filters_d=opt.num_filters_d,
    num_res_units_G=opt.num_res_units_G,
)
air = -1.0 # usually min value of normalizer
optimizer = torch.optim.Adam(pix2pix_model.parameters(), lr=opt.learning_rate)

#### load checkpoints

In [None]:

weights_path = 'Weights/Gamma_crop-n2-l4-f128_GAN1_L140.00_NGF20.00_a0.08_e0074.h5'
checkpoint = torch.load(weights_path)
pix2pix_model.load_state_dict(checkpoint['model'], strict=False)
pix2pix_model.to(gpu_device).eval()


#### suite

In [None]:
EXPERIMENT_PREFIX='Gamma_crop-n%d-l%d-f%d' % (opt.num_d, opt.num_layers_d, opt.num_filters_d) + '_GAN%d_L%.2f_NGF%.2f_a%.2f' % (opt.lambda_gan, opt.lambda_l1, opt.lambda_NGF, opt.alpha_NGF)

weights_dir=os.path.join(weights_output_dir, EXPERIMENT_PREFIX)
os.makedirs(weights_dir, exist_ok=True)

In [None]:
import imageio.v2 as imageio
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.exposure import match_histograms


In [None]:
import pandas as pd

csv_file = 'logs-n%d-l%d-f%d' % (opt.num_d, opt.num_layers_d, opt.num_filters_d) + '_GAN%d_L%.2f_NGF%.2f_a%.2f' % (opt.lambda_gan, opt.lambda_l1, opt.lambda_NGF, opt.alpha_NGF) +'.csv'
columns = ['epoch', 'psnr global','local psnr', 'loss 1', 'loss 2','G loss', 'D loss']
if os.path.exists(os.path.join(weights_dir,csv_file)):
  df = pd.read_csv(os.path.join(weights_dir,csv_file))
else:
  df = pd.DataFrame(columns=columns)
  df.to_csv(os.path.join(weights_dir,csv_file), index=False)

def update_csv(epoch, psnr,lpsnr, loss1, loss2, gloss, dloss):
  global df
  new_row = pd.DataFrame([[round(epoch,5),round( psnr,5),round(lpsnr,5), round(loss1,5), round(loss2,5), round(gloss,5), round(dloss,5)]], columns=columns)
  df = pd.concat([df, new_row], ignore_index=True)
  df.to_csv(os.path.join(weights_dir,csv_file), index=False)

#### sliceinferer

In [None]:
roi_size = ((256,256))
sw_batch_size=BATCH_SIZE

In [None]:
import torch.nn.functional as F

def dilate2d(M, pixels= 2):
    M = (M > 0.5).to(torch.float32)
    k = 2 * pixels + 1
    Md = F.max_pool2d(M, kernel_size=k, stride=1, padding=pixels)
    return (Md > 0.0).to(M.dtype)

#### training

In [None]:
PSNRs=[]
best_psnr=[]
for epoch in np.arange(76,opt.num_epochs):
# for epoch in range(opt.num_epochs):
    loop = tqdm(train_loader)
    loop.set_description(f"Epoch [{epoch}/{opt.num_epochs}]")

    # this is autoencoding for the beginning
    if epoch < 2:
        real_B_is = "SRC"
    else:
    # this is normal unpaired image translation behaviour
        real_B_is = "TGT"

    # real_B_is = "TGT"
    for batch_idx, real in enumerate(loop):
        # Transfer data to the device (CPU or GPU)
        real_A = real["SRC"].to(gpu_device)
        real_B = real[real_B_is].to(gpu_device)
        real_C = real["MASK"].to(gpu_device)
        real_C = dilate2d(real_C, pixels=2)

        # Forward pass
        fake_B, identity_B, pred_fake_B = pix2pix_model(real_A, real_B, is_training=True)
        generator_loss = 0.0

        real_Am = real_A*real_C
        real_Bm = real_B*real_C
        fake_Bm = fake_B*real_C
        identity_Bm = identity_B*real_C

        # ####
        # real_Am = real_A
        # real_Bm = real_B
        # fake_Bm = fake_B
        # identity_Bm = identity_B
        # ####

        if opt.lambda_l1:
            l1_loss = pix2pix_model.compute_l1_loss(fake_Bm, real_Bm)
            generator_loss += opt.lambda_l1*l1_loss
        if opt.lambda_NGF:
            NGF_loss = pix2pix_model.compute_NGF_loss(fake_Bm, real_Am, opt.alpha_NGF)
            generator_loss += opt.lambda_NGF*NGF_loss
        if opt.lambda_gan:
            pred_fake_Bm = pix2pix_model.discriminator_B(fake_Bm)
            adv_loss = pix2pix_model.compute_adv_loss(pred_fake_Bm)
            generator_loss += opt.lambda_gan*adv_loss
        if opt.lambda_bg:
            bg_loss = ((1.0 - real_C) * (fake_B - air).abs()).mean()
            generator_loss += opt.lambda_bg*bg_loss
        if opt.lambda_identity:
            identity_loss = pix2pix_model.compute_l1_loss(real_Bm, identity_Bm)
            generator_loss += opt.lambda_identity*identity_loss

        # discriminator
        discriminator_loss = pix2pix_model.compute_discriminator_loss(real_Bm, fake_Bm.detach())

        # Backpropagation and optimization
        optimizer.zero_grad()
        generator_loss.backward()
        discriminator_loss.backward()
        optimizer.step()

        if opt.lambda_NGF:
            loop.set_postfix(losses=f"[ {(adv_loss * opt.lambda_gan).item():.2f} {(NGF_loss * opt.lambda_NGF).item():.5f}] ")

    # VALIDATION
    model = pix2pix_model.generator_A_to_B
    model.eval()

    with torch.no_grad():
      current_global_psnrs=[]
      current_local_psnrs=[]
      for i, real_val in enumerate(val_loader):
          real_A_val = real_val["SRC"].to(gpu_device)
          real_B_val = real_val["TGT"].to(gpu_device)
          real_C_val = real_val["MASK"].to(gpu_device)
          real_C_val = dilate2d(real_C_val, pixels=2)

          fake_B_val = sliding_window_inference(
                        roi_size=roi_size,
                        inputs=real_A_val,
                        sw_batch_size=sw_batch_size,
                        predictor=model,
                        overlap=0.75,
                        mode='gaussian',
                        sigma_scale=0.5,
                        device=gpu_device,
                        padding_mode="replicate",
                    )


          # saving val samples img
          plt.figure(figsize=(15,5))
          plt.subplot(1,3,1)
          plt.title("MR")
          plt.imshow(real_A_val[0,0,:,:,].cpu().detach().numpy().squeeze(), vmin=-1, vmax=1, cmap='gray')
          plt.subplot(1,3,2)
          plt.title("CT")
          plt.imshow(real_B_val[0,0,:,:,].cpu().detach().numpy().squeeze(), vmin=-1, vmax=1, cmap='gray')
          plt.subplot(1,3,3)
          plt.title("FAKE")
          plt.imshow(fake_B_val[0,0,:,:,].cpu().detach().numpy().squeeze(), vmin=-1, vmax=1, cmap='gray')
          fprefix = real_val['SRC_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0]
          fname_out = os.path.join(weights_dir, (fprefix+'_result_e%.3d.png' % epoch))
          plt.savefig(fname_out, bbox_inches='tight')
          plt.close()

          # local psnr on masked images
          EPS = 1e-8
          num = ((fake_B_val - real_B_val)**2 * real_C_val).flatten(1).sum(1)
          den = real_C_val.flatten(1).sum(1).clamp_min(EPS)
          mse = num / den
          psnr_masked = 20*torch.log10(torch.tensor(2.0, device=gpu_device)) - 10*torch.log10(mse + EPS)
          val_psnr_masked = float(psnr_masked.mean())
          current_local_psnrs.append(val_psnr_masked)


          # global psnr
          P = fake_B_val.detach().cpu().numpy().astype(np.float32)
          G = real_B_val.detach().cpu().numpy().astype(np.float32)

          current_global_psnrs.extend([psnr(G[b, 0], P[b, 0], data_range=2.0)for b in range(P.shape[0])])



    PSNRs.append(np.mean(np.asarray(current_global_psnrs)))
    best_psnr.append(PSNRs[-1])
    print('average PSNR = %.3f' % PSNRs[-1])
    update_csv(epoch, PSNRs[-1],np.mean(current_local_psnrs), (adv_loss * opt.lambda_gan).item(), (NGF_loss * opt.lambda_NGF).item(), generator_loss.item(), discriminator_loss.item())

    if PSNRs[-1] == max(best_psnr):
        model_weights = pix2pix_model.state_dict()
        best_epoch = epoch

    if (epoch +1) %20 ==0:
        best_psnr = []
        torch.save({'model': model_weights}, os.path.join(weights_dir, EXPERIMENT_PREFIX+'_e%.4d.h5' % best_epoch))