In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from swin import SwinUNet2D
from mae_ssim_loss import MAE_SSIM_Loss
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import numpy as np
from glob import glob
import cv2
import matplotlib.pyplot as plt
import os
from scipy.ndimage import gaussian_filter, map_coordinates

torch.backends.cudnn.benchmark = True





Using device: cuda
Input: torch.Size([1, 3, 256, 256])
Output: torch.Size([1, 1, 256, 256])


In [2]:
class MyCustomDataset(Dataset):
    def __init__(self,train_path,test_path,image_size=(256, 256),aug=True):
        self.train_path = train_path
        self.test_path = test_path
        self.image_size = image_size
        self.aug = aug

        self.train_images = sorted(glob(os.path.join(train_path, '*.npy')))
    def __len__(self):
        return len(self.train_images)

    def rotate(self,mr,ct):
        
        h, w = self.image_size
        center = (self.image_size[0] // 2, self.image_size[0] // 2)
        angle = np.random.uniform(-10, 10)
        rotation_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)
        mr_rotated = cv2.warpAffine(mr, rotation_matrix, (w, h),borderMode=cv2.BORDER_REFLECT_101)
        ct_rotated = cv2.warpAffine(ct, rotation_matrix, (w, h),borderMode=cv2.BORDER_REFLECT_101)
        return mr_rotated,ct_rotated
        
    def hflip(self,mr,ct):
        mr_flipped = cv2.flip(mr, 1)
        ct_flipped = cv2.flip(ct, 1)
        return mr_flipped,ct_flipped

    def scaled(self, mr, ct):
        h, w = self.image_size
    
        # Random scale factor between 0.9 (zoom out) and 1.1 (zoom in)
        scale = np.random.uniform(0.9, 1.1)
    
        # Resize image
        scaled_mr = cv2.resize(mr, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        scaled_ct = cv2.resize(ct, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
    
        # Get new size
        new_h, new_w = scaled_mr.shape[:2]
    
        # Crop or pad to original size (centered)
        top = max((new_h - h) // 2, 0)
        left = max((new_w - w) // 2, 0)
        bottom = top + h
        right = left + w
    
        # If scaled image is larger — crop center
        if scale >= 1.0:
            scaled_mr = scaled_mr[top:bottom, left:right]
            scaled_ct = scaled_ct[top:bottom, left:right]
        else:
            # If scaled image is smaller — pad to original size
            pad_top = (h - new_h) // 2
            pad_bottom = h - new_h - pad_top
            pad_left = (w - new_w) // 2
            pad_right = w - new_w - pad_left
    
            scaled_mr = cv2.copyMakeBorder(scaled_mr, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_REFLECT_101)
            scaled_ct = cv2.copyMakeBorder(scaled_ct, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_REFLECT_101)
    
        return scaled_mr, scaled_ct
        
    def translate(self, mr, ct):
        h, w = self.image_size
    
        # Max shift: 10% of width and height
        max_shift_x = int(0.1 * w)
        max_shift_y = int(0.1 * h)
    
        # Random shifts in x and y directions
        tx = np.random.randint(-max_shift_x, max_shift_x + 1)
        ty = np.random.randint(-max_shift_y, max_shift_y + 1)
    
        # Create translation matrix
        translation_matrix = np.float32([[1, 0, tx], [0, 1, ty]])
    
        # Apply translation
        mr_translated = cv2.warpAffine(mr, translation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT_101)
        ct_translated = cv2.warpAffine(ct, translation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT_101)
    
        return mr_translated, ct_translated

    def elastic_deformation(self,image, dx, dy):
        shape = image.shape[:2]
    
        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
        indices = np.reshape(y + dy, (-1)), np.reshape(x + dx, (-1))
    
        if image.ndim == 3:
            channels = []
            for c in range(image.shape[2]):
                channel = map_coordinates(image[..., c], indices, order=1, mode='reflect').reshape(shape)
                channels.append(channel)
            return np.stack(channels, axis=-1).astype(image.dtype)
        else:
            return map_coordinates(image, indices, order=1, mode='reflect').reshape(shape).astype(image.dtype)
    
    def elastic(self, mr, ct, alpha=30, sigma=4):
        shape = mr.shape[:2]
        random_state = np.random.RandomState(None)
    
        dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
        dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
    
        mr_deformed = self.elastic_deformation(mr, dx, dy)
        ct_deformed = self.elastic_deformation(ct, dx, dy)
    
        return mr_deformed, ct_deformed

    def augment(self, mr, ct):
        # You can tweak these probabilities (0.5 = 50% chance)
        if np.random.rand() < 0.5:
            mr, ct = self.rotate(mr, ct)
        if np.random.rand() < 0.5:
            mr, ct = self.hflip(mr, ct)
        if np.random.rand() < 0.5:
            mr, ct = self.scaled(mr, ct)
        if np.random.rand() < 0.5:
            mr, ct = self.translate(mr, ct)
        if np.random.rand() < 0.2:
            mr, ct = self.elastic(mr, ct)
    
        return mr, ct
    
    def __getitem__(self, idx):
        mr_path = self.train_images[idx]
        filename = os.path.basename(mr_path)

        # CT file (should match filename)
        ct_path = os.path.join(self.test_path, filename)
        # print(mr_path)
        mr_img = np.load(mr_path)
        ct_img = np.load(ct_path)

        if self.image_size is not None:
            mr_img = cv2.resize(mr_img, self.image_size)
            ct_img = cv2.resize(ct_img, self.image_size)
            
        if self.aug:
            mr_img, ct_img = self.augment(mr_img, ct_img)

        if mr_img.ndim == 2:
            mr_img = np.expand_dims(mr_img, axis=-1)
        if ct_img.ndim == 2:
            ct_img = np.expand_dims(ct_img, axis=-1)



        mr_tensor = torch.from_numpy(mr_img).permute(2, 0, 1).float()
        ct_tensor = torch.from_numpy(ct_img).permute(2, 0, 1).float()

        return mr_tensor, ct_tensor


In [3]:
dataTrain = MyCustomDataset('../../ct_mr_stdscale//train/top//mr/','../../ct_mr_stdscale/train/top//ct/')
dataVal = MyCustomDataset('../../ct_mr_stdscale/val/top//mr/','../../ct_mr_stdscale/val/top//ct/',aug=False)

In [4]:
CLIP_MIN, CLIP_MAX = -1024, 3000

def ct_denorm(norm_data, vmin=CLIP_MIN, vmax=CLIP_MAX):
    data = (norm_data + 1.0) / 2.0 * (vmax - vmin) + vmin
    return data.astype(np.float32)

In [5]:
val_loader = DataLoader(dataVal, batch_size=1, shuffle=False)
train_loader = DataLoader(dataTrain, batch_size=8, shuffle=True)

In [6]:
model = SwinUNet2D(in_ch=1, out_ch=1, base_dim=32).to('cuda')


In [7]:
def train_model(model, train_loader, val_loader, epochs=200, lr=1e-4, alpha=0.5, save_path='best_model_top.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    loss_fn = MAE_SSIM_Loss(alpha=alpha)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", leave=False)
        for imgs, targets in loop:
            imgs, targets = imgs.to(device), targets.to(device)

            # Forward
            preds = model(imgs)
            loss = loss_fn(preds, targets)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        avg_train_loss = train_loss / len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, targets in val_loader:
                imgs, targets = imgs.to(device), targets.to(device)
                preds = model(imgs)
                loss = loss_fn(preds, targets)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        # --- Save best model ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"✅ Saved best model (Val Loss: {best_val_loss:.4f})")

    print("Training finished.")

In [None]:
train_model(model, train_loader, val_loader)


                                                                            

Epoch [1/200] | Train Loss: 0.2742 | Val Loss: 0.1607
✅ Saved best model (Val Loss: 0.1607)


                                                                             

Epoch [2/200] | Train Loss: 0.1256 | Val Loss: 0.1182
✅ Saved best model (Val Loss: 0.1182)


                                                                             

Epoch [3/200] | Train Loss: 0.1110 | Val Loss: 0.1134
✅ Saved best model (Val Loss: 0.1134)


                                                                             

Epoch [4/200] | Train Loss: 0.1084 | Val Loss: 0.1102
✅ Saved best model (Val Loss: 0.1102)


                                                                             

Epoch [5/200] | Train Loss: 0.1066 | Val Loss: 0.1075
✅ Saved best model (Val Loss: 0.1075)


                                                                             

Epoch [6/200] | Train Loss: 0.1051 | Val Loss: 0.1073
✅ Saved best model (Val Loss: 0.1073)


                                                                             

Epoch [7/200] | Train Loss: 0.1035 | Val Loss: 0.1053
✅ Saved best model (Val Loss: 0.1053)


                                                                             

Epoch [8/200] | Train Loss: 0.1023 | Val Loss: 0.1028
✅ Saved best model (Val Loss: 0.1028)


                                                                             

Epoch [9/200] | Train Loss: 0.1000 | Val Loss: 0.1027
✅ Saved best model (Val Loss: 0.1027)


                                                                              

Epoch [10/200] | Train Loss: 0.0988 | Val Loss: 0.1021
✅ Saved best model (Val Loss: 0.1021)


                                                                              

Epoch [11/200] | Train Loss: 0.0982 | Val Loss: 0.1004
✅ Saved best model (Val Loss: 0.1004)


                                                                              

Epoch [12/200] | Train Loss: 0.0972 | Val Loss: 0.1014


                                                                              

Epoch [13/200] | Train Loss: 0.0972 | Val Loss: 0.0999
✅ Saved best model (Val Loss: 0.0999)


Epoch [14/200]:  15%|█▌        | 37/245 [00:21<02:00,  1.72it/s, loss=0.0922]

Epoch [15/200] | Train Loss: 0.0963 | Val Loss: 0.0997


Epoch [16/200]:  62%|██████▏   | 151/245 [01:24<00:52,  1.77it/s, loss=0.108] 