In [54]:
import sys
import os
import matplotlib.pyplot as plt

# Add the /KITE/unet/src folder to Python path so you can import pytorch_unet
src_path = '/Users/yamacomur/Desktop/KITE/unet/src'
if src_path not in sys.path:
    sys.path.append(src_path)


In [55]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.modules.loss import CrossEntropyLoss
from scipy.ndimage import zoom
from collections import defaultdict
import cv2 as cv
from pytorch_unet import UNet, DiceLoss_TUnet

In [56]:
import random
import numpy as np
import torch

def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

print("Seed is set")

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


Seed is set


In [57]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_unet import UNet, DiceLoss_TUnet 

# Initialize model
model = UNet(n_class=config["num_class"], f_size=config["feature_map_size"], task_no=config["num_task"])
model = model.to(device)

# Loss functions
dice_loss = DiceLoss_TUnet(n_classes=config["num_class"])
bce_loss = nn.CrossEntropyLoss(reduction="none")  # for pixel-wise loss

# Optimizer and scheduler
optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

print("UNet model, loss, and optimizer ready")


NameError: name 'config' is not defined

In [None]:
import os
import cv2 as cv
import torch
from torch.utils.data import Dataset
import numpy as np
from scipy.ndimage import zoom
from torchvision import transforms

class OCTDataset(Dataset):
    num_classes = 10
    def __init__(self, image_dir, label_dir, file_list_path, input_size=(224, 512), transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.input_size = input_size
        self.transform = transform or transforms.ToTensor()

        with open(file_list_path, 'r') as f:
            self.filenames = [line.strip() for line in f if line.strip()]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]

        img_path = os.path.join(self.image_dir, filename)
        label_path = os.path.join(self.label_dir, filename)

        image = cv.imread(img_path, 0)  # grayscale
        # Load label
        label = cv.imread(label_path, 0)  # Grayscale
        
        assert label.min() >= 0 and label.max() < self.num_classes, f"Label values out of range: min={label.min()}, max={label.max()}"

        if image is None:
            raise FileNotFoundError(f"Image not found or unreadable: {img_path}")
        if label is None:
            raise FileNotFoundError(f"Label not found or unreadable: {label_path}")

        if image.shape != self.input_size:
            image = zoom(image, (self.input_size[0]/image.shape[0], self.input_size[1]/image.shape[1]), order=3)
        if label.shape != self.input_size:
            label = zoom(label, (self.input_size[0]/label.shape[0], self.input_size[1]/label.shape[1]), order=0)

        unique_values, counts = np.unique(label, return_counts=True)
        # print(f"File {filename} - Label unique values: {unique_values}, counts: {counts}") -> This part works, it's for debugging.

        
        if np.std(image) > 0:
            image = (image - np.min(image)) / (np.max(image) - np.min(image))
        else:
            image = image / 255.0

        image_tensor = torch.from_numpy(image).unsqueeze(0).float()  # shape: [1, H, W]
        label_tensor = torch.tensor(label, dtype=torch.long) 

        return image_tensor, label_tensor


In [59]:
#paths
image_dir = "/Users/yamacomur/Desktop/Spring 2025/COMP491/Data/duke_original/image"
label_dir = "/Users/yamacomur/Desktop/Spring 2025/COMP491/Data/duke_original/layer"
file_list_path = "/Users/yamacomur/Desktop/KITE/unet/contains_lesion/fold1/train.txt"
val_file_list_path = "/Users/yamacomur/Desktop/KITE/unet/contains_lesion/fold1/val.txt"

# datasets
train_dataset = OCTDataset(image_dir=image_dir, label_dir=label_dir, file_list_path=file_list_path)
val_dataset = OCTDataset(image_dir=image_dir, label_dir=label_dir, file_list_path=val_file_list_path)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, worker_init_fn=seed_worker)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, worker_init_fn=seed_worker)

dataloaders = {
    "train": train_loader,
    "val": val_loader
}



In [60]:
from collections import defaultdict
import torch.nn.functional as F
import time
import copy

def calc_loss(pred, target, loss_fn, metrics, phase):
    bce = loss_fn(pred, target)
    loss = bce

    metrics['bce'] += bce.item() * target.size(0)
    metrics['loss'] += loss.item() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    print(f"--- {phase.upper()} ---")
    for k in metrics.keys():
        print(f"{k}: {metrics[k] / epoch_samples:.4f}")
    print("---------------")

def train_model(model, dataloaders, optimizer, loss_fn, scheduler, num_epochs=10):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')

    for epoch in range(1, num_epochs + 1):
        print(f"\nEpoch {epoch}/{num_epochs}")

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            metrics = defaultdict(float)
            epoch_samples = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device).float()
                labels = labels.to(device).long()

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, loss_fn, metrics, phase)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                epoch_samples += inputs.size(0)

            epoch_loss = metrics['loss'] / epoch_samples
            if phase == 'val':
                scheduler.step(epoch_loss)
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    print("Best model updated")

            print_metrics(metrics, epoch_samples, phase)

    print(f"\nTraining complete. Best val loss: {best_loss:.4f}")
    model.load_state_dict(best_model_wts)
    return model


In [61]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_unet import UNet, DiceLoss_TUnet

# Configuration
num_classes = 10          # foreground + background
num_tasks = 1                  # we're only doing segmentation
feature_map_size = 16
lr = 0.001

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = UNet(n_class=num_classes, f_size=feature_map_size, task_no=num_tasks)
model = model.to(device)

# Loss function
loss_fn = DiceLoss_TUnet(n_classes=num_classes)

# Optimizer & scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)

# Better scheduler with faster response
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
)

print("UNet model, loss, and optimizer ready ")


UNet model, loss, and optimizer ready 


In [62]:
import time
from collections import defaultdict
import copy

def multiclass_loss(outputs, targets):
    # Cross-entropy loss
    ce_loss = F.cross_entropy(outputs, targets)
    
    # Dice loss for multi-class
    dice_loss_val = DiceLoss_TUnet(n_classes=num_classes)(outputs, targets)
    
    # If dice_loss returns per-class values, reduce to scalar
    if isinstance(dice_loss_val, torch.Tensor) and dice_loss_val.numel() > 1:
        dice_loss_val = dice_loss_val.mean()
    
    # Combined loss (adjust weights as needed)
    loss = 0.5 * ce_loss + 0.5 * dice_loss_val
    return loss

# Set as your loss function
loss_fn = multiclass_loss

def train_model(model, dataloaders, optimizer, loss_fn, scheduler, num_epochs=100):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 30)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            epoch_samples = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # Forward pass
                    outputs = model(inputs)
                    
                    # Get loss - but ensure it's a scalar
                    # If your loss_fn returns a tensor, reduce it to a scalar
                    raw_loss = loss_fn(outputs, labels)
                    
                    # IMPORTANT FIX: Make sure the loss is a scalar
                    if isinstance(raw_loss, torch.Tensor) and raw_loss.numel() > 1:
                        loss = raw_loss.mean()  # Use mean to reduce to scalar
                    else:
                        loss = raw_loss
                    
                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()  # This will now work because loss is a scalar
                        optimizer.step()

                # Statistics
                batch_size = inputs.size(0)
                running_loss += loss.item() * batch_size
                epoch_samples += batch_size

            # Print epoch statistics
            epoch_loss = running_loss / epoch_samples
            print(f"{phase.capitalize()} Loss: {epoch_loss:.4f}")
            
            # Store history
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
            else:
                history['val_loss'].append(epoch_loss)
                # Update scheduler
                scheduler.step(epoch_loss)
                # Check if this is the best model
                if epoch_loss < best_loss:
                    print("Saving best model...")
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

    # Load best weights
    model.load_state_dict(best_model_wts)
    print("\nTraining complete. Best val loss: {:.4f}".format(best_loss))
    return model

In [63]:
# Run training
trained_model = train_model(
    model=model,
    dataloaders=dataloaders,
    optimizer=optimizer,
    loss_fn=loss_fn,
    scheduler=scheduler,
    num_epochs=200 #CHANGE IT
)

# Save the trained model in TorchScript format
traced_model = torch.jit.trace(trained_model.cpu(), torch.rand(1, 1, 224, 512))
save_path = "/Users/yamacomur/Desktop/KITE/unet/notebooks/unet_traced.pt"
traced_model.save(save_path)

print(f" Trained UNet model saved to: {save_path}")



Epoch 1/200
------------------------------
Train Loss: 1.4823
Val Loss: 1.4198
Saving best model...

Epoch 2/200
------------------------------
Train Loss: 1.3886
Val Loss: 1.3057
Saving best model...

Epoch 3/200
------------------------------
Train Loss: 1.3037
Val Loss: 1.2271
Saving best model...

Epoch 4/200
------------------------------
Train Loss: 1.1875
Val Loss: 1.2355

Epoch 5/200
------------------------------
Train Loss: 1.1792
Val Loss: 1.1172
Saving best model...

Epoch 6/200
------------------------------
Train Loss: 1.1239
Val Loss: 1.0829
Saving best model...

Epoch 7/200
------------------------------
Train Loss: 1.0852
Val Loss: 1.0334
Saving best model...

Epoch 8/200
------------------------------
Train Loss: 1.0380
Val Loss: 0.9392
Saving best model...

Epoch 9/200
------------------------------
Train Loss: 0.9924
Val Loss: 0.8785
Saving best model...

Epoch 10/200
------------------------------
Train Loss: 0.9050
Val Loss: 0.8657
Saving best model...

Epoch 11/

In [None]:
import torch

model = model.cpu()
model.eval()

dummy_input = torch.rand(1, 1, 224, 512)

traced_model = torch.jit.trace(model, dummy_input)

save_path = "/Users/yamacomur/Desktop/KITE/unet/notebooks/unet_traced.pt"
traced_model.save(save_path)

print(f"Traced UNet model saved to: {save_path}")


Traced UNet model saved to: /Users/yamacomur/Desktop/KITE/unet/notebooks/unet_traced.pt
