In [1]:
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
class BrainTumorDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.patient_ids = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        patient_dir = os.path.join(self.data_dir, patient_id)

        # Load all 4 modalities
        flair = nib.load(os.path.join(patient_dir, f"{patient_id}_flair.nii")).get_fdata()
        t1 = nib.load(os.path.join(patient_dir, f"{patient_id}_t1.nii")).get_fdata()
        t1ce = nib.load(os.path.join(patient_dir, f"{patient_id}_t1ce.nii")).get_fdata()
        t2 = nib.load(os.path.join(patient_dir, f"{patient_id}_t2.nii")).get_fdata()
        
        # Stack the modalities along the channel dimension (C)
        image = np.stack([flair, t1, t1ce, t2], axis=0)  # Shape: [C, D, H, W]
        
        # Load the segmentation label
        label = nib.load(os.path.join(patient_dir, f"{patient_id}_seg.nii")).get_fdata()

        # Convert to PyTorch tensors
        image = torch.tensor(image, dtype=torch.float32)  # Shape: [C, D, H, W]
        label = torch.tensor(label, dtype=torch.long)  # Shape: [D, H, W]

        if self.transform:
            image = self.transform(image)
        
        return image, label

In [3]:
data_dir = 'data/'  # Update this with your data directory
dataset = BrainTumorDataset(data_dir)

# Split dataset into train and validation
train_indices, val_indices = train_test_split(np.arange(len(dataset)), test_size=0.2, random_state=42)

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [4]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        # Define U-Net architecture
        self.encoder1 = self.conv_block(in_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)
        self.bottleneck = self.conv_block(512, 1024)

        self.decoder4 = self.upconv_block(1024, 512)
        self.decoder3 = self.upconv_block(512, 256)
        self.decoder2 = self.upconv_block(256, 128)
        self.decoder1 = self.upconv_block(128, 64)
        
        self.final_conv = nn.Conv3d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        bottleneck = self.bottleneck(enc4)

        dec4 = self.decoder4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec3 = self.decoder3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec2 = self.decoder2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec1 = self.decoder1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)

        return self.final_conv(dec1)


In [5]:
device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set your batch size
batch_size = 1  # Small batch size to reduce memory usage
accumulation_steps = 4  # Number of steps to accumulate gradients

model = UNet(in_channels=4, out_channels=4).to(device)  # Replace with your model definition
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

Using device: cpu


In [6]:
num_epochs = 10  # Adjust this as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()  # Zero out gradients at the start

    # Wrap the train_loader with tqdm for progress indication
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
        for i, (images, labels) in enumerate(train_loader):
            # Move to CPU (or GPU)
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels) / accumulation_steps  # Normalize loss for accumulation

            # Backward pass
            loss.backward()

            # Gradient accumulation
            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()  # Reset gradients after each step

            # Update running loss
            running_loss += loss.item() * accumulation_steps

            # Update the progress bar
            pbar.update(1)
            pbar.set_postfix(loss=running_loss / ((i + 1) * accumulation_steps))

    # Print epoch loss at the end of each epoch
    print(f"Epoch [{epoch + 1}/{num_epochs}], Avg Loss: {running_loss / len(train_loader):.4f}")

Epoch 1/10:   0%|          | 0/148 [00:51<?, ?batch/s]


RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 9142272000 bytes.

In [25]:
def visualize_slice(image, label, prediction, slice_idx=60):
    img_slice = image[0, :, :, slice_idx]  # Show FLAIR modality
    label_slice = label[slice_idx, :, :]
    pred_slice = torch.argmax(prediction, dim=1)[0, slice_idx, :, :]

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.title("MRI Slice (FLAIR)")
    plt.imshow(img_slice, cmap="gray")

    plt.subplot(1, 3, 2)
    plt.title("Ground Truth")
    plt.imshow(label_slice, cmap="jet")

    plt.subplot(1, 3, 3)
    plt.title("Prediction")
    plt.imshow(pred_slice.cpu(), cmap="jet")

    plt.show()

In [None]:
# Save the model
torch.save(model.state_dict(), "unet3d_brain_tumor.pth")

# Load the model
model = UNet3D(in_channels=4, out_channels=4)
model.load_state_dict(torch.load("unet3d_brain_tumor.pth"))

In [None]:
model.eval()
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        val_loss = criterion(outputs, labels)

        print(f"Validation Loss: {val_loss.item():.4f}")

In [None]:
# Example visualization
visualize_slice(image.numpy(), label.numpy(), outputs)