In [1]:
# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

In [2]:
# import tarfile  # Add this import statement
# import kagglehub
# dschettler8845_brats_2021_task1_path = kagglehub.dataset_download('dschettler8845/brats-2021-task1')

# print('Data source import complete.')


In [3]:

# zip_file = tarfile.open("/root/.cache/kagglehub/datasets/dschettler8845/brats-2021-task1/versions/1/BraTS2021_Training_Data.tar")
# zip_file.extractall("/data")
# zip_file.close()

# data preprocessing

In [1]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
from scipy import ndimage

class BraTS2021Dataset(Dataset):
    """BraTS 2021 dataset"""
    def __init__(self, data_dir, patch_size=(128, 128, 128), transform=None, seed=42):
        self.data_dir = data_dir
        self.patch_size = patch_size
        self.transform = transform
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        # Get subject folders - look for all directories in the data_dir
        print(f"Looking for data in: {data_dir}")
        subject_dirs = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))]
        
        if not subject_dirs:
            # Try looking for .nii.gz files directly
            nifti_files = glob.glob(os.path.join(data_dir, "**", "*.nii.gz"), recursive=True)
            if nifti_files:
                # Extract unique subject IDs from filenames
                subject_ids = set()
                for file in nifti_files:
                    basename = os.path.basename(file)
                    # Extract subject ID (assuming format like "BraTS2021_00000_t1.nii.gz")
                    subject_id = '_'.join(basename.split('_')[:-1])  # Remove modality suffix
                    if '_seg' in basename:  # Handle segmentation files
                        subject_id = subject_id.replace('_seg', '')
                    subject_ids.add(subject_id)
                
                # Create a list of subject paths
                self.subjects = [os.path.join(data_dir, subject_id) for subject_id in subject_ids]
                print(f"Found {len(self.subjects)} subjects based on .nii.gz files")
            else:
                raise ValueError(f"No subjects found in {data_dir}. Please check the data path and structure.")
        else:
            self.subjects = [os.path.join(data_dir, d) for d in subject_dirs]
            print(f"Found {len(self.subjects)} subject directories")
    
    def __len__(self):
        return len(self.subjects)
    
    def __getitem__(self, idx):
        subject_path = self.subjects[idx]
        subject_id = os.path.basename(subject_path)
        
        # Try to locate the MRI files
        def find_file(pattern):
            matches = glob.glob(os.path.join(self.data_dir, "**", pattern), recursive=True)
            if matches:
                return matches[0]
            return None
        
        # Try different possible file patterns
        t1_path = find_file(f"{subject_id}*t1.nii.gz") or find_file(f"*{subject_id}*t1.nii.gz")
        t1ce_path = find_file(f"{subject_id}*t1ce.nii.gz") or find_file(f"*{subject_id}*t1ce.nii.gz")
        t2_path = find_file(f"{subject_id}*t2.nii.gz") or find_file(f"*{subject_id}*t2.nii.gz")
        flair_path = find_file(f"{subject_id}*flair.nii.gz") or find_file(f"*{subject_id}*flair.nii.gz")
        seg_path = find_file(f"{subject_id}*seg.nii.gz") or find_file(f"*{subject_id}*seg.nii.gz")
        
        if not all([t1_path, t1ce_path, t2_path, flair_path, seg_path]):
            print(f"Missing files for subject {subject_id}. Found: {t1_path}, {t1ce_path}, {t2_path}, {flair_path}, {seg_path}")
            # Return a dummy sample if files are missing
            # This prevents crashes but you should check your data
            image = np.zeros((4, *self.patch_size), dtype=np.float32)
            mask = np.zeros((3, *self.patch_size), dtype=np.float32)
            return torch.from_numpy(image).float(), torch.from_numpy(mask).float()
        
        # Load data
        t1 = self.load_and_normalize(t1_path)
        t1ce = self.load_and_normalize(t1ce_path)
        t2 = self.load_and_normalize(t2_path)
        flair = self.load_and_normalize(flair_path)
        seg = self.load_nifti_volume(seg_path)
        
        # Stack modalities
        image = np.stack([t1, t1ce, t2, flair], axis=0)  # (4, H, W, D)
        
        # Create segmentation masks for the three tumor sub-regions
        mask_et = (seg == 4).astype(np.float32)
        mask_wt = ((seg == 1) | (seg == 2) | (seg == 4)).astype(np.float32)
        mask_tc = ((seg == 1) | (seg == 4)).astype(np.float32)
        
        mask = np.stack([mask_et, mask_wt, mask_tc], axis=0)  # (3, H, W, D)
        
        # Random crop for data augmentation
        image, mask = self.random_crop_3d(image, mask)
        
        # Apply augmentations
        if self.transform is not None:
            image, mask = self.transform(image, mask)
        
        return torch.from_numpy(image).float(), torch.from_numpy(mask).float()
    
    def load_nifti_volume(self, filepath):
        """Load a NIfTI volume"""
        try:
            nifti = nib.load(filepath)
            volume = nifti.get_fdata().astype(np.float32)
            return volume
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
            # Return an empty volume as placeholder
            return np.zeros((240, 240, 155), dtype=np.float32)
    
    def load_and_normalize(self, filepath):
        """Load and normalize a volume"""
        volume = self.load_nifti_volume(filepath)
        return self.normalize(volume)
    
    def normalize(self, volume):
        """Normalize volume to have zero mean and unit variance (only for non-zero voxels)"""
        mask = volume > 0
        if np.sum(mask) > 0:
            mean = np.mean(volume[mask])
            std = np.std(volume[mask])
            if std > 0:
                volume[mask] = (volume[mask] - mean) / std
        return volume
    
    def random_crop_3d(self, image, mask):
        """Random crop a 3D patch"""
        c, h, w, d = image.shape
        ph, pw, pd = self.patch_size
        
        # If image is smaller than patch size, pad it
        if h < ph or w < pw or d < pd:
            # Calculate padding sizes
            pad_h = max(0, ph - h)
            pad_w = max(0, pw - w)
            pad_d = max(0, pd - d)
            
            # Pad image and mask
            image_pad = np.pad(image, ((0, 0), (0, pad_h), (0, pad_w), (0, pad_d)), mode='constant')
            mask_pad = np.pad(mask, ((0, 0), (0, pad_h), (0, pad_w), (0, pad_d)), mode='constant')
            
            h, w, d = h + pad_h, w + pad_w, d + pad_d
            image = image_pad
            mask = mask_pad
        
        # Get random start indices
        h_start = np.random.randint(0, h - ph + 1)
        w_start = np.random.randint(0, w - pw + 1)
        d_start = np.random.randint(0, d - pd + 1)
        
        # Extract patch
        image_patch = image[:, h_start:h_start+ph, w_start:w_start+pw, d_start:d_start+pd]
        mask_patch = mask[:, h_start:h_start+ph, w_start:w_start+pw, d_start:d_start+pd]
        
        return image_patch, mask_patch


# Data augmentation classes
class RandomFlip:
    """Random flip augmentation"""
    def __init__(self, flip_prob=0.5):
        self.flip_prob = flip_prob
    
    def __call__(self, image, mask):
        # Flip along depth dimension (axis 3)
        if np.random.rand() < self.flip_prob:
            image = np.flip(image, axis=3).copy()
            mask = np.flip(mask, axis=3).copy()
        
        # Flip along height dimension (axis 1)
        if np.random.rand() < self.flip_prob:
            image = np.flip(image, axis=1).copy()
            mask = np.flip(mask, axis=1).copy()
        
        # Flip along width dimension (axis 2)
        if np.random.rand() < self.flip_prob:
            image = np.flip(image, axis=2).copy()
            mask = np.flip(mask, axis=2).copy()
        
        return image, mask

class RandomIntensityShift:
    """Random intensity shift augmentation"""
    def __init__(self, shift_range=(-0.1, 0.1)):
        self.shift_range = shift_range
    
    def __call__(self, image, mask):
        for i in range(image.shape[0]):
            shift = np.random.uniform(*self.shift_range)
            image[i] = image[i] + shift
        return image, mask

class RandomIntensityScale:
    """Random intensity scale augmentation"""
    def __init__(self, scale_range=(0.9, 1.1)):
        self.scale_range = scale_range
    
    def __call__(self, image, mask):
        for i in range(image.shape[0]):
            scale = np.random.uniform(*self.scale_range)
            image[i] = image[i] * scale
        return image, mask

class Compose:
    """Compose multiple transforms"""
    def __init__(self, transforms):
        self.transforms = transforms
    
    def __call__(self, image, mask):
        for transform in self.transforms:
            image, mask = transform(image, mask)
        return image, mask
def create_dataloaders(data_dir, batch_size=1, patch_size=(128, 128, 128), train_ratio=0.8, num_workers=4, seed=42):
    """Create train and validation dataloaders for BraTS 2021 dataset"""
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Create transforms
    train_transform = Compose([
        RandomFlip(flip_prob=0.5),
        RandomIntensityShift(shift_range=(-0.1, 0.1)),
        RandomIntensityScale(scale_range=(0.9, 1.1))
    ])
    
    # Create the dataset
    full_dataset = BraTS2021Dataset(
        data_dir=data_dir,
        patch_size=patch_size,
        transform=None,  # We'll apply transforms later for train/val
        seed=seed
    )
    
    # Split into train and validation
    dataset_size = len(full_dataset)
    train_size = int(train_ratio * dataset_size)
    val_size = dataset_size - train_size
    
    print(f"Total dataset size: {dataset_size}")
    print(f"Training set size: {train_size}")
    print(f"Validation set size: {val_size}")
    
    if dataset_size == 0:
        raise ValueError("Dataset is empty. Please check your data path and structure.")
    
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size], 
        generator=torch.Generator().manual_seed(seed)
    )
    
    # Create custom Dataset classes to apply different transforms
    class TransformDataset(Dataset):
        def __init__(self, dataset, transform=None):
            self.dataset = dataset
            self.transform = transform
            
        def __len__(self):
            return len(self.dataset)
            
        def __getitem__(self, idx):
            image, mask = self.dataset[idx]
            
            # Convert torch tensors back to numpy for transforms
            image_np = image.numpy()
            mask_np = mask.numpy()
            
            if self.transform:
                image_np, mask_np = self.transform(image_np, mask_np)
                
            return torch.from_numpy(image_np), torch.from_numpy(mask_np)
    
    # Apply transforms
    train_dataset_with_transform = TransformDataset(train_dataset, train_transform)
    val_dataset_with_transform = TransformDataset(val_dataset, None)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset_with_transform,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset_with_transform,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

# model architecture

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
import math
from typing import Sequence, Tuple, Type, Union

def ensure_tuple_rep(val, dim):
    """
    Ensures input is a tuple of length dim by repeating the value if needed
    """
    if isinstance(val, (list, tuple)):
        if len(val) == dim:
            return tuple(val)
        else:
            raise ValueError(f"Length of input {len(val)} doesn't match requested length {dim}")
    else:
        return tuple(val for _ in range(dim))

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    """Tensor initialization with truncated normal distribution"""
    def norm_cdf(x):
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.0))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    """Truncated normal distribution initialization"""
    if not std > 0:
        raise ValueError("Standard deviation should be greater than zero.")
    if a >= b:
        raise ValueError("Minimum cutoff value (a) should be smaller than maximum cutoff value (b).")
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

class DropPath(nn.Module):
    """Stochastic drop paths per sample for residual blocks"""
    def __init__(self, drop_prob=0.0, scale_by_keep=True):
        super().__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep
        
        if not (0 <= drop_prob <= 1):
            raise ValueError("Drop path prob should be between 0 and 1.")
            
    def drop_path(self, x, drop_prob=0.0, training=False, scale_by_keep=True):
        if drop_prob == 0.0 or not training:
            return x
        keep_prob = 1 - drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and scale_by_keep:
            random_tensor.div_(keep_prob)
        return x * random_tensor
        
    def forward(self, x):
        return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding implementation
    """
    def __init__(
        self, patch_size, in_chans, embed_dim, norm_layer=None, spatial_dims=3
    ):
        super().__init__()
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.spatial_dims = spatial_dims
        
        if spatial_dims == 3:
            self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        elif spatial_dims == 2:
            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        else:
            raise ValueError("spatial_dims must be 2 or 3")
            
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        x = self.proj(x)
        if self.norm is not None:
            if self.spatial_dims == 3:
                x = x.permute(0, 2, 3, 4, 1)
                x = self.norm(x)
                x = x.permute(0, 4, 1, 2, 3)
            else:
                x = x.permute(0, 2, 3, 1)
                x = self.norm(x)
                x = x.permute(0, 3, 1, 2)
        return x

class MLPBlock(nn.Module):
    """MLP Block"""
    def __init__(self, hidden_size, mlp_dim, dropout_rate=0.0, act="GELU", dropout_mode="swin"):
        super().__init__()
        if act == "GELU":
            self.activation = nn.GELU()
        elif act == "ReLU":
            self.activation = nn.ReLU()
        else:
            raise ValueError(f"Unknown activation: {act}")
            
        self.dropout_mode = dropout_mode
        self.linear1 = nn.Linear(hidden_size, mlp_dim)
        self.linear2 = nn.Linear(mlp_dim, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        if self.dropout_mode == "swin":
            x = self.dropout(x)
            x = self.linear2(x)
            x = self.dropout(x)
        else:
            x = self.linear2(x)
            x = self.dropout(x)
        return x

def window_partition(x, window_size):
    """
    Window partition operation
    Args:
        x: input tensor
        window_size: local window size
    """
    x_shape = x.size()
    if len(x_shape) == 5:
        b, d, h, w, c = x_shape
        x = x.view(
            b,
            d // window_size[0],
            window_size[0],
            h // window_size[1],
            window_size[1],
            w // window_size[2],
            window_size[2],
            c,
        )
        windows = (
            x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
        )
    elif len(x_shape) == 4:
        b, h, w, c = x.shape
        x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
    return windows

def window_reverse(windows, window_size, dims):
    """
    Window reverse operation
    Args:
        windows: windows tensor
        window_size: local window size
        dims: dimension values
    """
    if len(dims) == 4:
        b, d, h, w = dims
        x = windows.view(
            b,
            d // window_size[0],
            h // window_size[1],
            w // window_size[2],
            window_size[0],
            window_size[1],
            window_size[2],
            -1,
        )
        x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
    elif len(dims) == 3:
        b, h, w = dims
        x = windows.view(b, h // window_size[0], w // window_size[1], window_size[0], window_size[1], -1)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
    return x

def get_window_size(x_size, window_size, shift_size=None):
    """
    Computing window size based on input size
    """
    use_window_size = list(window_size)
    if shift_size is not None:
        use_shift_size = list(shift_size)
    for i in range(len(x_size)):
        if x_size[i] <= window_size[i]:
            use_window_size[i] = x_size[i]
            if shift_size is not None:
                use_shift_size[i] = 0

    if shift_size is None:
        return tuple(use_window_size)
    else:
        return tuple(use_window_size), tuple(use_shift_size)

def compute_mask(dims, window_size, shift_size, device):
    """
    Computing region masks
    """
    cnt = 0
    
    if len(dims) == 3:
        d, h, w = dims
        img_mask = torch.zeros((1, d, h, w, 1), device=device)
        for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
            for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
                for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
                    img_mask[:, d, h, w, :] = cnt
                    cnt += 1
    elif len(dims) == 2:
        h, w = dims
        img_mask = torch.zeros((1, h, w, 1), device=device)
        for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
            for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
                img_mask[:, h, w, :] = cnt
                cnt += 1
                
    mask_windows = window_partition(img_mask, window_size)
    mask_windows = mask_windows.squeeze(-1)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    
    return attn_mask

class WindowAttention(nn.Module):
    """
    Window based multi-head self attention
    """
    def __init__(
        self,
        dim,
        num_heads,
        window_size,
        qkv_bias=False,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        if len(self.window_size) == 3:
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros(
                    (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
                    num_heads,
                )
            )
            coords_d = torch.arange(self.window_size[0])
            coords_h = torch.arange(self.window_size[1])
            coords_w = torch.arange(self.window_size[2])
            coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
            coords_flatten = torch.flatten(coords, 1)
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
            relative_coords[:, :, 0] += self.window_size[0] - 1
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 2] += self.window_size[2] - 1
            relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
            relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
        elif len(self.window_size) == 2:
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
            )
            coords_h = torch.arange(self.window_size[0])
            coords_w = torch.arange(self.window_size[1])
            coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
            coords_flatten = torch.flatten(coords, 1)
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
            relative_coords[:, :, 0] += self.window_size[0] - 1
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
            
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        trunc_normal_(self.relative_position_bias_table, std=0.02)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x, mask=None):
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index[:n, :n].reshape(-1)
        ].reshape(n, n, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        
        if mask is not None:
            nw = mask.shape[0]
            attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, n, n)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
            
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block
    """
    def __init__(
        self,
        dim,
        num_heads,
        window_size,
        shift_size,
        mlp_ratio=4.0,
        qkv_bias=True,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer="GELU",
        norm_layer=nn.LayerNorm,
        use_checkpoint=False,
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.use_checkpoint = use_checkpoint
        
        if min(self.window_size) <= min(self.shift_size):
            self.shift_size = tuple(0 for i in self.window_size)
            
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim,
            window_size=self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLPBlock(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
        
    def forward_part1(self, x, mask_matrix):
        x_shape = x.size()
        if len(x_shape) == 5:
            b, d, h, w, c = x.shape
            window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
            pad_l = pad_t = pad_d0 = 0
            pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
            pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
            pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
            _, dp, hp, wp, _ = x.shape
            dims = [b, dp, hp, wp]
        elif len(x_shape) == 4:
            b, h, w, c = x.shape
            window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
            pad_l = pad_t = 0
            pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
            pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
            _, hp, wp, _ = x.shape
            dims = [b, hp, wp]
            
        if any(i > 0 for i in shift_size):
            if len(x_shape) == 5:
                shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
            elif len(x_shape) == 4:
                shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None
            
        x_windows = window_partition(shifted_x, window_size)
        attn_windows = self.attn(x_windows, mask=attn_mask)
        attn_windows = attn_windows.view(-1, *(window_size + (c,)))
        shifted_x = window_reverse(attn_windows, window_size, dims)
        
        if any(i > 0 for i in shift_size):
            if len(x_shape) == 5:
                x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
            elif len(x_shape) == 4:
                x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
        else:
            x = shifted_x
            
        if len(x_shape) == 5:
            if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
                x = x[:, :d, :h, :w, :].contiguous()
        elif len(x_shape) == 4:
            if pad_r > 0 or pad_b > 0:
                x = x[:, :h, :w, :].contiguous()
                
        return x
        
    def forward_part2(self, x):
        return self.drop_path(self.mlp(self.norm2(x)))
        
    def forward(self, x, mask_matrix):
        shortcut = x
        if self.use_checkpoint:
            x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
        else:
            x = self.forward_part1(x, mask_matrix)
        x = shortcut + self.drop_path(x)
        if self.use_checkpoint:
            x = x + checkpoint.checkpoint(self.forward_part2, x)
        else:
            x = x + self.forward_part2(x)
        return x

class PatchMerging(nn.Module):
    """
    Patch merging layer
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm, spatial_dims=3):
        super().__init__()
        self.dim = dim
        if spatial_dims == 3:
            self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
            self.norm = norm_layer(8 * dim)
        elif spatial_dims == 2:
            self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
            self.norm = norm_layer(4 * dim)
            
    def forward(self, x):
        x_shape = x.size()
        if len(x_shape) == 5:
            b, d, h, w, c = x_shape
            pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
            if pad_input:
                x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
            x0 = x[:, 0::2, 0::2, 0::2, :]
            x1 = x[:, 1::2, 0::2, 0::2, :]
            x2 = x[:, 0::2, 1::2, 0::2, :]
            x3 = x[:, 0::2, 0::2, 1::2, :]
            x4 = x[:, 1::2, 0::2, 1::2, :]
            x5 = x[:, 0::2, 1::2, 0::2, :]
            x6 = x[:, 0::2, 0::2, 1::2, :]
            x7 = x[:, 1::2, 1::2, 1::2, :]
            x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
        elif len(x_shape) == 4:
            b, h, w, c = x_shape
            pad_input = (h % 2 == 1) or (w % 2 == 1)
            if pad_input:
                x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
            x0 = x[:, 0::2, 0::2, :]
            x1 = x[:, 1::2, 0::2, :]
            x2 = x[:, 0::2, 1::2, :]
            x3 = x[:, 1::2, 1::2, :]
            x = torch.cat([x0, x1, x2, x3], -1)
            
        x = self.norm(x)
        x = self.reduction(x)
        return x

class BasicLayer(nn.Module):
    """
    Basic Swin Transformer layer in one stage
    """
    def __init__(
        self,
        dim,
        depth,
        num_heads,
        window_size,
        drop_path,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        norm_layer=nn.LayerNorm,
        downsample=None,
        use_checkpoint=False,
    ):
        super().__init__()
        self.window_size = window_size
        self.shift_size = tuple(i // 2 for i in window_size)
        self.no_shift = tuple(0 for i in window_size)
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock(
                    dim=dim,
                    num_heads=num_heads,
                    window_size=self.window_size,
                    shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer,
                    use_checkpoint=use_checkpoint,
                )
                for i in range(depth)
            ]
        )
        
        self.downsample = downsample
        if self.downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
            
    def forward(self, x):
        x_shape = x.size()
        if len(x_shape) == 5:
            b, c, d, h, w = x_shape
            window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
            x = x.permute(0, 2, 3, 4, 1)  # (B, D, H, W, C)
            dp = int(np.ceil(d / window_size[0])) * window_size[0]
            hp = int(np.ceil(h / window_size[1])) * window_size[1]
            wp = int(np.ceil(w / window_size[2])) * window_size[2]
            attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
            
            for blk in self.blocks:
                x = blk(x, attn_mask)
                
            x = x.view(b, d, h, w, -1)
            if self.downsample is not None:
                x = self.downsample(x)
            x = x.permute(0, 4, 1, 2, 3)  # (B, C, D, H, W)
            
        elif len(x_shape) == 4:
            b, c, h, w = x_shape
            window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
            x = x.permute(0, 2, 3, 1)  # (B, H, W, C)
            hp = int(np.ceil(h / window_size[0])) * window_size[0]
            wp = int(np.ceil(w / window_size[1])) * window_size[1]
            attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
            
            for blk in self.blocks:
                x = blk(x, attn_mask)
                
            x = x.view(b, h, w, -1)
            if self.downsample is not None:
                x = self.downsample(x)
            x = x.permute(0, 3, 1, 2)  # (B, C, H, W)
            
        return x

class SwinTransformer(nn.Module):
    """
    Swin Transformer backbone
    """
    def __init__(
        self,
        in_chans,
        embed_dim,
        window_size,
        patch_size,
        depths,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        patch_norm=False,
        use_checkpoint=False,
        spatial_dims=3,
    ):
        super().__init__()
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.window_size = window_size
        self.patch_size = patch_size
        
        self.patch_embed = PatchEmbed(
            patch_size=self.patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None,
            spatial_dims=spatial_dims,
        )
        
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        
        self.layers1 = nn.ModuleList()
        self.layers2 = nn.ModuleList()
        self.layers3 = nn.ModuleList()
        self.layers4 = nn.ModuleList()
        
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=self.window_size,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                downsample=PatchMerging,
                use_checkpoint=use_checkpoint,
            )
            
            if i_layer == 0:
                self.layers1.append(layer)
            elif i_layer == 1:
                self.layers2.append(layer)
            elif i_layer == 2:
                self.layers3.append(layer)
            elif i_layer == 3:
                self.layers4.append(layer)
                
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        
    def proj_out(self, x, normalize=False):
        if normalize:
            x_shape = x.size()
            if len(x_shape) == 5:
                n, ch, d, h, w = x_shape
                x = x.permute(0, 2, 3, 4, 1)  # (n, d, h, w, ch)
                x = F.layer_norm(x, [ch])
                x = x.permute(0, 4, 1, 2, 3)  # (n, ch, d, h, w)
            elif len(x_shape) == 4:
                n, ch, h, w = x_shape
                x = x.permute(0, 2, 3, 1)  # (n, h, w, ch)
                x = F.layer_norm(x, [ch])
                x = x.permute(0, 3, 1, 2)  # (n, ch, h, w)
        return x
        
    def forward(self, x, normalize=True):
        x0 = self.patch_embed(x)
        x0 = self.pos_drop(x0)
        x0_out = self.proj_out(x0, normalize)
        x1 = self.layers1[0](x0.contiguous())
        x1_out = self.proj_out(x1, normalize)
        x2 = self.layers2[0](x1.contiguous())
        x2_out = self.proj_out(x2, normalize)
        x3 = self.layers3[0](x2.contiguous())
        x3_out = self.proj_out(x3, normalize)
        x4 = self.layers4[0](x3.contiguous())
        x4_out = self.proj_out(x4, normalize)
        
        return [x0_out, x1_out, x2_out, x3_out, x4_out]

class UnetrBasicBlock(nn.Module):
    """
    A basic block for UNETR with debug logs
    """
    def __init__(
        self,
        spatial_dims,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        norm_name="instance",
        res_block=True,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        if norm_name == "instance":
            if spatial_dims == 3:
                self.norm = nn.InstanceNorm3d(out_channels)
            elif spatial_dims == 2:
                self.norm = nn.InstanceNorm2d(out_channels)
        elif norm_name == "batch":
            if spatial_dims == 3:
                self.norm = nn.BatchNorm3d(out_channels)
            elif spatial_dims == 2:
                self.norm = nn.BatchNorm2d(out_channels)
        else:
            raise ValueError(f"Unsupported normalization: {norm_name}")
            
        if spatial_dims == 3:
            self.conv1 = nn.Conv3d(
                in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2
            )
            if res_block:
                self.conv2 = nn.Conv3d(
                    out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2
                )
                self.conv3 = nn.Conv3d(
                    out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2
                )
                
            if stride > 1:
                self.residual = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride)
            elif in_channels != out_channels:
                self.residual = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1)
            else:
                self.residual = nn.Identity()
        elif spatial_dims == 2:
            self.conv1 = nn.Conv2d(
                in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2
            )
            if res_block:
                self.conv2 = nn.Conv2d(
                    out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2
                )
                self.conv3 = nn.Conv2d(
                    out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2
                )
                
            if stride > 1:
                self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
            elif in_channels != out_channels:
                self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            else:
                self.residual = nn.Identity()
                
        self.res_block = res_block
        self.activation = nn.ReLU(inplace=True)
        
    def forward(self, x):
        # print(f"UnetrBasicBlock - Input shape: {x.shape}, in_channels: {self.in_channels}, out_channels: {self.out_channels}")
        
        res = self.residual(x)
        # print(f"UnetrBasicBlock - After residual shape: {res.shape}")
        
        x = self.conv1(x)
        # print(f"UnetrBasicBlock - After conv1 shape: {x.shape}")
        
        x = self.norm(x)
        x = self.activation(x)
        
        if self.res_block:
            x = self.conv2(x)
            # print(f"UnetrBasicBlock - After conv2 shape: {x.shape}")
            x = self.norm(x)
            x = self.activation(x)
            
            x = self.conv3(x)
            # print(f"UnetrBasicBlock - After conv3 shape: {x.shape}")
            x = self.norm(x)
            
        x = x + res
        x = self.activation(x)
        # print(f"UnetrBasicBlock - Output shape: {x.shape}")
        return x
class UnetrUpBlock(nn.Module):
    """
    An upsampling module for UNETR with debug logs
    """
    def __init__(
        self,
        spatial_dims,
        in_channels,
        out_channels,
        kernel_size,
        upsample_kernel_size,
        norm_name="instance",
        res_block=True,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        if spatial_dims == 3:
            self.transp_conv = nn.ConvTranspose3d(
                in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_kernel_size
            )
        elif spatial_dims == 2:
            self.transp_conv = nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_kernel_size
            )
            
        # The important fix: we need to account for the concatenated channels
        # Previously we used in_channels + out_channels which led to a channel mismatch
        self.conv_block = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=out_channels * 2,  # This is after concatenation
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        
    def forward(self, x, skip):
        # Add debug info
        # print(f"UnetrUpBlock - Input x shape: {x.shape}, skip shape: {skip.shape}")
        
        x = self.transp_conv(x)
        # print(f"UnetrUpBlock - After transpose conv shape: {x.shape}")
        
        x = torch.cat((x, skip), dim=1)
        # print(f"UnetrUpBlock - After concatenation shape: {x.shape}")
        
        x = self.conv_block(x)
        # print(f"UnetrUpBlock - Output shape: {x.shape}")
        return x
class UnetOutBlock(nn.Module):
    """
    A final output block for UNETR
    """
    def __init__(self, spatial_dims, in_channels, out_channels):
        super().__init__()
        
        if spatial_dims == 3:
            self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        elif spatial_dims == 2:
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
            
    def forward(self, x):
        return self.conv(x)
class SwinUNETR(nn.Module):
    """
    Swin UNETR based on: "Hatamizadeh et al.,
    Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
    <https://arxiv.org/abs/2201.01266>"
    With added debug logging
    """
    def __init__(
        self,
        img_size,
        in_channels,
        out_channels,
        depths=(2, 2, 2, 2),
        num_heads=(3, 6, 12, 24),
        feature_size=24,
        norm_name="instance",
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        normalize=True,
        use_checkpoint=False,
        spatial_dims=3,
    ):
        super().__init__()
        
        # Print configuration for debugging
        # print(f"SwinUNETR - Initializing with img_size={img_size}, in_channels={in_channels}, out_channels={out_channels}")
        # print(f"SwinUNETR - feature_size={feature_size}, depths={depths}, num_heads={num_heads}")
        
        img_size = ensure_tuple_rep(img_size, spatial_dims)
        patch_size = ensure_tuple_rep(2, spatial_dims)
        window_size = ensure_tuple_rep(7, spatial_dims)
        
        if not (spatial_dims == 2 or spatial_dims == 3):
            raise ValueError("spatial dimension should be 2 or 3.")
            
        for m, p in zip(img_size, patch_size):
            for i in range(5):
                if m % np.power(p, i + 1) != 0:
                    raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
                    
        if not (0 <= drop_rate <= 1):
            raise ValueError("dropout rate should be between 0 and 1.")
            
        if not (0 <= attn_drop_rate <= 1):
            raise ValueError("attention dropout rate should be between 0 and 1.")
            
        if not (0 <= dropout_path_rate <= 1):
            raise ValueError("drop path rate should be between 0 and 1.")
            
        if feature_size % 12 != 0:
            raise ValueError("feature_size should be divisible by 12.")
            
        self.normalize = normalize
        
        self.swinViT = SwinTransformer(
            in_chans=in_channels,
            embed_dim=feature_size,
            window_size=window_size,
            patch_size=patch_size,
            depths=depths,
            num_heads=num_heads,
            mlp_ratio=4.0,
            qkv_bias=True,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=dropout_path_rate,
            norm_layer=nn.LayerNorm,
            use_checkpoint=use_checkpoint,
            spatial_dims=spatial_dims,
        )
        
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.encoder2 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.encoder3 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=2 * feature_size,
            out_channels=2 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.encoder4 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=4 * feature_size,
            out_channels=4 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.encoder10 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=16 * feature_size,
            out_channels=16 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )
        
        # Fix the dimensions for decoder blocks
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=16 * feature_size,
            out_channels=8 * feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=8 * feature_size,  # Changed from 'feature_size * 8'
            out_channels=4 * feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=4 * feature_size,
            out_channels=2 * feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=2 * feature_size,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.decoder1 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.out = UnetOutBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=out_channels
        )
        
    def forward(self, x_in):
        # Print input shape
        # print(f"SwinUNETR - Input shape: {x_in.shape}")
        
        hidden_states_out = self.swinViT(x_in, self.normalize)
        
        # Print shapes for debugging
        # for i, hidden_state in enumerate(hidden_states_out):
            # print(f"SwinUNETR - Hidden state {i} shape: {hidden_state.shape}")
        
        enc0 = self.encoder1(x_in)
        # print(f"SwinUNETR - enc0 shape: {enc0.shape}")
        
        enc1 = self.encoder2(hidden_states_out[0])
        # print(f"SwinUNETR - enc1 shape: {enc1.shape}")
        
        enc2 = self.encoder3(hidden_states_out[1])
        # print(f"SwinUNETR - enc2 shape: {enc2.shape}")
        
        enc3 = self.encoder4(hidden_states_out[2])
        # print(f"SwinUNETR - enc3 shape: {enc3.shape}")
        
        dec4 = self.encoder10(hidden_states_out[4])
        # print(f"SwinUNETR - dec4 (bottleneck) shape: {dec4.shape}")
        
        dec3 = self.decoder5(dec4, hidden_states_out[3])
        # print(f"SwinUNETR - dec3 shape: {dec3.shape}")
        
        dec2 = self.decoder4(dec3, enc3)
        # print(f"SwinUNETR - dec2 shape: {dec2.shape}")
        
        dec1 = self.decoder3(dec2, enc2)
        # print(f"SwinUNETR - dec1 shape: {dec1.shape}")
        
        dec0 = self.decoder2(dec1, enc1)
        # print(f"SwinUNETR - dec0 shape: {dec0.shape}")
        
        out = self.decoder1(dec0, enc0)
        # print(f"SwinUNETR - Final decoder output shape: {out.shape}")
        
        logits = self.out(out)
        # print(f"SwinUNETR - Final output shape: {logits.shape}")
        
        return logits

# metrics

In [3]:
def dice_coefficient_for_metrics(y_pred, y_true, smooth=1e-5, ignore_empty=True):
    """
    Calculate Dice coefficient for metrics reporting (with thresholding)
    This version detaches gradients and is used for metrics, not loss
    """
    # Apply sigmoid and threshold - note this breaks gradients but is OK for metrics
    y_pred = (torch.sigmoid(y_pred) > 0.5).float()
    
    # Compute for each class
    batch_size = y_pred.shape[0]
    num_classes = y_pred.shape[1]
    dice_scores = torch.zeros(batch_size, num_classes, device=y_pred.device)
    
    for b in range(batch_size):
        for c in range(num_classes):
            # Get current class
            pred = y_pred[b, c, ...].bool()  # Binary prediction
            true = y_true[b, c, ...].bool()  # Binary ground truth
            
            # Count positive pixels
            true_sum = torch.sum(true)
            
            if true_sum > 0:
                # Compute intersection (only count pixels where both are True)
                intersection = torch.sum(torch.masked_select(true, pred))
                # Compute Dice
                dice_scores[b, c] = (2.0 * intersection) / (true_sum + torch.sum(pred))
            else:
                # Handle empty ground truth
                if ignore_empty:
                    dice_scores[b, c] = float('nan')
                else:
                    pred_sum = torch.sum(pred)
                    if pred_sum <= 0:
                        dice_scores[b, c] = 1.0  # Both empty - perfect match
                    else:
                        dice_scores[b, c] = 0.0  # Empty ground truth but prediction has values
    
    return dice_scores

def dice_loss_function(y_pred, y_true, smooth=1e-5):
    """
    Differentiable Dice loss function for training
    Uses soft Dice which preserves gradients
    """
    # Apply sigmoid but NO thresholding to maintain gradients
    y_pred_sigmoid = torch.sigmoid(y_pred)
    
    # Reshape for reduction
    batch_size = y_pred.shape[0]
    num_classes = y_pred.shape[1]
    
    # Flatten prediction and target tensors for easier operations
    y_pred_flat = y_pred_sigmoid.view(batch_size, num_classes, -1)
    y_true_flat = y_true.view(batch_size, num_classes, -1)
    
    # Calculate intersection and union using differentiable operations
    intersection = torch.sum(y_pred_flat * y_true_flat, dim=2)
    pred_sum = torch.sum(y_pred_flat, dim=2)
    true_sum = torch.sum(y_true_flat, dim=2)
    
    # Calculate dice coefficient with smoothing
    dice = (2. * intersection + smooth) / (pred_sum + true_sum + smooth)
    
    # Convert to loss (1 - Dice)
    loss = 1.0 - dice
    
    # Handle empty ground truth cases
    empty_gt = (true_sum == 0)
    if torch.any(empty_gt):
        # Set loss to 0 for empty ground truth if prediction is also empty
        empty_pred = (pred_sum < smooth)
        loss[empty_gt & empty_pred] = 0.0
    
    # Return mean loss
    return torch.mean(loss)

def multiclass_dice_loss(y_pred, y_true, smooth=1e-5):
    """
    Multiclass Dice loss that preserves gradients for backpropagation
    """
    return dice_loss_function(y_pred, y_true, smooth)

def multiclass_dice_coefficient(y_pred, y_true, smooth=1e-5):
    """
    Calculate average Dice coefficient across all classes for metrics reporting
    """
    # Calculate per-class Dice (using the non-differentiable version for metrics)
    dice_scores = dice_coefficient_for_metrics(y_pred, y_true, smooth)
    
    # Average over batches for each class
    class_scores = torch.nanmean(dice_scores, dim=0)  # Shape: [C]
    
    # Average over classes for mean Dice
    overall_dice = torch.nanmean(class_scores)
    
    # Replace NaNs with 1.0 for printing (assuming perfect score for empty regions)
    display_scores = class_scores.clone()
    display_scores = torch.where(torch.isnan(display_scores), torch.ones_like(display_scores), display_scores)
    
    et_score = display_scores[0].item()
    wt_score = display_scores[1].item()
    tc_score = display_scores[2].item()
    
    # For overall display, use the nanmean but replace NaN with 1.0
    overall_display = overall_dice.item() if not torch.isnan(overall_dice) else 1.0
    
    print(f"Class Dice - ET: {et_score:.4f}, WT: {wt_score:.4f}, TC: {tc_score:.4f}, Avg: {overall_display:.4f}")
    
    return overall_dice

def class_wise_dice_coefficient(y_pred, y_true, smooth=1e-5):
    """
    Calculate class-wise Dice scores for reporting
    """
    # Calculate per-class Dice (using the non-differentiable version for metrics)
    dice_scores = dice_coefficient_for_metrics(y_pred, y_true, smooth)
    
    # Average over batches for each class
    class_scores = torch.nanmean(dice_scores, dim=0)  # Shape: [C]
    
    # Replace NaNs with 1.0 for class-wise metrics
    class_scores = torch.where(torch.isnan(class_scores), torch.ones_like(class_scores), class_scores)
    
    # Convert to list
    return class_scores.cpu().detach().tolist()

# model training

In [4]:
def train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=800, save_dir='checkpoints'):
    """Train the Swin UNETR model with class-wise Dice score tracking"""


    import time
    import os
    import math
    import torch
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Create save directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    best_val_loss = float('inf')
    best_val_dice = 0.0
    
    # For tracking metrics history
    history = {
        'train_loss': [], 'train_dice': [], 'train_dice_et': [], 'train_dice_wt': [], 'train_dice_tc': [],
        'val_loss': [], 'val_dice': [], 'val_dice_et': [], 'val_dice_wt': [], 'val_dice_tc': []
    }
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        train_dice = 0
        train_class_dice = {'ET': 0, 'WT': 0, 'TC': 0}
        
        # For handling NaN values properly
        train_dice_values = []
        train_class_dice_values = {'ET': [], 'WT': [], 'TC': []}
        
        train_start = time.time()
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = multiclass_dice_loss(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Calculate Dice score (detach to avoid unnecessary gradient computation)
            with torch.no_grad():
                dice_score = multiclass_dice_coefficient(outputs.detach(), targets)
                
                # Calculate class-wise Dice scores (ET, WT, TC)
                class_scores = class_wise_dice_coefficient(outputs.detach(), targets)
                
                # Store values for proper averaging (handling NaN)
                if not torch.isnan(dice_score):
                    train_dice_values.append(dice_score.item())
                
                train_class_dice_values['ET'].append(class_scores[0])
                train_class_dice_values['WT'].append(class_scores[1])
                train_class_dice_values['TC'].append(class_scores[2])
            
            train_loss += loss.item()
            train_dice += dice_score.item() if not torch.isnan(dice_score) else 0
            
            # For tracking running average (needed for batch display)
            train_class_dice['ET'] += class_scores[0] if not math.isnan(class_scores[0]) else 0
            train_class_dice['WT'] += class_scores[1] if not math.isnan(class_scores[1]) else 0
            train_class_dice['TC'] += class_scores[2] if not math.isnan(class_scores[2]) else 0
            
            # Print status
            if (batch_idx + 1) % 10 == 0:
                batch_dice = dice_score.item() if not torch.isnan(dice_score) else 0
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, '
                      f'Loss: {loss.item():.4f}, Dice: {batch_dice:.4f}')
        
        train_time = time.time() - train_start
        train_loss /= len(train_loader)
        
        # Properly average Dice scores ignoring NaNs
        train_dice = float(torch.nanmean(torch.tensor(train_dice_values))) if train_dice_values else 0.0
        
        # Average class-wise scores properly
        for key in train_class_dice:
            values = [v for v in train_class_dice_values[key] if not math.isnan(v)]
            train_class_dice[key] = sum(values) / len(values) if values else 1.0  # Use 1.0 for empty regions
        
        # Validation
        model.eval()
        val_loss = 0
        val_dice = 0
        val_class_dice = {'ET': 0, 'WT': 0, 'TC': 0}
        
        # For handling NaN values properly
        val_dice_values = []
        val_class_dice_values = {'ET': [], 'WT': [], 'TC': []}
        
        val_start = time.time()
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = multiclass_dice_loss(outputs, targets)
                dice_score = multiclass_dice_coefficient(outputs, targets)
                
                # Calculate class-wise Dice scores (ET, WT, TC)
                class_scores = class_wise_dice_coefficient(outputs, targets)
                
                # Store values for proper averaging (handling NaN)
                if not torch.isnan(dice_score):
                    val_dice_values.append(dice_score.item())
                
                val_class_dice_values['ET'].append(class_scores[0])
                val_class_dice_values['WT'].append(class_scores[1])
                val_class_dice_values['TC'].append(class_scores[2])
                
                val_loss += loss.item()
                val_dice += dice_score.item() if not torch.isnan(dice_score) else 0
                
                # For tracking running average
                val_class_dice['ET'] += class_scores[0] if not math.isnan(class_scores[0]) else 0
                val_class_dice['WT'] += class_scores[1] if not math.isnan(class_scores[1]) else 0
                val_class_dice['TC'] += class_scores[2] if not math.isnan(class_scores[2]) else 0
        
        val_time = time.time() - val_start
        val_loss /= len(val_loader)
        
        # Properly average Dice scores ignoring NaNs
        val_dice = float(torch.nanmean(torch.tensor(val_dice_values))) if val_dice_values else 0.0
        
        # Average class-wise scores properly
        for key in val_class_dice:
            values = [v for v in val_class_dice_values[key] if not math.isnan(v)]
            val_class_dice[key] = sum(values) / len(values) if values else 1.0  # Use 1.0 for empty regions
        
        # Update scheduler
        scheduler.step()
        
        # Print status with class-wise scores
        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, '
              f'Train ET: {train_class_dice["ET"]:.4f}, Train WT: {train_class_dice["WT"]:.4f}, Train TC: {train_class_dice["TC"]:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, '
              f'Val ET: {val_class_dice["ET"]:.4f}, Val WT: {val_class_dice["WT"]:.4f}, Val TC: {val_class_dice["TC"]:.4f}, '
              f'Train Time: {train_time:.2f}s, Val Time: {val_time:.2f}s')
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_dice'].append(train_dice)
        history['train_dice_et'].append(train_class_dice['ET'])
        history['train_dice_wt'].append(train_class_dice['WT'])
        history['train_dice_tc'].append(train_class_dice['TC'])
        history['val_loss'].append(val_loss)
        history['val_dice'].append(val_dice)
        history['val_dice_et'].append(val_class_dice['ET'])
        history['val_dice_wt'].append(val_class_dice['WT'])
        history['val_dice_tc'].append(val_class_dice['TC'])
        
        # Save best model based on validation Dice score
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            model_path = os.path.join(save_dir, f'swin_unetr_best_dice.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_dice': train_dice,
                'val_dice': val_dice,
                'train_class_dice': train_class_dice,
                'val_class_dice': val_class_dice,
                'best_val_dice': best_val_dice,
                'history': history
            }, model_path)
            print(f'Saved best model with validation Dice score: {best_val_dice:.4f}')
        
        # Also save based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model_path = os.path.join(save_dir, f'swin_unetr_best_loss.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_dice': train_dice,
                'val_dice': val_dice,
                'train_class_dice': train_class_dice,
                'val_class_dice': val_class_dice,
                'best_val_loss': best_val_loss,
                'history': history
            }, model_path)
            print(f'Saved best model with validation loss: {best_val_loss:.4f}')
        
        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            model_path = os.path.join(save_dir, f'swin_unetr_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_dice': train_dice,
                'val_dice': val_dice,
                'train_class_dice': train_class_dice,
                'val_class_dice': val_class_dice,
                'best_val_loss': best_val_loss,
                'best_val_dice': best_val_dice,
                'history': history
            }, model_path)
    
    # Save final history for plotting
    import json
    # Convert float values to be JSON serializable
    history_json = {k: [float(val) for val in v] for k, v in history.items()}
    with open(os.path.join(save_dir, 'training_history.json'), 'w') as f:
        json.dump(history_json, f)
    
    return model, history

In [5]:
def validate_model_config(img_size, feature_size, depths, num_heads):
    """
    Validates the model configuration to catch potential issues early
    """
    # Convert img_size to tuple if it's not already
    if not isinstance(img_size, tuple):
        img_size = tuple(ensure_tuple_rep(img_size, 3))
    
    patch_size = (2, 2, 2)
    
    # Check divisibility of image size by 2^(num_layers+1)
    max_stages = len(depths) + 1  # +1 for initial patch embedding
    for dim, ps in zip(img_size, patch_size):
        for i in range(max_stages):
            if dim % (ps ** i) != 0:
                return False, f"Image dimension {dim} is not divisible by {ps}^{i}"
    
    # Check feature size divisibility
    if feature_size % 12 != 0:
        return False, f"Feature size {feature_size} is not divisible by 12"
    
    # Check lengths
    if len(depths) != len(num_heads):
        return False, f"Length of depths {len(depths)} doesn't match length of num_heads {len(num_heads)}"
    
    # Calculate expected tensor shapes
    shapes = []
    curr_shape = list(img_size)
    
    # Initial patch embedding divides dimensions by patch_size
    for i in range(3):
        curr_shape[i] = curr_shape[i] // patch_size[i]
    
    shapes.append(curr_shape.copy())
    
    # For each stage, dimensions are halved
    for _ in range(len(depths)):
        for i in range(3):
            curr_shape[i] = curr_shape[i] // 2
        shapes.append(curr_shape.copy())
    
    print("Expected feature map shapes:")
    for i, shape in enumerate(shapes):
        feature_channels = feature_size * (2 ** min(i, len(depths)-1))
        # print(f"Stage {i}: {shape} with {feature_channels} channels")
    
    return True, "Configuration is valid"
    
def main():
    # Parameters
    batch_size = 2
    patch_size = (128,128,128)
    feature_size = 48
    depths = (2, 2, 2, 2)
    num_heads = (3, 6, 12, 24)
    lr = 0.001
    num_epochs = 10
    data_dir = "workspace/data"
    save_dir = "checkpoints_1"
    train_ratio = 0.8
    
    # Validate model configuration
    valid, message = validate_model_config(patch_size, feature_size, depths, num_heads)
    if not valid:
        print(f"ERROR - Invalid model configuration: {message}")
        print("Please adjust the model parameters before continuing.")
        return
    else:
        print(f"Model configuration validated: {message}")
    
    # Create dataloaders
    train_loader, val_loader = create_dataloaders(
        data_dir=data_dir, 
        batch_size=batch_size, 
        patch_size=patch_size, 
        train_ratio=train_ratio
    )
    
    # Create model
    model = SwinUNETR(
        img_size=patch_size,
        in_channels=4,  # T1, T1ce, T2, FLAIR
        out_channels=3,  # ET, WT, TC
        feature_size=feature_size,
        depths=depths,
        num_heads=num_heads,
        norm_name="instance",
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        normalize=True,
        use_checkpoint=False,
        spatial_dims=3
    )
    
    # Print total parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total model parameters: {total_params:,}")
    
    # Create optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    # Train model
    model,history = train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs, save_dir)
    
    print("Training complete!")

In [None]:
torch.multiprocessing.set_sharing_strategy('file_system')
if __name__ == "__main__":
    main()

Expected feature map shapes:
Model configuration validated: Configuration is valid
Looking for data in: workspace/data
Found 1251 subject directories
Total dataset size: 1251
Training set size: 1000
Validation set size: 251
Total model parameters: 84,842,997
Class Dice - ET: 0.0153, WT: 0.0985, TC: 0.0422, Avg: 0.0520
Class Dice - ET: 0.0291, WT: 0.1958, TC: 0.0714, Avg: 0.0987
Class Dice - ET: 0.0025, WT: 0.1233, TC: 0.0116, Avg: 0.0458
Class Dice - ET: 0.0053, WT: 0.2806, TC: 0.0176, Avg: 0.1012
Class Dice - ET: 0.0300, WT: 0.5807, TC: 0.1050, Avg: 0.2386
Class Dice - ET: 0.0123, WT: 0.3417, TC: 0.0744, Avg: 0.1428
Class Dice - ET: 0.0098, WT: 0.1862, TC: 0.0379, Avg: 0.0780
Class Dice - ET: 0.0000, WT: 0.2358, TC: 0.0002, Avg: 0.0787


In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

def visualize_training_history(history_path):
    """
    Load and visualize training history from JSON file
    
    Args:
        history_path: Path to the training_history.json file
    """
    # Load the history
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    # Create figure with subplots
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot loss
    axs[0, 0].plot(history['train_loss'], label='Train Loss')
    axs[0, 0].plot(history['val_loss'], label='Val Loss')
    axs[0, 0].set_title('Loss')
    axs[0, 0].set_xlabel('Epoch')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend()
    
    # Plot average Dice score
    axs[0, 1].plot(history['train_dice'], label='Train Dice')
    axs[0, 1].plot(history['val_dice'], label='Val Dice')
    axs[0, 1].set_title('Average Dice Score')
    axs[0, 1].set_xlabel('Epoch')
    axs[0, 1].set_ylabel('Dice Score')
    axs[0, 1].legend()
    
    # Plot training class-wise Dice scores
    axs[1, 0].plot(history['train_dice_et'], label='ET')
    axs[1, 0].plot(history['train_dice_wt'], label='WT')
    axs[1, 0].plot(history['train_dice_tc'], label='TC')
    axs[1, 0].set_title('Training Class-wise Dice Scores')
    axs[1, 0].set_xlabel('Epoch')
    axs[1, 0].set_ylabel('Dice Score')
    axs[1, 0].legend()
    
    # Plot validation class-wise Dice scores
    axs[1, 1].plot(history['val_dice_et'], label='ET')
    axs[1, 1].plot(history['val_dice_wt'], label='WT')
    axs[1, 1].plot(history['val_dice_tc'], label='TC')
    axs[1, 1].set_title('Validation Class-wise Dice Scores')
    axs[1, 1].set_xlabel('Epoch')
    axs[1, 1].set_ylabel('Dice Score')
    axs[1, 1].legend()
    
    plt.tight_layout()
    plt.savefig('training_history_plot.png')
    plt.show()

# Example usage
visualize_training_history('/workspace/checkpoints/training_history.json')

In [None]:
import torch
import os
import numpy as np
from tqdm import tqdm

def load_and_test_model(model_path, val_loader, device=None):
    """
    Load a trained model and test it on validation data
    
    Args:
        model_path: Path to the saved model checkpoint
        val_loader: Validation data loader
        device: Device to run the model on (default: None, will use CUDA if available)
    
    Returns:
        Dictionary containing evaluation metrics
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"Using device: {device}")
    
    # Load the model
    checkpoint = torch.load(model_path, map_location=device)
    
    # Print training info from checkpoint
    print(f"Model trained for {checkpoint['epoch']} epochs")
    print(f"Best validation Dice score: {checkpoint.get('best_val_dice', 'N/A')}")
    
    # Create a new model instance
     # Import your model class
    
    # Use the same model parameters as during training
    model = SwinUNETR(
        img_size=(128, 128, 128),  # Assuming these were your parameters
        in_channels=4,
        out_channels=3,
        feature_size=48,
        depths=(2, 2, 2, 2),
        num_heads=(3, 6, 12, 24),
        norm_name="instance",
    ).to(device)
    
    # Load the state dictionary
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Set to evaluation mode
    model.eval()
    
    # Initialize metrics
    val_dice = 0
    class_dice = {'ET': 0, 'WT': 0, 'TC': 0}
    
    # Define the dice coefficient functions (assuming they're already defined)
    
    # Test the model
    print("Testing model on validation data...")
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            # Calculate average Dice score
            dice_score = multiclass_dice_coefficient(outputs, targets)
            val_dice += dice_score.item()
            
            # Calculate class-wise Dice scores
            scores = class_wise_dice_coefficient(outputs, targets)
            class_dice['ET'] += scores[0]
            class_dice['WT'] += scores[1]
            class_dice['TC'] += scores[2]
    
    # Average over validation dataset
    val_dice /= len(val_loader)
    for key in class_dice:
        class_dice[key] /= len(val_loader)
    
    # Print results
    print(f"Average Dice score: {val_dice:.4f}")
    print(f"Class-wise Dice scores:")
    print(f"  ET: {class_dice['ET']:.4f}")
    print(f"  WT: {class_dice['WT']:.4f}")
    print(f"  TC: {class_dice['TC']:.4f}")
    
    # Return metrics
    return {
        'avg_dice': val_dice,
        'class_dice': class_dice
    }

# Create a function to compare different models
def compare_models(model_paths, val_loader, device=None):
    """
    Compare multiple models on validation data
    
    Args:
        model_paths: List of paths to saved model checkpoints
        val_loader: Validation data loader
        device: Device to run the models on
    """
    results = {}
    
    for path in model_paths:
        model_name = os.path.basename(path)
        print(f"\nEvaluating model: {model_name}")
        results[model_name] = load_and_test_model(path, val_loader, device)
    
    # Print comparison table
    print("\nModel Comparison:")
    print("=" * 80)
    print(f"{'Model Name':<25} {'Avg Dice':<10} {'ET':<10} {'WT':<10} {'TC':<10}")
    print("-" * 80)
    
    for model_name, metrics in results.items():
        dice = metrics['avg_dice']
        class_dice = metrics['class_dice']
        print(f"{model_name:<25} {dice:<10.4f} {class_dice['ET']:<10.4f} {class_dice['WT']:<10.4f} {class_dice['TC']:<10.4f}")
    
    print("=" * 80)

# Example usage
model_paths = [
    '/workspace/checkpoints/swin_unetr_best_dice.pth',
    '/workspace/checkpoints/swin_unetr_best_loss.pth',
    '/workspace/checkpoints/swin_unetr_epoch_10.pth'
]

# First, create your validation loader using the same function you used for training
# This assumes you have already defined the necessary functions  # Import your data loader function

train_loader, val_loader = create_dataloaders(
    data_dir="workspace/data",
    batch_size=2,
    patch_size=(128, 128, 128),
    train_ratio=0.8
)

# Compare the models
compare_models(model_paths, val_loader)

In [None]:
def visualize_predictions(model_path, val_loader, num_samples=3, device=None):
    """
    Visualize predictions from a trained model on validation data
    
    Args:
        model_path: Path to the saved model checkpoint
        val_loader: Validation data loader
        num_samples: Number of samples to visualize
        device: Device to run the model on
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load the model
    checkpoint = torch.load(model_path, map_location=device)
    
    # Create a new model instance
 # Import your model class
    
    model = SwinUNETR(
        img_size=(128, 128, 128),
        in_channels=4,
        out_channels=3,
        feature_size=48,
        depths=(2, 2, 2, 2),
        num_heads=(3, 6, 12, 24),
        norm_name="instance",
    ).to(device)
    
    # Load the state dictionary
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Set to evaluation mode
    model.eval()
    
    # Get samples from validation loader
    samples = []
    with torch.no_grad():
        for inputs, targets in val_loader:
            samples.append((inputs, targets))
            if len(samples) >= num_samples:
                break
    
    # Create visualizations
    import matplotlib.pyplot as plt
    
    for i, (inputs, targets) in enumerate(samples):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        
        # Apply sigmoid to convert outputs to probabilities
        outputs = torch.sigmoid(outputs)
        
        # Convert to numpy for visualization
        inputs_np = inputs.cpu().numpy()[0]  # [C, H, W, D]
        targets_np = targets.cpu().numpy()[0]  # [3, H, W, D]
        outputs_np = outputs.cpu().numpy()[0]  # [3, H, W, D]
        
        # Find a good slice to visualize (middle slice with tumor)
        tumor_area_per_slice = np.sum(targets_np[1], axis=(0, 1))  # Use WT for finding tumor area
        good_slice = np.argmax(tumor_area_per_slice)
        
        # Create figure
        fig, axes = plt.subplots(3, 4, figsize=(16, 12))
        
        # Show input modalities
        modality_names = ['T1', 'T1ce', 'T2', 'FLAIR']
        for j in range(4):
            axes[0, j].imshow(inputs_np[j, :, :, good_slice], cmap='gray')
            axes[0, j].set_title(f'Input: {modality_names[j]}')
            axes[0, j].axis('off')
        
        # Show target segmentations
        class_names = ['ET', 'WT', 'TC']
        for j in range(3):
            axes[1, j].imshow(targets_np[j, :, :, good_slice], cmap='hot')
            axes[1, j].set_title(f'Target: {class_names[j]}')
            axes[1, j].axis('off')
        
        # Show predicted segmentations
        for j in range(3):
            axes[2, j].imshow(outputs_np[j, :, :, good_slice], cmap='hot')
            axes[2, j].set_title(f'Prediction: {class_names[j]}')
            axes[2, j].axis('off')
        
        # Show overlay of all predictions on FLAIR
        flair = inputs_np[3, :, :, good_slice]
        flair_norm = (flair - flair.min()) / (flair.max() - flair.min())
        
        # Create RGB overlay
        overlay = np.zeros((flair.shape[0], flair.shape[1], 3))
        overlay[:, :, 0] = outputs_np[0, :, :, good_slice]  # ET in red
        overlay[:, :, 1] = outputs_np[1, :, :, good_slice]  # WT in green
        overlay[:, :, 2] = outputs_np[2, :, :, good_slice]  # TC in blue
        
        # Show overlay
        axes[1, 3].imshow(flair_norm, cmap='gray')
        axes[1, 3].imshow(overlay, alpha=0.5)
        axes[1, 3].set_title('Overlay on FLAIR')
        axes[1, 3].axis('off')
        
        # Empty plot
        axes[2, 3].axis('off')
        
        # Set title for the whole figure
        plt.suptitle(f'Sample {i+1} - Slice {good_slice}', fontsize=16)
        plt.tight_layout()
        plt.savefig(f'sample_{i+1}_predictions.png')
        plt.show()

# Example usage
visualize_predictions('/workspace/checkpoints/swin_unetr_best_dice.pth', val_loader, num_samples=3)

In [None]:
if __name__ == "__main__":
    # 1. Visualize training history
    visualize_training_history('/workspace/checkpoints/training_history.json')
    
    # 2. Create validation loader
    train_loader, val_loader = create_dataloaders(
        data_dir="workspace/data",
        batch_size=1,
        patch_size=(128, 128, 128),
        train_ratio=0.8
    )
    
    # 3. Compare different model checkpoints
    model_paths = [
        '/workspace/checkpoints/swin_unetr_best_dice.pth',
        '/workspace/checkpoints/swin_unetr_best_loss.pth',
        '/workspace/checkpoints/swin_unetr_epoch_10.pth'
    ]
    compare_models(model_paths, val_loader)
    
    # 4. Visualize predictions from the best model
    visualize_predictions('/workspace/checkpoints/swin_unetr_best_dice.pth', val_loader, num_samples=3)

In [None]:
def plot_training_history(history):
    """Plot the training history with class-wise metrics"""
    import matplotlib.pyplot as plt
    
    # Create figure with subplots
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot loss
    axs[0, 0].plot(history['train_loss'], label='Train Loss')
    axs[0, 0].plot(history['val_loss'], label='Val Loss')
    axs[0, 0].set_title('Loss')
    axs[0, 0].set_xlabel('Epoch')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend()
    
    # Plot average Dice score
    axs[0, 1].plot(history['train_dice'], label='Train Dice')
    axs[0, 1].plot(history['val_dice'], label='Val Dice')
    axs[0, 1].set_title('Average Dice Score')
    axs[0, 1].set_xlabel('Epoch')
    axs[0, 1].set_ylabel('Dice Score')
    axs[0, 1].legend()
    
    # Plot training class-wise Dice scores
    axs[1, 0].plot(history['train_dice_et'], label='ET')
    axs[1, 0].plot(history['train_dice_wt'], label='WT')
    axs[1, 0].plot(history['train_dice_tc'], label='TC')
    axs[1, 0].set_title('Training Class-wise Dice Scores')
    axs[1, 0].set_xlabel('Epoch')
    axs[1, 0].set_ylabel('Dice Score')
    axs[1, 0].legend()
    
    # Plot validation class-wise Dice scores
    axs[1, 1].plot(history['val_dice_et'], label='ET')
    axs[1, 1].plot(history['val_dice_wt'], label='WT')
    axs[1, 1].plot(history['val_dice_tc'], label='TC')
    axs[1, 1].set_title('Validation Class-wise Dice Scores')
    axs[1, 1].set_xlabel('Epoch')
    axs[1, 1].set_ylabel('Dice Score')
    axs[1, 1].legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

In [None]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

def visualize_tumor_classes(seg_path, slice_idx=None):
    """
    Visualize the three tumor classes (ET, WT, TC) from a BraTS segmentation mask
    Args:
        seg_path: Path to the segmentation mask (.nii.gz)
        slice_idx: Slice index to visualize (if None, will find the middle slice with tumor)
    """
    # Load segmentation mask
    seg_nii = nib.load(seg_path)
    seg_data = seg_nii.get_fdata()
    
    # Create binary masks for each class
    # In BraTS: label 1 = necrotic tumor core, label 2 = peritumoral edema, label 4 = enhancing tumor
    mask_et = (seg_data == 4).astype(np.float32)  # Enhancing Tumor
    mask_wt = ((seg_data == 1) | (seg_data == 2) | (seg_data == 4)).astype(np.float32)  # Whole Tumor
    mask_tc = ((seg_data == 1) | (seg_data == 4)).astype(np.float32)  # Tumor Core
    
    # If slice_idx is not provided, find a good slice to display
    if slice_idx is None:
        # Find the slice with the largest tumor area
        tumor_area_per_slice = np.sum(mask_wt, axis=(0, 1))
        slice_idx = np.argmax(tumor_area_per_slice)
    
    # Plot the three classes
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Define colors
    color_et = np.array([[0, 0, 0, 0], [1, 0, 0, 1]])  # Red for ET
    color_wt = np.array([[0, 0, 0, 0], [0, 1, 0, 1]])  # Green for WT
    color_tc = np.array([[0, 0, 0, 0], [0, 0, 1, 1]])  # Blue for TC
    
    cmap_et = ListedColormap(color_et)
    cmap_wt = ListedColormap(color_wt)
    cmap_tc = ListedColormap(color_tc)
    
    # Plot each class
    axes[0].imshow(mask_et[:, :, slice_idx], cmap=cmap_et)
    axes[0].set_title('Enhancing Tumor (ET)')
    axes[0].axis('off')
    
    axes[1].imshow(mask_wt[:, :, slice_idx], cmap=cmap_wt)
    axes[1].set_title('Whole Tumor (WT)')
    axes[1].axis('off')
    
    axes[2].imshow(mask_tc[:, :, slice_idx], cmap=cmap_tc)
    axes[2].set_title('Tumor Core (TC)')
    axes[2].axis('off')
    
    plt.suptitle(f'Tumor Classes Visualization (Slice {slice_idx})')
    plt.tight_layout()
    plt.show()

    # Let's also create a combined visualization with an MRI background
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Create a combined RGB image
    rgb_img = np.zeros((seg_data.shape[0], seg_data.shape[1], 3))
    rgb_img[:, :, 0] = mask_et[:, :, slice_idx]  # Red channel - ET
    rgb_img[:, :, 1] = mask_wt[:, :, slice_idx]  # Green channel - WT
    rgb_img[:, :, 2] = mask_tc[:, :, slice_idx]  # Blue channel - TC
    
    ax.imshow(rgb_img)
    ax.set_title('Combined Tumor Classes (Red=ET, Green=WT, Blue=TC)')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Use the correct path based on your directory structure
visualize_tumor_classes('workspace/data/BraTS2021_00000/BraTS2021_00000_seg.nii.gz')

In [None]:
# For a more informative visualization with MRI background
def visualize_tumor_with_background(seg_path, flair_path, slice_idx=None):
    # Load segmentation mask
    seg_nii = nib.load(seg_path)
    seg_data = seg_nii.get_fdata()
    
    # Load FLAIR for background
    flair_nii = nib.load(flair_path)
    flair_data = flair_nii.get_fdata()
    
    # Create binary masks for each class
    mask_et = (seg_data == 4).astype(np.float32)
    mask_wt = ((seg_data == 1) | (seg_data == 2) | (seg_data == 4)).astype(np.float32)
    mask_tc = ((seg_data == 1) | (seg_data == 4)).astype(np.float32)
    
    # If slice_idx is not provided, find a good slice to display
    if slice_idx is None:
        tumor_area_per_slice = np.sum(mask_wt, axis=(0, 1))
        slice_idx = np.argmax(tumor_area_per_slice)
    
    # Normalize FLAIR for background
    flair_slice = flair_data[:, :, slice_idx]
    flair_min, flair_max = np.min(flair_slice), np.max(flair_slice)
    if flair_max > flair_min:  # Avoid division by zero
        flair_norm = (flair_slice - flair_min) / (flair_max - flair_min)
    else:
        flair_norm = flair_slice
    
    # Plot with MRI background
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    
    # Background FLAIR only
    axes[0, 0].imshow(flair_norm, cmap='gray')
    axes[0, 0].set_title('FLAIR MRI')
    axes[0, 0].axis('off')
    
    # Overlay ET on FLAIR
    axes[0, 1].imshow(flair_norm, cmap='gray')
    axes[0, 1].imshow(mask_et[:, :, slice_idx], cmap='hot', alpha=0.5)
    axes[0, 1].set_title('Enhancing Tumor (ET)')
    axes[0, 1].axis('off')
    
    # Overlay WT on FLAIR
    axes[1, 0].imshow(flair_norm, cmap='gray')
    axes[1, 0].imshow(mask_wt[:, :, slice_idx], cmap='winter', alpha=0.5)
    axes[1, 0].set_title('Whole Tumor (WT)')
    axes[1, 0].axis('off')
    
    # Overlay TC on FLAIR
    axes[1, 1].imshow(flair_norm, cmap='gray')
    axes[1, 1].imshow(mask_tc[:, :, slice_idx], cmap='autumn', alpha=0.5)
    axes[1, 1].set_title('Tumor Core (TC)')
    axes[1, 1].axis('off')
    
    plt.suptitle(f'Tumor Segmentation Visualization (Slice {slice_idx})', fontsize=16)
    plt.tight_layout()
    plt.show()

# Example usage with the correct paths
visualize_tumor_with_background(
    'workspace/data/BraTS2021_00000/BraTS2021_00000_seg.nii.gz',
    'workspace/data/BraTS2021_00000/BraTS2021_00000_flair.nii.gz'
)