In [1]:
import cv2 
import matplotlib.pyplot as plt
import torch
from glob import glob
import os
import numpy as np
import torch.nn as nn 
from torchsummary import summary

In [2]:
from torchvision.utils import save_image


In [3]:
from tqdm import tqdm

In [4]:
from disciminator_model import Discriminator
from generator_model import Generator

In [5]:
# model = Model(3,3).to('cuda')

In [6]:
class DataSet(torch.utils.data.Dataset):
    def __init__(self, path_train,path_test,imgz=256): 
        self.data_train = glob(path_train + '/*' + '.jpg')
        self.data_test = glob(path_test + '/*' + '.jpg')
        self.imgz = imgz
        
    def __len__(self): 
        return len(self.data_train)

    def __getitem__(self,idx):
        #### train #####
        img_train = cv2.imread(self.data_train[idx])
        # print(self.data_train[idx])
        h,w,c = img_train.shape
        img_train_resize = cv2.resize(img_train,(self.imgz,self.imgz))
        img_train_resize = img_train_resize /255
        img_train_resize = np.transpose(img_train_resize, (2, 0, 1))  
        img_train_resize = torch.tensor(img_train_resize, dtype=torch.float32)
        #### test #####
        
        img_test = cv2.imread(self.data_test[idx])
        # print(self.data_test[idx])
        h,w,c = img_test.shape
        img_test_resize = cv2.resize(img_test,(self.imgz,self.imgz))
        img_test_resize = img_test_resize /255
        img_test_resize = np.transpose(img_test_resize, (2, 0, 1))  
        img_test_resize = torch.tensor(img_test_resize, dtype=torch.float32)
        return img_train_resize,img_test_resize
        

In [7]:
        

def train_fn(disc_H,disc_Z,gen_Z,gen_H,loader,opt_disc,opt_gen,l1,mse,d_scaler,g_scaler):
    loop = tqdm(loader,leave=True)
    
    for idx,(ct,mri) in enumerate(loop):
        ct = ct.to('cuda')
        mri = mri.to('cuda')
        #### Decriminator
        with torch.cuda.amp.autocast():
            fake_ct = gen_H(mri)
            D_ct_real = disc_H(ct)
            D_ct_fake = disc_H(fake_ct.detach())
            D_ct_real_loss = mse(D_ct_real, torch.ones_like(D_ct_real))
            D_ct_fake_loss = mse(D_ct_fake, torch.zeros_like(D_ct_fake))
            D_ct_loss = D_ct_real_loss + D_ct_fake_loss
            
            fake_mri = gen_Z(ct)
            D_mri_real = disc_Z(mri)
            D_mri_fake = disc_Z(fake_mri.detach())
            D_mri_real_loss = mse(D_mri_real, torch.ones_like(D_mri_real))
            D_mri_fake_loss = mse(D_mri_fake, torch.zeros_like(D_mri_fake))
            D_mri_loss = D_mri_real_loss + D_mri_fake_loss
            
            D_loss = (D_ct_loss + D_mri_loss)/2
        
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        #### Generator
        with torch.cuda.amp.autocast():
            D_ct_fake = disc_H(fake_ct)
            D_mri_fake = disc_Z(fake_mri)
            loss_G_ct = mse(D_ct_fake,torch.ones_like(D_ct_fake))
            loss_G_mri = mse(D_mri_fake,torch.ones_like(D_mri_fake))
            
            # cycle loss
            cycle_mri = gen_Z(fake_ct)
            cycle_ct = gen_H(fake_mri)
            cycle_mri_loss = l1(mri,cycle_mri)
            cycle_ct_loss = l1(ct,cycle_ct)
            
            # identity loss
            identity_mri = gen_Z(mri)
            identity_ct = gen_H(ct)
            identity_mri_loss = l1(mri,identity_mri)
            identity_ct_loss = l1(ct,identity_ct)
            
            # add all together
            g_loss = (
                loss_G_ct + loss_G_mri + cycle_ct_loss*10 + cycle_mri_loss*10 + identity_ct_loss*0 + identity_mri_loss*0
            )
        opt_gen.zero_grad()
        g_scaler.scale(g_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        if idx % 100 ==0:
            save_image(fake_ct,f"img_result/ct_{idx}.png")
            save_image(fake_mri,f"img_result/mri_{idx}.png")
    torch.save(gen_H.state_dict(), f"genH.pth")
    torch.save(gen_Z.state_dict(), f"genZ.pth")
    torch.save(disc_H.state_dict(), f"discH.pth")
    torch.save(disc_Z.state_dict(), f"discZ.pth")

    # print(f"✅ Models saved for epoch {epoch}")

In [8]:
# plt.imshow( torch.permute(test.__getitem__(0),(1,2,0)) )

In [9]:
train_dataset = DataSet('../../ct_mr_data/train/mr/','../../ct_mr_data/train/ct/')
test_dataset = DataSet('../../ct_mr_data/val//mr/','../../ct_mr_data/val/ct/')

# test_dataset = DataSet('../Dataset/images/trainB/','.jpg')

In [10]:
len(train_dataset)

15495

In [11]:
train_loader = torch.utils.data.DataLoader( 
    dataset=train_dataset, 
    batch_size=8, 
    shuffle=True, 
) 

test_loader = torch.utils.data.DataLoader( 
    dataset=test_dataset, 
    batch_size=1, 
) 

In [12]:
disc_H = Discriminator(in_channels=3).to('cuda')
disc_Z = Discriminator(in_channels=3).to('cuda')
gen_Z = Generator(img_channels=3,num_residuals=4).to('cuda')
geh_H = Generator(img_channels=3,num_residuals=4).to('cuda')
opt_disc = torch.optim.Adam(list(disc_H.parameters()) + list(disc_Z.parameters()),
                           lr=0.0002,
                            betas=(0.5,0.999)
                           )
opt_gen = torch.optim.Adam(list(gen_Z.parameters()) + list(geh_H.parameters()),
                           lr=0.0002,
                            betas=(0.5,0.999)
                           )

In [13]:
l1 = nn.L1Loss()
mse = nn.MSELoss()

In [14]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()


In [16]:
for epoch in range(100):
    train_fn(disc_H,disc_Z,gen_Z,geh_H,train_loader,opt_disc,opt_gen,l1,mse,d_scaler,g_scaler)
    

  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 1937/1937 [43:20<00:00,  1.34s/it]
100%|██████████| 1937/1937 [47:55<00:00,  1.48s/it]
100%|██████████| 1937/1937 [37:51<00:00,  1.17s/it] 
100%|██████████| 1937/1937 [32:25<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:21<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:21<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:21<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:20<00:00,  1.00s/it]
100%|██████████| 1937/1937 [32:21<00:00,  1.00s

KeyboardInterrupt: 

In [None]:
# num_epochs = 50
# model = Model(3,3).to('cuda')
# loss_fn =L1_SSIM_Loss()
# lr = 1e-2
# optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=0.9)

In [None]:
# import torch
# import numpy as np

# all_losses = []  # store mean loss for each epoch
# best_val_loss = float("inf")  # initialize with infinity
# best_model_path = "best_model.pth"

# for epoch in range(num_epochs):
#     print('Epoch:', epoch)
    
#     # -------------------- TRAINING --------------------
#     model.train()
#     train_losses = []
#     for b, (X, y) in enumerate(train_loader):
#         X, y = X.to('cuda'), y.to('cuda')
        
#         optimizer.zero_grad()
#         yHat = model(X)
#         loss = loss_fn(yHat, y)  # compute loss
#         loss.backward()
#         optimizer.step()
        
#         train_losses.append(loss.item())
    
#     epoch_loss = sum(train_losses) / len(train_losses)
#     all_losses.append(epoch_loss)
#     print('Train Loss Mean:', epoch_loss)
#     print(f'----------------------- End {epoch + 1} {b+1}/{len(train_loader)} ------------------')
    
#     # -------------------- EVALUATION --------------------
#     model.eval()
#     val_losses = []
#     with torch.no_grad():  # no gradient computation during evaluation
#         for X_val, y_val in test_loader:
#             X_val, y_val = X_val.to('cuda'), y_val.to('cuda')
#             yHat_val = model(X_val)
#             val_loss = loss_fn(yHat_val, y_val)
#             val_losses.append(val_loss.item())
    
#     if val_losses:
#         val_epoch_loss = sum(val_losses) / len(val_losses)
#         print('Validation Loss Mean:', val_epoch_loss)
        
#         # -------- Save best model --------
#         if val_epoch_loss < best_val_loss:
#             best_val_loss = val_epoch_loss
#             torch.save(model.state_dict(), 'best_model/' + best_model_path)
#             print(f"✅ Best model saved (val_loss={val_epoch_loss:.6f})")
    
# # -------------------- SAVE TRAINING LOSSES --------------------
# all_losses = np.array(all_losses)
# np.save('losses.npy', all_losses)
# print("Training losses saved to losses.npy")
# print(f"Best model path: {best_model_path}, Best val_loss={best_val_loss:.6f}")
