In [2]:
# Cell 1: Install dependencies
%pip install torchio --quiet



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:

# Cell 2: Imports
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
import numpy as np
import torchio as tio
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [4]:

# Cell 3: ISLES Dataset Loader (3D volumes for UNet)
class ISLESDataset3D(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        mask_root = os.path.join(root_dir, "derivatives")
        for subject in os.listdir(root_dir):
            if subject.startswith("sub-"):
                ses_dir = os.path.join(root_dir, subject, "ses-0001")
                if os.path.exists(ses_dir):
                    dwi_dir = os.path.join(ses_dir, "dwi")
                    anat_dir = os.path.join(ses_dir, "anat")
                    dwi_path = [f for f in os.listdir(dwi_dir) if f.endswith("_dwi.nii.gz")]
                    flair_path = [f for f in os.listdir(anat_dir) if f.endswith("_FLAIR.nii.gz")]
                    mask_dir = os.path.join(mask_root, subject, "ses-0001")
                    mask_path = []
                    if os.path.exists(mask_dir):
                        mask_path = [f for f in os.listdir(mask_dir) if f.endswith(".nii.gz")]
                    if dwi_path and flair_path and mask_path:
                        self.samples.append({
                            "dwi": os.path.join(dwi_dir, dwi_path[0]),
                            "flair": os.path.join(anat_dir, flair_path[0]),
                            "mask": os.path.join(mask_dir, mask_path[0])
                        })
        print(f"Total 3D samples: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        dwi = nib.load(sample["dwi"]).get_fdata()
        flair = nib.load(sample["flair"]).get_fdata()
        mask = nib.load(sample["mask"]).get_fdata()
        # Crop to minimum shape
       # After normalization, before stacking:
        crop_shape = (64, 64, 16)  # or smaller if needed
        dwi = dwi[:crop_shape[0], :crop_shape[1], :crop_shape[2]]
        flair = flair[:crop_shape[0], :crop_shape[1], :crop_shape[2]]
        mask = mask[:crop_shape[0], :crop_shape[1], :crop_shape[2]]
        # Pad to be divisible by 16
        def pad_to_16(arr):
            pad = [(0, (16 - s % 16) % 16) for s in arr.shape]
            return np.pad(arr, pad, mode='constant')
        dwi = pad_to_16(dwi)
        flair = pad_to_16(flair)
        mask = pad_to_16(mask)
        # Normalize
        dwi = (dwi - dwi.mean()) / (dwi.std() + 1e-5)
        flair = (flair - flair.mean()) / (flair.std() + 1e-5)
        x = np.stack([dwi, flair], axis=0).astype(np.float32)
        y = (mask > 0).astype(np.float32)
        return torch.tensor(x), torch.tensor(y).unsqueeze(0)


In [5]:

# Cell 4: DataLoader
dataset = ISLESDataset3D('data')
loader = DataLoader(dataset, batch_size=1, shuffle=True)


Total 3D samples: 248


In [6]:

class UNet3D(nn.Module):
    def __init__(self, in_channels=2, out_channels=1, init_features=16):
        super().__init__()
        features = init_features
        self.encoder1 = UNet3D._block(in_channels, features)
        self.pool1 = nn.MaxPool3d(2)
        self.encoder2 = UNet3D._block(features, features * 2)
        self.pool2 = nn.MaxPool3d(2)
        self.encoder3 = UNet3D._block(features * 2, features * 4)
        self.pool3 = nn.MaxPool3d(2)

        self.bottleneck = UNet3D._block(features * 4, features * 8)

        self.up3 = nn.ConvTranspose3d(features * 8, features * 4, 2, stride=2)
        self.decoder3 = UNet3D._block(features * 8, features * 4)
        self.up2 = nn.ConvTranspose3d(features * 4, features * 2, 2, stride=2)
        self.decoder2 = UNet3D._block(features * 4, features * 2)
        self.up1 = nn.ConvTranspose3d(features * 2, features, 2, stride=2)
        self.decoder1 = UNet3D._block(features * 2, features)

        self.conv = nn.Conv3d(features, out_channels, kernel_size=1)

    @staticmethod
    def _crop_to_match(src, tgt):
        src_shape = src.shape[2:]
        tgt_shape = tgt.shape[2:]
        crop = [(s - t) // 2 for s, t in zip(src_shape, tgt_shape)]
        slices = tuple(slice(c, c + t) for c, t in zip(crop, tgt_shape))
        return src[(...,) + slices]

    @staticmethod
    def _block(in_channels, features):
        return nn.Sequential(
            nn.Conv3d(in_channels, features, 3, padding=1),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True),
            nn.Conv3d(features, features, 3, padding=1),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))

        bottleneck = self.bottleneck(self.pool3(enc3))

        dec3 = self.up3(bottleneck)
        enc3_cropped = self._crop_to_match(enc3, dec3)
        dec3 = torch.cat((dec3, enc3_cropped), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.up2(dec3)
        enc2_cropped = self._crop_to_match(enc2, dec2)
        dec2 = torch.cat((dec2, enc2_cropped), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.up1(dec2)
        enc1_cropped = self._crop_to_match(enc1, dec1)
        dec1 = torch.cat((dec1, enc1_cropped), dim=1)
        dec1 = self.decoder1(dec1)
        return self.conv(dec1)
    


In [None]:
# Cell 6: Model, Loss, Optimizer


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet3D(in_channels=2, out_channels=1, init_features=8).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 1


: 

In [None]:

# Cell 7: Training Loop
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        # ...existing code...
        out = model(x)
        # Crop y to match out shape
        if out.shape != y.shape:
            # Compute cropping for each spatial dimension
            diff = [y.shape[i] - out.shape[i] for i in range(2, 5)]
            crop = [ (d // 2, d - d // 2) for d in diff ]
            y_cropped = y[
                :,
                :,
                crop[0][0]:y.shape[2]-crop[0][1] if crop[0][1] > 0 else y.shape[2],
                crop[1][0]:y.shape[3]-crop[1][1] if crop[1][1] > 0 else y.shape[3],
                crop[2][0]:y.shape[4]-crop[2][1] if crop[2][1] > 0 else y.shape[4],
            ]
        else:
            y_cropped = y
        loss = criterion(out, y_cropped)
# ...existing code...
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(loader):.4f}")


In [None]:

# Cell 8: Prediction and Visualization (middle slice)
model.eval()
with torch.no_grad():
    x, y = next(iter(loader))
    x = x.to(device)
    # ...existing code...
    pred = torch.sigmoid(model(x)).cpu().numpy()[0,0]
    flair = x.cpu().numpy()[0,1]
    mask = y.cpu().numpy()[0,0]
    # Crop mask to match pred shape
    if mask.shape != pred.shape:
        diff = [mask.shape[i] - pred.shape[i] for i in range(3)]
        crop = [ (d // 2, d - d // 2) for d in diff ]
        mask = mask[
            crop[0][0]:mask.shape[0]-crop[0][1] if crop[0][1] > 0 else mask.shape[0],
            crop[1][0]:mask.shape[1]-crop[1][1] if crop[1][1] > 0 else mask.shape[1],
            crop[2][0]:mask.shape[2]-crop[2][1] if crop[2][1] > 0 else mask.shape[2],
        ]
        flair = flair[
            crop[0][0]:flair.shape[0]-crop[0][1] if crop[0][1] > 0 else flair.shape[0],
            crop[1][0]:flair.shape[1]-crop[1][1] if crop[1][1] > 0 else flair.shape[1],
            crop[2][0]:flair.shape[2]-crop[2][1] if crop[2][1] > 0 else flair.shape[2],
        ]
# ...existing code...
    mid = flair.shape[2] // 2
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.title('FLAIR')
    plt.imshow(flair[:,:,mid], cmap='gray')
    plt.axis('off')
    plt.subplot(1,3,2)
    plt.title('Mask')
    plt.imshow(mask[:,:,mid], cmap='gray')
    plt.axis('off')
    plt.subplot(1,3,3)
    plt.title('Predicted')
    plt.imshow(pred[:,:,mid] > 0.5, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
    plt.show()