## Setup and Imports

In [None]:
! pip install numpy nibabel matplotlib pandas scipy tqdm plotly optuna SimpleITK hiddenlayer torch

In [None]:
import torch

In [None]:
# ! pip install causal-conv1d mamba-ssm

In [None]:
# ! pip uninstall torch

In [None]:
# ! pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126

In [None]:
# must be done after installing pytorch
# ! pip install nnunetv2 

In [None]:
# ! git clone https://github.com/MrBlankness/LightM-UNet
# ! cd LightM-UNet/lightm-unet
# ! pip install -e .

In [None]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import pandas as pd
import torch
from scipy.ndimage import zoom, rotate
from tqdm import tqdm
import torch.nn.functional as F
from matplotlib.widgets import Slider
import plotly.graph_objects as go
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
from torch.optim import Adam
from concurrent.futures import ThreadPoolExecutor, as_completed
import optuna
from optuna.pruners import MedianPruner
from torch.cuda.amp import autocast, GradScaler
import SimpleITK as sitk
import subprocess
import shutil
import json


## Dataset Import

In [None]:

images_dir = r"F:\aims_tbi\normalized_T1_scans"
masks_dir = r"F:\aims_tbi\resampled_1mm_Lesion_masks"


# print number of files in processed images and masks
print(f"Number of processed images: {len(os.listdir(images_dir))}")
print(f"Number of processed masks: {len(os.listdir(masks_dir))}")


In [None]:
# Define scan parameters
scan_id = 'scan_0001'
start_slice = 110
num_slices = 5

# Load the .nii.gz files
image_path = os.path.join(images_dir, f"{scan_id}_T1_normalized.nii.gz")
mask_path = os.path.join(masks_dir, f"{scan_id}_Lesion_resampled_1mm.nii.gz")

# Load image and mask using nibabel
image_nii = nib.load(image_path)
mask_nii = nib.load(mask_path)

image_array = image_nii.get_fdata().astype(np.float32)
mask_array = mask_nii.get_fdata().astype(np.uint8)

print(f"Loaded {scan_id}:")
print(f"Image shape: {image_array.shape}")
print(f"Mask shape: {mask_array.shape}")

# Create visualization
fig, axes = plt.subplots(2, num_slices, figsize=(num_slices * 3, 10))

# Ensure axes is 2D even for single slice
if num_slices == 1:
    axes = axes.reshape(-1, 1)

for i in range(num_slices):
    slice_idx = start_slice + i

    # Check if slice index is valid
    if slice_idx >= image_array.shape[2]:
        print(f"⚠️ Slice {slice_idx} is out of bounds (max: {image_array.shape[2]-1}), skipping...")
        continue

    # Extract slices (transpose for proper orientation)
    image_slice = image_array[:, :, slice_idx].T
    mask_slice = mask_array[:, :, slice_idx].T

    # Row 1: Processed T1 image only
    axes[0, i].imshow(image_slice, cmap='gray', origin='lower')
    axes[0, i].set_title(f'Processed T1\n(Normalized) Slice {slice_idx}', fontsize=10)
    axes[0, i].axis('off')

    # Row 2: Processed T1 + Lesion overlay
    axes[1, i].imshow(image_slice, cmap='gray', origin='lower')
    if np.any(mask_slice > 0):  # Only overlay if there are lesions in this slice
        axes[1, i].imshow(mask_slice, cmap='Reds', alpha=0.6, origin='lower')
    axes[1, i].set_title(f'Processed T1 + Lesion\nSlice {slice_idx}', fontsize=10)
    axes[1, i].axis('off')

plt.tight_layout()
plt.suptitle(f'Processed Images from NIfTI Files - {scan_id}', fontsize=16, y=0.98)
plt.show()

# Print intensity statistics
print(f"\n=== Statistics from NIfTI Files for {scan_id} ===")

# Image stats (brain voxels only)
brain_mask = image_array != 0  # Background is 0 after normalization
brain_voxels = image_array[brain_mask]

print("Processed T1 Image (from NIfTI file):")
print(f"  Mean: {np.mean(brain_voxels):.6f}")
print(f"  Std: {np.std(brain_voxels):.6f}")
print(f"  Min: {np.min(brain_voxels):.4f}")
print(f"  Max: {np.max(brain_voxels):.4f}")
print(f"  Shape: {image_array.shape}")

# Lesion statistics
lesion_voxels = np.count_nonzero(mask_array)
total_voxels = mask_array.size
lesion_percentage = (lesion_voxels / total_voxels) * 100

print(f"\nLesion Mask (from NIfTI file):")
print(f"  Lesion voxels: {lesion_voxels:,}")
print(f"  Total voxels: {total_voxels:,}")
print(f"  Lesion percentage: {lesion_percentage:.4f}%")
print(f"  Shape: {mask_array.shape}")

# If you have metadata, you can print it here (optional)
# print(f"\nMetadata Statistics:")
# print(f"  Brain voxels: ...")
# print(f"  Brain mean: ...")
# print(f"  Brain std: ...")
# print(f"

### Verify the mask values are limited to 0 and 1

In [None]:
# for fname in os.listdir(masks_dir):
#     if fname.endswith("_Lesion_resampled_1mm.nii.gz"):
#         mask_path = os.path.join(masks_dir, fname)
#         mask_array = nib.load(mask_path).get_fdata().astype(np.uint8)
#         unique_vals = np.unique(mask_array)
#         if np.any((unique_vals != 0) & (unique_vals != 1)):
#             print(f"{fname}: {unique_vals}")

## 3D-UNet setup

### Custom Dataset

In [None]:
class LesionDataset(Dataset):
    def __init__(self, images_dir, masks_dir, patch_size=(64, 64, 64), transform=None):
        self.images = sorted([f for f in os.listdir(images_dir) if f.endswith('.nii.gz')])
        self.masks = sorted([f for f in os.listdir(masks_dir) if f.endswith('.nii.gz')])
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.patch_size = patch_size
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.images[idx])
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        image = nib.load(img_path).get_fdata().astype(np.float32)
        mask = nib.load(mask_path).get_fdata().astype(np.float32)
        # Add channel dimension
        image = np.expand_dims(image, axis=0)
        mask = np.expand_dims(mask, axis=0)
        # Random crop
        D, H, W = image.shape[1:]
        pd, ph, pw = self.patch_size
        if D > pd and H > ph and W > pw:
            d = np.random.randint(0, D - pd + 1)
            h = np.random.randint(0, H - ph + 1)
            w = np.random.randint(0, W - pw + 1)
            image = image[:, d:d+pd, h:h+ph, w:w+pw]
            mask = mask[:, d:d+pd, h:h+ph, w:w+pw]
        return torch.from_numpy(image), torch.from_numpy(mask)

### 3D U-Net Model

In [None]:
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(UNet3D, self).__init__()
        features = init_features
        self.encoder1 = self._block(in_channels, features)
        self.pool1 = nn.MaxPool3d(2)
        self.encoder2 = self._block(features, features*2)
        self.pool2 = nn.MaxPool3d(2)
        self.encoder3 = self._block(features*2, features*4)
        self.pool3 = nn.MaxPool3d(2)
        self.encoder4 = self._block(features*4, features*8)
        self.pool4 = nn.MaxPool3d(2)

        self.bottleneck = self._block(features*8, features*16)

        self.up4 = nn.ConvTranspose3d(features*16, features*8, 2, stride=2)
        self.decoder4 = self._block(features*16, features*8)
        self.up3 = nn.ConvTranspose3d(features*8, features*4, 2, stride=2)
        self.decoder3 = self._block(features*8, features*4)
        self.up2 = nn.ConvTranspose3d(features*4, features*2, 2, stride=2)
        self.decoder2 = self._block(features*4, features*2)
        self.up1 = nn.ConvTranspose3d(features*2, features, 2, stride=2)
        self.decoder1 = self._block(features*2, features)

        self.conv = nn.Conv3d(features, out_channels, kernel_size=1)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.up4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.up3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.up2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.up1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

### Dice Metric Functions

In [None]:
def dice_score(pred, target, epsilon=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2. * intersection + epsilon) / (union + epsilon)
    return dice.item()

def dice_per_class(pred, target, epsilon=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()
    # Lesion (positive)
    dice_lesion = dice_score(pred, target, epsilon)
    # Background (negative)
    dice_bg = dice_score(1 - pred, 1 - target, epsilon)
    # Overall (mean)
    dice_overall = (dice_lesion + dice_bg) / 2
    return dice_lesion, dice_bg, dice_overall

### Loss Function (Dice + BCE for Imbalance)

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCELoss()

    def forward(self, pred, target):
        pred = pred.view(-1)
        target = target.view(-1)
        dice = 1 - dice_score(pred, target)
        bce = self.bce(pred, target)
        return dice + bce

### Training Loop

In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [None]:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset and DataLoader 
dataset = LesionDataset(images_dir, masks_dir, patch_size=(64, 64, 64))
train_size = int(0.85 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1)

# Model, Loss, Optimizer
model = UNet3D(init_features=16).to(device)
criterion = DiceBCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training
num_epochs = 50
best_val_dice = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0
    dice_lesion, dice_bg, dice_overall = 0, 0, 0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            d_lesion, d_bg, d_overall = dice_per_class(outputs, masks)
            dice_lesion += d_lesion
            dice_bg += d_bg
            dice_overall += d_overall
    val_loss /= len(val_loader)
    dice_lesion /= len(val_loader)
    dice_bg /= len(val_loader)
    dice_overall /= len(val_loader)

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Dice Lesion={dice_lesion:.4f}, Dice No Lesion={dice_bg:.4f}, Dice Overall={dice_overall:.4f}")

    # Save best model
    if dice_lesion > best_val_dice:
        best_val_dice = dice_lesion
        torch.save(model.state_dict(), "3dunet_model_new_preprocessing.pth")
        print("Saved best model.")

### Inference Example

In [None]:
# # Load model for inference
# model = UNet3D().to(device)
# model.load_state_dict(torch.load("best_3dunet_model.pth"))
# model.eval()

# # Example inference on a single scan
# with torch.no_grad():
#     img, _ = dataset[0]
#     img = img.unsqueeze(0).to(device)
#     pred = model(img)
#     pred_mask = (pred > 0.5).cpu().numpy().astype(np.uint8)[0,0]
#     # Save as NIfTI
#     nib.save(nib.Nifti1Image(pred_mask, np.eye(4)), "predicted_mask.nii.gz")