In [1]:
import tarfile
import os

# Path to the TAR file
tar_path = '/kaggle/input/brats-2021-task1/BraTS2021_Training_Data.tar'
extract_path = '/kaggle/working/BraTS2021_Training_Data'

# Extract the TAR file
with tarfile.open(tar_path, 'r') as tar:
    tar.extractall(path=extract_path)


In [2]:
import nibabel as nib
import numpy as np
import os


# Path to the TAR file
tar_path = '/kaggle/input/brats-2021-task1/BraTS2021_Training_Data.tar'
extract_path = '/kaggle/working/BraTS2021_Training_Data'



def load_nifti(file_path):
    """
    Load a NIfTI file and convert it to a NumPy array.
    """
    img = nib.load(file_path)
    return img.get_fdata()

def preprocess_modality(modality_path):
    """
    Normalize the modality data.
    """
    data = load_nifti(modality_path)
    # Normalize (z-score)
    mean, std = data.mean(), data.std()
    return (data - mean) / std





In [3]:
import torch
from torch.utils.data import Dataset

import os

# Dataset Class
class BraTSDataset(Dataset):
    def __init__(self, root_dir, num_classes, transform=None):
        self.root_dir = root_dir
        self.num_classes = num_classes
        self.transform = transform
        self.patients = [
            folder for folder in sorted(os.listdir(root_dir))
            if not folder.startswith('.')
        ]

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        patient_dir = os.path.join(self.root_dir, self.patients[idx])
    
        # Load modalities
        t1 = preprocess_modality(os.path.join(patient_dir, f"{self.patients[idx]}_t1.nii.gz"))
        t1ce = preprocess_modality(os.path.join(patient_dir, f"{self.patients[idx]}_t1ce.nii.gz"))
        t2 = preprocess_modality(os.path.join(patient_dir, f"{self.patients[idx]}_t2.nii.gz"))
        flair = preprocess_modality(os.path.join(patient_dir, f"{self.patients[idx]}_flair.nii.gz"))
    
        # Load segmentation and ensure valid range
        seg = load_nifti(os.path.join(patient_dir, f"{self.patients[idx]}_seg.nii.gz"))
        seg = torch.tensor(seg, dtype=torch.long)
        seg = torch.clamp(seg, min=0, max=self.num_classes - 1)  # Ensure valid range
    
        # One-hot encode
        seg_one_hot = F.one_hot(seg, num_classes=self.num_classes).permute(3, 0, 1, 2).float()
    
        # Combine modalities into a tensor
        image = np.stack([t1, t1ce, t2, flair], axis=0)
        image = torch.tensor(image, dtype=torch.float32)
    
        return image, seg_one_hot




In [4]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv3d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.projection(x)  # (B, embed_dim, D, H, W)
        self.d, self.h, self.w = x.shape[2:]  # Store dimensions for later use
        B, C, D, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, feedforward_dim):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, feedforward_dim),
            nn.ReLU(),
            nn.Linear(feedforward_dim, embed_dim)
        )
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        # Self-attention
        attn_output, _ = self.attention(x, x, x)
        x = self.layernorm1(x + attn_output)
        
        # Feedforward
        ff_output = self.feedforward(x)
        x = self.layernorm2(x + ff_output)
        return x

class TransformerSegmentation(nn.Module):
    def __init__(
        self,
        input_channels,
        patch_size,
        embed_dim,
        depth,
        heads,
        mlp_dim,
        dropout_rate=0.1,
        attn_dropout_rate=0.1,
        num_classes=4,
    ):
        super().__init__()
        self.embedding = PatchEmbedding(
            in_channels=input_channels, patch_size=patch_size, embed_dim=embed_dim
        )
        self.transformer = nn.Sequential(
            *[
                TransformerBlock(embed_dim, heads, mlp_dim)
                for _ in range(depth)
            ]
        )
        self.output_projection = nn.Conv3d(embed_dim, num_classes, kernel_size=1)

    def forward(self, x):
        # Embed patches
        x = self.embedding(x)

        # Process with transformer
        x = self.transformer(x)

        # Reshape back to spatial dimensions
        B, num_patches, embed_dim = x.shape
        x = x.transpose(1, 2).reshape(B, embed_dim, self.embedding.d, self.embedding.h, self.embedding.w)
        x = self.output_projection(x)  # Final segmentation output
        return x



In [5]:
from torch.optim import Adam
import torch.nn.functional as F


class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, logits, targets):
        smooth = 1.0

        # Apply softmax to logits to get probabilities
        probs = torch.softmax(logits, dim=1)

        # Compute Dice loss
        intersection = (probs * targets).sum(dim=(2, 3, 4))
        union = probs.sum(dim=(2, 3, 4)) + targets.sum(dim=(2, 3, 4))
        dice = (2.0 * intersection + smooth) / (union + smooth)

        return 1 - dice.mean()


def train_transbts(model, train_loader, val_loader, epochs=20, lr=1e-4, device='cuda'):
    model.to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    criterion = CombinedLoss()
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            
            outputs = model(images)
            outputs = F.interpolate(outputs, size=masks.shape[1:], mode="trilinear", align_corners=False)
            loss = criterion(outputs, masks)
            train_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}")

In [6]:


       

# Training Function
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0

    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        outputs = F.interpolate(outputs, size=masks.shape[2:], mode="trilinear", align_corners=False)
        
        # Compute loss
        loss = criterion(outputs, masks)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(loader)

In [7]:


def evaluate(model, loader, criterion, device):
    model.eval()
    val_loss = 0
    dice_coeff = 0

    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)
            outputs = F.interpolate(outputs, size=masks.shape[2:], mode="trilinear", align_corners=False)

            # Compute validation loss
            val_loss += criterion(outputs, masks).item()

            # Dice coefficient
            probs = torch.softmax(outputs, dim=1)
            preds = (probs > 0.5).float()
            intersection = (preds * masks).sum()
            dice_coeff += (2.0 * intersection) / (preds.sum() + masks.sum() + 1e-5)

    return val_loss / len(loader), dice_coeff / len(loader)



In [8]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from torch.utils.data import Subset
import random


# Split the dataset into training and validation sets
# Define the number of classes in your segmentation task
num_classes = 4  # Replace 4 with the actual number of classes in your dataset

# Initialize the dataset
dataset = BraTSDataset(root_dir=extract_path, num_classes=num_classes)
train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)

# Select a subset of the dataset
#subset_size = 10  # Number of samples to use for debugging
#indices = random.sample(range(len(dataset)), subset_size)  # Randomly sample indices
#debug_dataset = Subset(dataset, indices)

train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

# Create DataLoaders
#train_loader = DataLoader(debug_dataset, batch_size=1, shuffle=True, num_workers=4)
#val_loader = DataLoader(debug_dataset, batch_size=1, shuffle=False, num_workers=4)


# Debug the DataLoader
for images, masks in train_loader:
    print("Image batch shape:", images.shape)
    print("Mask batch shape:", masks.shape)
    print("Mask unique values:", torch.unique(masks))
    break  # Debug only one batch



Image batch shape: torch.Size([1, 4, 240, 240, 155])
Mask batch shape: torch.Size([1, 4, 240, 240, 155])
Mask unique values: tensor([0., 1.])


In [9]:



from torch.optim import Adam

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerSegmentation(
    input_channels=4, patch_size=4, embed_dim=128, depth=6, heads=8, mlp_dim=512, num_classes=4
).to(device)

# Free GPU cache
torch.cuda.empty_cache()

# Train with reduced batch size
batch_size = 1


# Initialize optimizer and loss function
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = DiceLoss()

# Training loop
num_epochs = 13
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, dice_score = evaluate(model, val_loader, criterion, device)
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Dice Score: {dice_score:.4f}")


Epoch 1/13
Train Loss: 0.4740
Val Loss: 0.4189, Dice Score: 0.9930
Epoch 2/13
Train Loss: 0.4134
Val Loss: 0.4052, Dice Score: 0.9913
Epoch 3/13
Train Loss: 0.4024
Val Loss: 0.3720, Dice Score: 0.9934
Epoch 4/13
Train Loss: 0.3820
Val Loss: 0.3590, Dice Score: 0.9935
Epoch 5/13
Train Loss: 0.3653
Val Loss: 0.3563, Dice Score: 0.9938
Epoch 6/13
Train Loss: 0.3587
Val Loss: 0.3473, Dice Score: 0.9950
Epoch 7/13
Train Loss: 0.3462
Val Loss: 0.3600, Dice Score: 0.9913
Epoch 8/13
Train Loss: 0.3474
Val Loss: 0.3434, Dice Score: 0.9949
Epoch 9/13
Train Loss: 0.3423
Val Loss: 0.3778, Dice Score: 0.9922
Epoch 10/13
Train Loss: 0.3361
Val Loss: 0.3187, Dice Score: 0.9949
Epoch 11/13
Train Loss: 0.3308
Val Loss: 0.3258, Dice Score: 0.9947
Epoch 12/13
Train Loss: 0.3384
Val Loss: 0.3134, Dice Score: 0.9953
Epoch 13/13
Train Loss: 0.3200
Val Loss: 0.3405, Dice Score: 0.9942


In [10]:
torch.save(model.state_dict(), '/kaggle/working/transbts_model.pth')
print("Model saved!")


Model saved!


In [11]:
# Load the model
model.load_state_dict(torch.load('/kaggle/working/transbts_model.pth'))
model.eval()


  model.load_state_dict(torch.load('/kaggle/working/transbts_model.pth'))


TransformerSegmentation(
  (embedding): PatchEmbedding(
    (projection): Conv3d(4, 128, kernel_size=(4, 4, 4), stride=(4, 4, 4))
  )
  (transformer): Sequential(
    (0): TransformerBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (feedforward): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (layernorm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (feedforward): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_feat