<a href="https://colab.research.google.com/github/poornasandur/3D-F-CNN-BrainStruct/blob/master/BraTS_AfricaComplete_pipeline22222.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
!pip install monai
!pip install nibabel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import nibabel as nib
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm
import monai
from monai.losses import DiceLoss, DiceCELoss
from glob import glob
import json



In [12]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [13]:
# Configuration
class Config:
    data_dir = '/content/drive/MyDrive/BraTS-Africa'  # Update with your path
    batch_size = 2
    learning_rate = 1e-4
    epochs = 100
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_classes = 4  # Background, ET, TC, WT
    img_size = (128, 128, 128)  # Reduced size for memory efficiency

    @classmethod
    def get_serializable_config(cls):
        return {
            'data_dir': cls.data_dir,
            'batch_size': cls.batch_size,
            'learning_rate': cls.learning_rate,
            'epochs': cls.epochs,
            'device': str(cls.device),
            'num_classes': cls.num_classes,
            'img_size': list(cls.img_size)
        }

    @classmethod
    def save(cls, path):
        with open(path, 'w') as f:
            json.dump(cls.get_serializable_config(), f, indent=4)

    @classmethod
    def load(cls, path):
        with open(path, 'r') as f:
            config = json.load(f)
        cls.data_dir = config['data_dir']
        cls.batch_size = config['batch_size']
        cls.learning_rate = config['learning_rate']
        cls.epochs = config['epochs']
        cls.device = torch.device(config['device'])
        cls.num_classes = config['num_classes']
        cls.img_size = tuple(config['img_size'])

In [14]:
class BraTSAfricaDataset(Dataset):
    def __init__(self, patient_dirs, mode='train', transform=None):
        self.patient_dirs = patient_dirs
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return len(self.patient_dirs)

    def __getitem__(self, idx):
        patient_dir = self.patient_dirs[idx]
        patient_id = os.path.basename(patient_dir)

        # Load modalities
        modalities = []
        for mod in ['t1n', 't1c', 't2w', 't2f']:
            img = nib.load(os.path.join(patient_dir, f"{patient_id}-{mod}.nii.gz")).get_fdata()
            img = self.preprocess(img)
            modalities.append(img)

        image = np.stack(modalities, axis=0)  # Shape: (4, H, W, D)

        if self.mode in ['train', 'val']:
            mask = nib.load(os.path.join(patient_dir, f"{patient_id}-seg.nii.gz")).get_fdata()
            mask = self.preprocess(mask, is_mask=True)
            mask = self.encode_mask(mask)  # Shape: (H, W, D, 4)
            mask = np.moveaxis(mask, -1, 0)  # Shape: (4, H, W, D)
            return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)
        else:
            return torch.tensor(image, dtype=torch.float32)

    def preprocess(self, img, is_mask=False):
        img = self.crop_or_pad(img, Config.img_size)
        if not is_mask:
            img = self.normalize(img)
        return img

    def crop_or_pad(self, img, target_size):
        current_size = img.shape
        new_img = np.zeros(target_size, dtype=img.dtype)

        starts = [max(0, (current_size[i] - target_size[i]) // 2) for i in range(3)]
        ends = [min(current_size[i], starts[i] + target_size[i]) for i in range(3)]

        t_starts = [max(0, (target_size[i] - current_size[i]) // 2) for i in range(3)]
        t_ends = [min(target_size[i], t_starts[i] + current_size[i]) for i in range(3)]

        new_img[t_starts[0]:t_ends[0], t_starts[1]:t_ends[1], t_starts[2]:t_ends[2]] = \
            img[starts[0]:ends[0], starts[1]:ends[1], starts[2]:ends[2]]
        return new_img

    def normalize(self, img):
        eps = 1e-8
        return (img - img.mean()) / (img.std() + eps)

    def encode_mask(self, mask):
        encoded_mask = np.zeros((*mask.shape, Config.num_classes), dtype=np.float32)
        encoded_mask[..., 0] = (mask == 0)  # Background
        encoded_mask[..., 1] = (mask == 4)  # ET
        encoded_mask[..., 2] = np.logical_or(mask == 1, mask == 4)  # TC
        encoded_mask[..., 3] = (mask > 0)   # WT
        return encoded_mask

In [15]:
class UNet3D(nn.Module):
    def __init__(self, in_channels=4, out_channels=4, init_features=32):
        super().__init__()
        features = init_features

        self.encoder1 = self._block(in_channels, features, "enc1")
        self.pool1 = nn.MaxPool3d(2, 2)
        self.encoder2 = self._block(features, features*2, "enc2")
        self.pool2 = nn.MaxPool3d(2, 2)
        self.encoder3 = self._block(features*2, features*4, "enc3")
        self.pool3 = nn.MaxPool3d(2, 2)
        self.encoder4 = self._block(features*4, features*8, "enc4")
        self.pool4 = nn.MaxPool3d(2, 2)

        self.bottleneck = self._block(features*8, features*16, "bottleneck")

        self.upconv4 = nn.ConvTranspose3d(features*16, features*8, 2, 2)
        self.decoder4 = self._block(features*16, features*8, "dec4")
        self.upconv3 = nn.ConvTranspose3d(features*8, features*4, 2, 2)
        self.decoder3 = self._block(features*8, features*4, "dec3")
        self.upconv2 = nn.ConvTranspose3d(features*4, features*2, 2, 2)
        self.decoder2 = self._block(features*4, features*2, "dec2")
        self.upconv1 = nn.ConvTranspose3d(features*2, features, 2, 2)
        self.decoder1 = self._block(features*2, features, "dec1")

        self.conv = nn.Conv3d(features, out_channels, 1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    def _block(self, in_channels, features, name):
        return nn.Sequential(
            nn.Conv3d(in_channels, features, 3, padding=1, bias=False),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True),
            nn.Conv3d(features, features, 3, padding=1, bias=False),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True)
        )

In [16]:
def train(model, train_loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0

    for data, target in tqdm(train_loader, desc="Training"):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)

        # Convert one-hot target to class indices
        target_classes = torch.argmax(target, dim=1, keepdim=True)

        loss = loss_fn(output, target_classes)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate(model, val_loader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    total_dice = 0.0

    with torch.no_grad():
        for data, target in tqdm(val_loader, desc="Validation"):
            data, target = data.to(device), target.to(device)
            output = model(data)

            target_classes = torch.argmax(target, dim=1, keepdim=True)
            loss = loss_fn(output, target_classes)
            total_loss += loss.item()

            preds = torch.argmax(output, dim=1)
            dice_score = monai.metrics.compute_dice(
                y_pred=preds,
                y=target_classes.squeeze(1),
                include_background=False
            )
            total_dice += dice_score.mean().item()

    return total_loss / len(val_loader), total_dice / len(val_loader)

In [17]:
def get_patient_dirs(data_dir):
    valid_dirs = []

    for subdir in ['51_OtherNeoplasms', '95_Glioma']:
        full_path = os.path.join(data_dir, subdir)
        if not os.path.exists(full_path):
            print(f"Warning: {subdir} not found")
            continue

        for patient in os.listdir(full_path):
            patient_dir = os.path.join(full_path, patient)
            if not os.path.isdir(patient_dir):
                continue

            # Check for required files
            required_files = [
                f"{patient}-t1n.nii.gz",
                f"{patient}-t1c.nii.gz",
                f"{patient}-t2w.nii.gz",
                f"{patient}-t2f.nii.gz",
                f"{patient}-seg.nii.gz"
            ]

            if all(os.path.exists(os.path.join(patient_dir, f)) for f in required_files):
                valid_dirs.append(patient_dir)

    if not valid_dirs:
        raise ValueError(f"No valid patient directories found in {data_dir}")

    print(f"Found {len(valid_dirs)} patient directories with complete data")
    return valid_dirs

In [18]:
def main():
    Config.save('config.json')

    # Get patient directories
    patient_dirs = get_patient_dirs(Config.data_dir)
    train_dirs, val_dirs = train_test_split(patient_dirs, test_size=0.2, random_state=42)

    # Create datasets
    train_dataset = BraTSAfricaDataset(train_dirs, 'train')
    val_dataset = BraTSAfricaDataset(val_dirs, 'val')

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Initialize model
    model = UNet3D(in_channels=4, out_channels=Config.num_classes).to(Config.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.learning_rate)
    loss_fn = DiceCELoss(
        include_background=False,
        to_onehot_y=True,
        softmax=True,
        lambda_dice=1.0,
        lambda_ce=1.0
    )

    # Training loop
    best_dice = 0.0
    for epoch in range(Config.epochs):
        print(f"\nEpoch {epoch+1}/{Config.epochs}")

        train_loss = train(model, train_loader, optimizer, loss_fn, Config.device)
        val_loss, val_dice = validate(model, val_loader, loss_fn, Config.device)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Saved new best model with Dice: {best_dice:.4f}")

if __name__ == "__main__":
    main()

Found 7 patient directories with complete data

Epoch 1/100


Training: 100%|██████████| 3/3 [00:28<00:00,  9.49s/it]
Validation: 100%|██████████| 1/1 [00:15<00:00, 15.46s/it]


Train Loss: 2.2483 | Val Loss: 2.3123 | Val Dice: nan

Epoch 2/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.14s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.04s/it]


Train Loss: 2.2159 | Val Loss: 2.3055 | Val Dice: nan

Epoch 3/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.09s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.93s/it]


Train Loss: 2.1935 | Val Loss: 2.2945 | Val Dice: nan

Epoch 4/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.04s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.47s/it]


Train Loss: 2.1785 | Val Loss: 2.2775 | Val Dice: nan

Epoch 5/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.98s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.93s/it]


Train Loss: 2.1765 | Val Loss: 2.2540 | Val Dice: nan

Epoch 6/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.28s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.14s/it]


Train Loss: 2.1741 | Val Loss: 2.2273 | Val Dice: nan

Epoch 7/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.21s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.10s/it]


Train Loss: 2.1575 | Val Loss: 2.2019 | Val Dice: nan

Epoch 8/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.97s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.18s/it]


Train Loss: 2.1542 | Val Loss: 2.1781 | Val Dice: nan

Epoch 9/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.89s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.30s/it]


Train Loss: 2.1547 | Val Loss: 2.1629 | Val Dice: nan

Epoch 10/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.10s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.95s/it]


Train Loss: 2.1451 | Val Loss: 2.1622 | Val Dice: nan

Epoch 11/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.12s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.15s/it]


Train Loss: 2.1282 | Val Loss: 2.1567 | Val Dice: nan

Epoch 12/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.07s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.44s/it]


Train Loss: 2.1237 | Val Loss: 2.1378 | Val Dice: nan

Epoch 13/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.16s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.95s/it]


Train Loss: 2.1208 | Val Loss: 2.1329 | Val Dice: nan

Epoch 14/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.13s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.98s/it]


Train Loss: 2.1268 | Val Loss: 2.1309 | Val Dice: nan

Epoch 15/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.85s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.36s/it]


Train Loss: 2.1249 | Val Loss: 2.1236 | Val Dice: nan

Epoch 16/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.24s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.18s/it]


Train Loss: 2.1300 | Val Loss: 2.1197 | Val Dice: nan

Epoch 17/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.87s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.17s/it]


Train Loss: 2.1246 | Val Loss: 2.1182 | Val Dice: nan

Epoch 18/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.87s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.20s/it]


Train Loss: 2.1113 | Val Loss: 2.1182 | Val Dice: nan

Epoch 19/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.88s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.92s/it]


Train Loss: 2.1206 | Val Loss: 2.1152 | Val Dice: nan

Epoch 20/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.06s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.44s/it]


Train Loss: 2.1080 | Val Loss: 2.1149 | Val Dice: nan

Epoch 21/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.02s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.40s/it]


Train Loss: 2.1147 | Val Loss: 2.1140 | Val Dice: nan

Epoch 22/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.09s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.14s/it]


Train Loss: 2.1128 | Val Loss: 2.1134 | Val Dice: nan

Epoch 23/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.98s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.93s/it]


Train Loss: 2.1100 | Val Loss: 2.1137 | Val Dice: nan

Epoch 24/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.86s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.29s/it]


Train Loss: 2.1096 | Val Loss: 2.1150 | Val Dice: nan

Epoch 25/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.02s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.23s/it]


Train Loss: 2.1112 | Val Loss: 2.1154 | Val Dice: nan

Epoch 26/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.94s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.21s/it]


Train Loss: 2.1092 | Val Loss: 2.1124 | Val Dice: nan

Epoch 27/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.10s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.19s/it]


Train Loss: 2.0958 | Val Loss: 2.1092 | Val Dice: nan

Epoch 28/100


Training: 100%|██████████| 3/3 [00:22<00:00,  7.34s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.03s/it]


Train Loss: 2.1037 | Val Loss: 2.1083 | Val Dice: nan

Epoch 29/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.97s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.02s/it]


Train Loss: 2.0932 | Val Loss: 2.1068 | Val Dice: nan

Epoch 30/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.00s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.96s/it]


Train Loss: 2.1044 | Val Loss: 2.1045 | Val Dice: nan

Epoch 31/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.22s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.03s/it]


Train Loss: 2.0991 | Val Loss: 2.1008 | Val Dice: nan

Epoch 32/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.00s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.47s/it]


Train Loss: 2.1004 | Val Loss: 2.0991 | Val Dice: nan

Epoch 33/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.11s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.99s/it]


Train Loss: 2.0860 | Val Loss: 2.1008 | Val Dice: nan

Epoch 34/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.03s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.96s/it]


Train Loss: 2.0955 | Val Loss: 2.1022 | Val Dice: nan

Epoch 35/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.02s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.51s/it]


Train Loss: 2.0914 | Val Loss: 2.0998 | Val Dice: nan

Epoch 36/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.90s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.09s/it]


Train Loss: 2.0932 | Val Loss: 2.0985 | Val Dice: nan

Epoch 37/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.21s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.14s/it]


Train Loss: 2.0809 | Val Loss: 2.0953 | Val Dice: nan

Epoch 38/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.03s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.56s/it]


Train Loss: 2.0796 | Val Loss: 2.0929 | Val Dice: nan

Epoch 39/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.26s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.02s/it]


Train Loss: 2.0784 | Val Loss: 2.0929 | Val Dice: nan

Epoch 40/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.93s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.01s/it]


Train Loss: 2.0768 | Val Loss: 2.0936 | Val Dice: nan

Epoch 41/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.30s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.46s/it]


Train Loss: 2.0879 | Val Loss: 2.0922 | Val Dice: nan

Epoch 42/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.09s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.41s/it]


Train Loss: 2.0819 | Val Loss: 2.0886 | Val Dice: nan

Epoch 43/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.20s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.22s/it]


Train Loss: 2.0800 | Val Loss: 2.0833 | Val Dice: nan

Epoch 44/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.21s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.06s/it]


Train Loss: 2.0806 | Val Loss: 2.0853 | Val Dice: nan

Epoch 45/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.04s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.46s/it]


Train Loss: 2.0824 | Val Loss: 2.0843 | Val Dice: nan

Epoch 46/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.19s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.05s/it]


Train Loss: 2.0685 | Val Loss: 2.0834 | Val Dice: nan

Epoch 47/100


Training: 100%|██████████| 3/3 [00:22<00:00,  7.45s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.97s/it]


Train Loss: 2.0660 | Val Loss: 2.0864 | Val Dice: nan

Epoch 48/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.10s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.43s/it]


Train Loss: 2.0644 | Val Loss: 2.0812 | Val Dice: nan

Epoch 49/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.05s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.21s/it]


Train Loss: 2.0646 | Val Loss: 2.0769 | Val Dice: nan

Epoch 50/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.23s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.96s/it]


Train Loss: 2.0747 | Val Loss: 2.0739 | Val Dice: nan

Epoch 51/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.01s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.52s/it]


Train Loss: 2.0599 | Val Loss: 2.0731 | Val Dice: nan

Epoch 52/100


Training: 100%|██████████| 3/3 [00:22<00:00,  7.41s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.05s/it]


Train Loss: 2.0684 | Val Loss: 2.0738 | Val Dice: nan

Epoch 53/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.26s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.50s/it]


Train Loss: 2.0577 | Val Loss: 2.0729 | Val Dice: nan

Epoch 54/100


Training: 100%|██████████| 3/3 [00:22<00:00,  7.39s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.45s/it]


Train Loss: 2.0568 | Val Loss: 2.0709 | Val Dice: nan

Epoch 55/100


Training: 100%|██████████| 3/3 [00:22<00:00,  7.43s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.30s/it]


Train Loss: 2.0647 | Val Loss: 2.0700 | Val Dice: nan

Epoch 56/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.83s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.06s/it]


Train Loss: 2.0641 | Val Loss: 2.0673 | Val Dice: nan

Epoch 57/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.81s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.14s/it]


Train Loss: 2.0622 | Val Loss: 2.0664 | Val Dice: nan

Epoch 58/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.02s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.17s/it]


Train Loss: 2.0646 | Val Loss: 2.0668 | Val Dice: nan

Epoch 59/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.02s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.09s/it]


Train Loss: 2.0500 | Val Loss: 2.0630 | Val Dice: nan

Epoch 60/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.05s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.47s/it]


Train Loss: 2.0501 | Val Loss: 2.0644 | Val Dice: nan

Epoch 61/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.15s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.34s/it]


Train Loss: 2.0480 | Val Loss: 2.0666 | Val Dice: nan

Epoch 62/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.21s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.49s/it]


Train Loss: 2.0457 | Val Loss: 2.0646 | Val Dice: nan

Epoch 63/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.10s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.99s/it]


Train Loss: 2.0597 | Val Loss: 2.0624 | Val Dice: nan

Epoch 64/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.07s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.47s/it]


Train Loss: 2.0439 | Val Loss: 2.0636 | Val Dice: nan

Epoch 65/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.01s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.96s/it]


Train Loss: 2.0589 | Val Loss: 2.0600 | Val Dice: nan

Epoch 66/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.00s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.43s/it]


Train Loss: 2.0411 | Val Loss: 2.0540 | Val Dice: nan

Epoch 67/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.24s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.39s/it]


Train Loss: 2.0497 | Val Loss: 2.0512 | Val Dice: nan

Epoch 68/100


Training: 100%|██████████| 3/3 [00:22<00:00,  7.40s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.93s/it]


Train Loss: 2.0533 | Val Loss: 2.0518 | Val Dice: nan

Epoch 69/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.18s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.44s/it]


Train Loss: 2.0371 | Val Loss: 2.0535 | Val Dice: nan

Epoch 70/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.17s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.50s/it]


Train Loss: 2.0509 | Val Loss: 2.0515 | Val Dice: nan

Epoch 71/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.27s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.49s/it]


Train Loss: 2.0369 | Val Loss: 2.0465 | Val Dice: nan

Epoch 72/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.13s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.54s/it]


Train Loss: 2.0334 | Val Loss: 2.0444 | Val Dice: nan

Epoch 73/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.79s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.50s/it]


Train Loss: 2.0424 | Val Loss: 2.0478 | Val Dice: nan

Epoch 74/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.23s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.38s/it]


Train Loss: 2.0428 | Val Loss: 2.0461 | Val Dice: nan

Epoch 75/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.94s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.07s/it]


Train Loss: 2.0400 | Val Loss: 2.0440 | Val Dice: nan

Epoch 76/100


Training: 100%|██████████| 3/3 [00:22<00:00,  7.35s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.02s/it]


Train Loss: 2.0286 | Val Loss: 2.0458 | Val Dice: nan

Epoch 77/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.09s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.16s/it]


Train Loss: 2.0293 | Val Loss: 2.0440 | Val Dice: nan

Epoch 78/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.20s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.07s/it]


Train Loss: 2.0404 | Val Loss: 2.0443 | Val Dice: nan

Epoch 79/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.06s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.45s/it]


Train Loss: 2.0405 | Val Loss: 2.0421 | Val Dice: nan

Epoch 80/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.33s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.06s/it]


Train Loss: 2.0346 | Val Loss: 2.0405 | Val Dice: nan

Epoch 81/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.23s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.99s/it]


Train Loss: 2.0242 | Val Loss: 2.0378 | Val Dice: nan

Epoch 82/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.03s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.05s/it]


Train Loss: 2.0361 | Val Loss: 2.0389 | Val Dice: nan

Epoch 83/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.98s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.06s/it]


Train Loss: 2.0214 | Val Loss: 2.0395 | Val Dice: nan

Epoch 84/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.99s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.11s/it]


Train Loss: 2.0350 | Val Loss: 2.0381 | Val Dice: nan

Epoch 85/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.87s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.20s/it]


Train Loss: 2.0204 | Val Loss: 2.0351 | Val Dice: nan

Epoch 86/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.04s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.00s/it]


Train Loss: 2.0185 | Val Loss: 2.0327 | Val Dice: nan

Epoch 87/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.09s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.37s/it]


Train Loss: 2.0278 | Val Loss: 2.0330 | Val Dice: nan

Epoch 88/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.92s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.49s/it]


Train Loss: 2.0319 | Val Loss: 2.0331 | Val Dice: nan

Epoch 89/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.04s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.04s/it]


Train Loss: 2.0241 | Val Loss: 2.0324 | Val Dice: nan

Epoch 90/100


Training: 100%|██████████| 3/3 [00:20<00:00,  7.00s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.97s/it]


Train Loss: 2.0254 | Val Loss: 2.0285 | Val Dice: nan

Epoch 91/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.24s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.06s/it]


Train Loss: 2.0126 | Val Loss: 2.0264 | Val Dice: nan

Epoch 92/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.08s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.51s/it]


Train Loss: 2.0220 | Val Loss: 2.0271 | Val Dice: nan

Epoch 93/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.08s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.53s/it]


Train Loss: 2.0105 | Val Loss: 2.0247 | Val Dice: nan

Epoch 94/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.09s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.48s/it]


Train Loss: 2.0081 | Val Loss: 2.0228 | Val Dice: nan

Epoch 95/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.19s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.19s/it]


Train Loss: 2.0085 | Val Loss: 2.0232 | Val Dice: nan

Epoch 96/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.22s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.02s/it]


Train Loss: 2.0228 | Val Loss: 2.0206 | Val Dice: nan

Epoch 97/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.95s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.18s/it]


Train Loss: 2.0161 | Val Loss: 2.0203 | Val Dice: nan

Epoch 98/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.32s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.16s/it]


Train Loss: 2.0044 | Val Loss: 2.0194 | Val Dice: nan

Epoch 99/100


Training: 100%|██████████| 3/3 [00:20<00:00,  6.90s/it]
Validation: 100%|██████████| 1/1 [00:04<00:00,  4.99s/it]


Train Loss: 2.0193 | Val Loss: 2.0187 | Val Dice: nan

Epoch 100/100


Training: 100%|██████████| 3/3 [00:21<00:00,  7.03s/it]
Validation: 100%|██████████| 1/1 [00:05<00:00,  5.12s/it]

Train Loss: 2.0111 | Val Loss: 2.0193 | Val Dice: nan





In [20]:
import torch
import nibabel as nib
import numpy as np
from monai.inferers import sliding_window_inference
import os

class BraTSInferencePipeline:
    def __init__(self, model_path, config):
        """
        Initialize the inference pipeline

        Args:
            model_path: Path to trained model weights (.pth)
            config: Configuration dictionary with parameters
        """
        self.device = torch.device(config['device'])
        self.img_size = config['img_size']
        self.num_classes = config['num_classes']

        # Load model
        self.model = UNet3D(in_channels=4, out_channels=self.num_classes).to(self.device)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()

        # Inference parameters
        self.sw_batch_size = 2
        self.roi_size = config['img_size']
        self.overlap = 0.5

    def preprocess(self, patient_dir):
        """
        Load and preprocess a patient's MRI volumes

        Args:
            patient_dir: Path to patient directory containing:
                - {patient_id}-t1n.nii.gz
                - {patient_id}-t1c.nii.gz
                - {patient_id}-t2w.nii.gz
                - {patient_id}-t2f.nii.gz

        Returns:
            torch.Tensor: Preprocessed image tensor (1, 4, D, H, W)
            nibabel.Nifti1Image: Original image for reference
        """
        patient_id = os.path.basename(patient_dir)
        modalities = []

        # Load each modality
        for mod in ['t1n', 't1c', 't2w', 't2f']:
            img_path = os.path.join(patient_dir, f"{patient_id}-{mod}.nii.gz")
            img = nib.load(img_path)
            data = img.get_fdata()

            # Normalize and resize
            data = self._normalize(data)
            data = self._crop_or_pad(data, self.img_size)
            modalities.append(data)

        # Stack modalities and add batch dimension
        image = np.stack(modalities, axis=0)  # (4, D, H, W)
        image = torch.from_numpy(image).float().unsqueeze(0)  # (1, 4, D, H, W)

        return image.to(self.device), img

    def _normalize(self, data):
        """Z-score normalization"""
        return (data - data.mean()) / (data.std() + 1e-8)

    def _crop_or_pad(self, data, target_shape):
        """Center crop or pad to target shape"""
        current_shape = data.shape
        new_data = np.zeros(target_shape, dtype=data.dtype)

        # Calculate crop/pad dimensions
        starts = [max(0, (current_shape[i] - target_shape[i]) // 2) for i in range(3)]
        ends = [min(current_shape[i], starts[i] + target_shape[i]) for i in range(3)]

        t_starts = [max(0, (target_shape[i] - current_shape[i]) // 2) for i in range(3)]
        t_ends = [min(target_shape[i], t_starts[i] + current_shape[i]) for i in range(3)]

        new_data[t_starts[0]:t_ends[0],
                 t_starts[1]:t_ends[1],
                 t_starts[2]:t_ends[2]] = data[starts[0]:ends[0],
                                              starts[1]:ends[1],
                                              starts[2]:ends[2]]
        return new_data

    def predict(self, image_tensor):
        """
        Run inference on preprocessed image

        Args:
            image_tensor: Input tensor (1, 4, D, H, W)

        Returns:
            np.ndarray: Predicted segmentation (D, H, W)
        """
        with torch.no_grad():
            # Use sliding window for large volumes
            output = sliding_window_inference(
                inputs=image_tensor,
                roi_size=self.roi_size,
                sw_batch_size=self.sw_batch_size,
                predictor=self.model,
                overlap=self.overlap,
                mode='gaussian'
            )

        # Convert to discrete labels
        pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
        return pred

    def postprocess(self, prediction, reference_nifti):
        """
        Convert prediction to BraTS format and create NIfTI

        Args:
            prediction: (D, H, W) integer array
            reference_nifti: Original NIfTI image for header

        Returns:
            nibabel.Nifti1Image: Segmentation in BraTS format
        """
        # Convert classes back to BraTS labels
        brats_pred = np.zeros_like(prediction, dtype=np.uint8)
        brats_pred[prediction == 1] = 4  # ET
        brats_pred[prediction == 2] = 1  # TC (necrotic)
        brats_pred[prediction == 3] = 2  # WT (edema)

        return nib.Nifti1Image(brats_pred, reference_nifti.affine, reference_nifti.header)

    def process_patient(self, patient_dir, output_dir=None):
        """
        Complete pipeline for one patient

        Args:
            patient_dir: Path to patient data
            output_dir: Where to save segmentation (optional)

        Returns:
            tuple: (prediction_array, nifti_image)
        """
        # 1. Preprocess
        image_tensor, reference_nifti = self.preprocess(patient_dir)

        # 2. Predict
        pred = self.predict(image_tensor)

        # 3. Postprocess
        seg_nifti = self.postprocess(pred, reference_nifti)

        # 4. Save if needed
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            patient_id = os.path.basename(patient_dir)
            output_path = os.path.join(output_dir, f"{patient_id}-seg.nii.gz")
            nib.save(seg_nifti, output_path)
            print(f"Saved segmentation to {output_path}")

        return pred, seg_nifti


# Example Usage
if __name__ == "__main__":
    # Load your saved config
    config = {
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'img_size': (128, 128, 128),
        'num_classes': 4
    }

    # Initialize pipeline
    pipeline = BraTSInferencePipeline(
        model_path="best_model.pth",
        config=config
    )

    # Process a patient
    patient_dir = "/content/drive/MyDrive/BraTS-Africa/95_Glioma/BraTS-SSA-00002-000"
    prediction, seg_nifti = pipeline.process_patient(
        patient_dir,
        output_dir="./output_segmentations"
    )

    # Visualize slices
    import matplotlib.pyplot as plt
    plt.imshow(prediction[64], cmap='jet')  # Show middle slice
    plt.show()

FileNotFoundError: [Errno 2] No such file or directory: 'best_model.pth'