In [None]:
#library installations
%pip install pydicom SimpleITK numpy

In [None]:
import os
import pydicom
import numpy as np
import SimpleITK as sitk

def load_dicom_volume(folder_path):
    # Load all DICOM files in the folder
    dicoms = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith('.dcm'):
            dicom = pydicom.dcmread(os.path.join(folder_path, filename))
            dicoms.append(dicom)

    # Sort slices by ImagePositionPatient or InstanceNumber
    dicoms.sort(key=lambda x: float(x.ImagePositionPatient[2]) if 'ImagePositionPatient' in x else int(x.InstanceNumber))

    # Stack slices into 3D array
    image_stack = np.stack([d.pixel_array for d in dicoms])

    # Get spacing info
    try:
        spacing = list(map(float, dicoms[0].PixelSpacing))  # in-plane spacing
        slice_thickness = float(dicoms[0].SliceThickness)
        spacing.append(slice_thickness)
    except:
        spacing = [1.0, 1.0, 1.0]  # fallback if tags missing

    return image_stack, spacing

In [None]:
def resample_volume(volume, original_spacing, new_spacing=[1.0, 1.0, 1.0]):
    original_spacing = np.array(original_spacing[::-1])  # DICOM order: z, y, x
    new_spacing = np.array(new_spacing)
    
    resize_factor = original_spacing / new_spacing
    new_shape = np.round(np.array(volume.shape) * resize_factor).astype(int)

    volume_sitk = sitk.GetImageFromArray(volume)
    volume_sitk.SetSpacing(original_spacing.tolist())

    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(new_spacing.tolist())
    resampler.SetSize([int(s) for s in new_shape[::-1]])
    resampler.SetInterpolator(sitk.sitkLinear)

    resampled = resampler.Execute(volume_sitk)
    return sitk.GetArrayFromImage(resampled)

In [None]:
def normalize_ct(volume, clip_min=-1000, clip_max=400):
    volume = np.clip(volume, clip_min, clip_max)
    volume = (volume - clip_min) / (clip_max - clip_min)  # normalize to [0, 1]
    return volume.astype(np.float32)

In [None]:
def load_and_process_dicom(folder_path):
    volume, spacing = load_dicom_volume(folder_path)
    resampled = resample_volume(volume, spacing, [1.0, 1.0, 1.0])
    normalized = normalize_ct(resampled)
    return normalized  # shape: (D, H, W)

In [None]:
def create_montage_tensor(volume):
    slices = get_10_montage_slices(volume)
    slices = [preprocess_slice(s) for s in slices]
    montage = np.stack(slices)  # shape: (10, H, W)
    montage = montage[:, np.newaxis, :, :]  # (10, 1, H, W)
    montage = np.transpose(montage, (1, 0, 2, 3))  # (1, 10, H, W)
    tensor = torch.tensor(montage, dtype=torch.float32)  # (1, 10, 224, 224)
    tensor = tensor.unsqueeze(0)  # add batch dim: (B=1, C=1, D=10, H, W)
    return tensor

In [None]:
import torch
from torch.utils.data import Dataset
import pandas as pd

class DicomMontageDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels_df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        patient_id = self.labels_df.iloc[idx]['patient_id']
        label = self.labels_df.iloc[idx]['label']
        dicom_folder = os.path.join(self.root_dir, patient_id)
        
        # Process volume
        volume = load_and_process_dicom(dicom_folder)
        tensor = create_montage_tensor(volume)

        if self.transform:
            tensor = self.transform(tensor)
            
        return tensor.squeeze(0), torch.tensor(label, dtype=torch.long)

In [None]:
#partition the dataset into train:validate:test 75%:12.5%:12.5%

from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Subset, DataLoader

#full dataset and loader
dataset = DicomMontageDataset(csv_file='labels.csv', root_dir='dataset')
loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

# `labels` is a list or array of all corresponding labels in order
labels = [label for _, label in dataset]

# First split: 75% train, 25% temp (val + test)
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
for train_idx, temp_idx in sss1.split(X=labels, y=labels):
    pass

# Now split the temp into 50/50 (which is 12.5% val, 12.5% test)
temp_labels = [labels[i] for i in temp_idx]
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
for val_idx_rel, test_idx_rel in sss2.split(X=temp_labels, y=temp_labels):
    val_idx = [temp_idx[i] for i in val_idx_rel]
    test_idx = [temp_idx[i] for i in test_idx_rel]

# Create subsets
train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)
test_dataset = Subset(dataset, test_idx)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


In [None]:
#Create the model

from monai.networks.nets import se_resnet50
import torch.nn as nn
import torch

# Load pretrained SE-ResNet50 3D
model = se_resnet50(spatial_dims=3, in_channels=1, num_classes=2)  # num_classes = 2 for binary classification

# Optional: move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

## Optionally freeze early layers to prevent fine tuning the whole network
# for param in model.layer1.parameters():
#     param.requires_grad = False

num_epochs = 10

for epoch in range(num_epochs):
    # --- Training ---
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)  # (B, 1, 10, 224, 224)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {running_loss:.4f}, Train Accuracy: {train_acc:.2f}%")

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for val_inputs, val_labels in val_loader:
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.to(device)

            val_outputs = model(val_inputs)
            loss = criterion(val_outputs, val_labels)

            val_loss += loss.item()
            _, val_predicted = val_outputs.max(1)
            val_correct += val_predicted.eq(val_labels).sum().item()
            val_total += val_labels.size(0)

    val_acc = 100 * val_correct / val_total
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%\n")


In [None]:
#Test the model 
def test_model(model, test_loader, criterion, device):
    model.eval()  # Set model to evaluation mode
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation for speed
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    avg_loss = test_loss / len(test_loader)
    accuracy = 100 * correct / total
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

test_model(model, test_loader, criterion, device)
