In [1]:
from unet.model import UNet3DfMRI
from unet.dataset import FMRI3DDataset

In [2]:

import torch
import os
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from tqdm import tqdm  
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


from unet.dataset import FMRI3DDataset
from unet.evaluate import evaluate_model


# -----------------------------
# Training 
# -----------------------------
def train_model(model: nn.Module, dataloader: DataLoader, device, epochs: int = 6, lr: float = 1e-4, writer: SummaryWriter = None, resume_from: int = 0):
    model.train()
    model.to(device)
    current_lr = lr
    optimizer = torch.optim.Adam(model.parameters(), lr=current_lr)
    lr_updated = False  # Flag to track lr change
    loss_fn = nn.MSELoss()

    global_step = 0

    save_dir = f"./saved_images/"
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(resume_from, epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

        for batch_idx, (x, y) in enumerate(progress_bar):
            x, y = x.to(device), y.to(device)
            pred = model(x)

            # Delete extra padded dimension before computing loss
           

            # Create masks
            threshold = 0.3  
            bright_mask = (y > threshold).float()
            dark_mask = 1.0 - bright_mask 

            per_voxel_loss = loss_fn(pred, y)

            # Apply masks
            bright_loss = (per_voxel_loss * bright_mask).sum() / bright_mask.sum().clamp(min=1.0)
            dark_loss = (per_voxel_loss * dark_mask).sum() / dark_mask.sum().clamp(min=1.0)

            # Weighted total loss
            loss = 0.8 * bright_loss + 0.2 * dark_loss
            loss *= 1e2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

            #Scalar logging every 10 iterations
            if global_step % 10 == 0:
                writer.add_scalar("Loss/train", loss.item(), global_step)
                pred_np = pred.detach().cpu().numpy()
                y_np = y.detach().cpu().numpy()

                psnr_total = 0.0
                ssim_total = 0.0
                
                # Calculate over entire batch
                for i in range(pred_np.shape[0]):  
                    pred_vol = np.squeeze(pred_np[i])
                    y_vol = np.squeeze(y_np[i])

                    psnr = peak_signal_noise_ratio(y_vol, pred_vol, data_range=1.0)
                    ssim = structural_similarity(y_vol, pred_vol, data_range=1.0)

                    psnr_total += psnr
                    ssim_total += ssim

                avg_psnr = psnr_total / pred_np.shape[0]
                avg_ssim = ssim_total / pred_np.shape[0]

                writer.add_scalar("Train/Loss", loss.item(), global_step)
                writer.add_scalar("Train/PSNR", avg_psnr, global_step)
                writer.add_scalar("Train/SSIM", avg_ssim, global_step)

            # log every 1/20 epoch
            if batch_idx % (len(dataloader) // 20 + 1) == 0:
                with torch.no_grad():
                    x_slice = x[0, 0, :, :, x.shape[4] // 2].cpu().unsqueeze(0)
                    y_slice = y[0, 0, :, :, y.shape[4] // 2].cpu().unsqueeze(0)
                    pred_slice = pred[0, 0, :, :, pred.shape[4] // 2].cpu().unsqueeze(0)

                    grid = make_grid(torch.stack([x_slice, y_slice, pred_slice]), nrow=3, normalize=False)
                    writer.add_image(f"Epoch_{epoch+1}/Input_GT_Pred", grid, global_step)

            global_step += 1

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1}/{epochs}, Avg Loss: {avg_loss:.4f}")

        torch.save(model.state_dict(), f"./runs/checkpoints/final_run/model_weights_epoch{epoch + 1}.pth")


In [5]:
root_dir = "./data"
batch_size = 4
num_epochs = 6
learning_rate = 1e-4
writer = SummaryWriter('runs/final_train_adam/')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

full_dataset = FMRI3DDataset(root_dir)

train_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=8)

model = UNet3DfMRI()
print("Starting training...")

# Start from pretrained model
#model.load_state_dict(torch.load('/home/asharab/Documents/Masters/MRI_PROJECT/runs/checkpoints/final_run/model_weights_epoch2.pth'))

model = torch.jit.script(model)
train_model(model, train_loader, device, epochs=num_epochs, lr=learning_rate, writer=writer, resume_from=0)

writer.close()

Starting training...


                                                                            

Epoch 1/6, Avg Loss: 6.4702


                                                                             

Epoch 2/6, Avg Loss: 0.1062


                                                                             

Epoch 3/6, Avg Loss: 0.0588


                                                                             

Epoch 4/6, Avg Loss: 0.0498


                                                                             

Epoch 5/6, Avg Loss: 0.0326


                                                                              

Epoch 6/6, Avg Loss: 0.0222


In [6]:
from unet.evaluate import evaluate_model

validation_dataset_dir = "./validation"
batch_size = 2

full_dataset = FMRI3DDataset(validation_dataset_dir)
validation_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=False)

evaluate_model(model, validation_loader, device)

                                                                                      

Evaluation Results — PSNR: 30.2361, SSIM: 0.8826




(np.float64(30.2360656600988), np.float64(0.8826445585828765))