In [1]:
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import matplotlib.pyplot as plt
import datetime
import json
from torch.cuda.amp import GradScaler, autocast

# -------------------------------
# 1. Verify GPU Availability
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU devices: {torch.cuda.get_device_name(0)}")
    if torch.cuda.device_count() > 1:
        print(f"Additional GPUs: {[torch.cuda.get_device_name(i) for i in range(1, torch.cuda.device_count())]}")

# -------------------------------
# 2. Define Custom Metrics
# -------------------------------
def dice_coefficient(y_true, y_pred, smooth=1):
    """
    Calculates the Dice Coefficient.
    """
    y_true_f = y_true.view(-1)
    y_pred_f = (torch.sigmoid(y_pred) > 0.5).float().view(-1)  # Apply sigmoid and threshold at 0.5
    intersection = (y_true_f * y_pred_f).sum()
    return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)

def iou_metric(y_true, y_pred, smooth=1):
    """
    Calculates the Intersection over Union (IoU) metric.
    """
    y_true_f = y_true.view(-1)
    y_pred_f = (torch.sigmoid(y_pred) > 0.5).float().view(-1)  # Apply sigmoid and threshold at 0.5
    intersection = (y_true_f * y_pred_f).sum()
    union = y_true_f.sum() + y_pred_f.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def f1_score(y_true, y_pred, threshold=0.5, eps=1e-7):
    """
    Calculates the F1 Score using Precision and Recall.
    """
    y_pred = (torch.sigmoid(y_pred) > threshold).float()  # Apply sigmoid and threshold at 0.5
    y_true_f = y_true.view(-1)
    y_pred_f = y_pred.view(-1)
    true_positives = (y_true_f * y_pred_f).sum()
    precision = true_positives / (y_pred_f.sum() + eps)
    recall = true_positives / (y_true_f.sum() + eps)
    return 2 * (precision * recall) / (precision + recall + eps)

# -------------------------------
# 3. Define the U-Net Model
# -------------------------------
class UNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1):
        super(UNet, self).__init__()
        
        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True)
            )
        
        def upconv_block(in_ch, out_ch):
            return nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        
        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.bottleneck = conv_block(512, 1024)
        
        # Decoder
        self.up4 = upconv_block(1024, 512)
        self.dec4 = conv_block(1024, 512)
        self.up3 = upconv_block(512, 256)
        self.dec3 = conv_block(512, 256)
        self.up2 = upconv_block(256, 128)
        self.dec2 = conv_block(256, 128)
        self.up1 = upconv_block(128, 64)
        self.dec1 = conv_block(128, 64)
        
        # Output
        self.out_conv = nn.Conv2d(64, out_channels, 1)
        # No sigmoid here; BCEWithLogitsLoss applies it internally
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(nn.MaxPool2d(2)(e1))
        e3 = self.enc3(nn.MaxPool2d(2)(e2))
        e4 = self.enc4(nn.MaxPool2d(2)(e3))
        b = self.bottleneck(nn.MaxPool2d(2)(e4))
        
        # Decoder
        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        out = self.out_conv(d1)
        return out

# -------------------------------
# 4. Grad-CAM Implementation
# -------------------------------
class GradCAM:
    def __init__(self, model, layer_name):
        self.model = model
        self.layer_name = layer_name
        self.gradients = None
        self.activations = None
        
    def save_gradient(self, grad):
        self.gradients = grad
        
    def forward_hook(self, module, input, output):
        self.activations = output
        output.register_hook(self.save_gradient)
        
    def __call__(self, x):
        self.model.eval()
        self.gradients = None
        self.activations = None
        
        # Register hook
        target_layer = dict(self.model.named_modules())[self.layer_name]
        hook = target_layer.register_forward_hook(self.forward_hook)
        
        # Forward pass
        x = x.to(device)
        output = self.model(x)
        
        # Backward pass
        self.model.zero_grad()
        loss = output.mean()
        loss.backward()
        
        # Compute Grad-CAM
        pooled_grads = torch.mean(self.gradients, dim=[0, 2, 3])
        heatmap = torch.mean(self.activations * pooled_grads[None, :, None, None], dim=1)
        heatmap = nn.functional.relu(heatmap)
        heatmap /= torch.max(heatmap) + 1e-8
        
        hook.remove()
        return heatmap.cpu().detach().numpy()

def overlay_gradcam(img, heatmap, alpha=0.4):
    """
    Overlays the Grad-CAM heatmap on the input image.
    """
    heatmap = resize(heatmap[0], img.shape[:2], preserve_range=True)
    heatmap = np.uint8(255 * heatmap)
    jet = plt.get_cmap('jet')
    heatmap = jet(heatmap)
    heatmap = heatmap[..., :3]  # Remove alpha channel
    superimposed_img = heatmap * alpha + img[..., :3]
    superimposed_img = np.clip(superimposed_img, 0, 1)
    return superimposed_img

# -------------------------------
# 5. Define the Dataset
# -------------------------------
class BraTSDataset(Dataset):
    def __init__(self, h5_files, data_dir, dim=(128, 128)):
        self.h5_files = h5_files
        self.data_dir = data_dir
        self.dim = dim
        
    def __len__(self):
        return len(self.h5_files)
        
    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.h5_files[idx])
        try:
            with h5py.File(file_path, 'r') as hf:
                if 'image' in hf.keys() and 'mask' in hf.keys():
                    image = hf['image'][:]
                    mask = hf['mask'][:]
                else:
                    raise KeyError(f"Unexpected keys in {file_path}: {list(hf.keys())}")

            if mask.ndim > 2:
                mask = np.mean(mask, axis=-1)

            if image.shape[:2] != self.dim:
                image = resize(
                    image, 
                    (*self.dim, image.shape[2]), 
                    preserve_range=True, 
                    anti_aliasing=True
                )
            
            if mask.shape != self.dim:
                mask = resize(
                    mask, 
                    self.dim, 
                    preserve_range=True, 
                    order=0, 
                    anti_aliasing=False
                )

            image_max = np.max(image)
            if image_max > 0:
                image = image.astype(np.float32) / image_max
            else:
                image = image.astype(np.float32)

            mask_max = np.max(mask)
            if mask_max > 0:
                mask = mask.astype(np.float32) / mask_max
            else:
                mask = mask.astype(np.float32)

            image = torch.from_numpy(image).permute(2, 0, 1).float()
            mask = torch.from_numpy(mask).unsqueeze(0).float()
            
            return image, mask
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
            return torch.zeros((4, *self.dim)), torch.zeros((1, *self.dim))

# -------------------------------
# 6. Prepare the Data
# -------------------------------
data_dir = '/kaggle/input/brats2020-training-data/BraTS2020_training_data/content/data'
h5_files = [f for f in os.listdir(data_dir) if f.endswith('.h5')]

if not h5_files:
    raise FileNotFoundError(f"No .h5 files found in the directory: {data_dir}")

train_files, val_files = train_test_split(h5_files, test_size=0.2, random_state=42)

print(f"Total samples: {len(h5_files)}")
print(f"Training samples: {len(train_files)}")
print(f"Validation samples: {len(val_files)}")

batch_size = 8
image_dim = (128, 128)

train_dataset = BraTSDataset(train_files, data_dir, dim=image_dim)
val_dataset = BraTSDataset(val_files, data_dir, dim=image_dim)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# -------------------------------
# 7. Build and Initialize the Model
# -------------------------------
model = UNet(in_channels=4, out_channels=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()
scaler = GradScaler()

# -------------------------------
# 8. Define Training and Validation Loops
# -------------------------------
def train_epoch(model, loader, optimizer, criterion, scaler):
    model.train()
    running_loss = 0.0
    metrics = {'dice': 0.0, 'iou': 0.0, 'f1': 0.0}
    count = 0
    
    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, masks)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item() * images.size(0)
        with torch.no_grad():
            metrics['dice'] += dice_coefficient(masks, outputs).item() * images.size(0)
            metrics['iou'] += iou_metric(masks, outputs).item() * images.size(0)
            metrics['f1'] += f1_score(masks, outputs).item() * images.size(0)
        count += images.size(0)
    
    epoch_loss = running_loss / count
    metrics = {k: v / count for k, v in metrics.items()}
    return epoch_loss, metrics

def validate_epoch(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    metrics = {'dice': 0.0, 'iou': 0.0, 'f1': 0.0}
    count = 0
    
    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item() * images.size(0)
            metrics['dice'] += dice_coefficient(masks, outputs).item() * images.size(0)
            metrics['iou'] += iou_metric(masks, outputs).item() * images.size(0)
            metrics['f1'] += f1_score(masks, outputs).item() * images.size(0)
            count += images.size(0)
    
    epoch_loss = running_loss / count
    metrics = {k: v / count for k, v in metrics.items()}
    return epoch_loss, metrics

# -------------------------------
# 9. Training Loop
# -------------------------------
epochs = 10
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'train_dice': [], 'val_dice': [], 
           'train_iou': [], 'val_iou': [], 'train_f1': [], 'val_f1': []}

for epoch in range(epochs):
    try:
        train_loss, train_metrics = train_epoch(model, train_loader, optimizer, criterion, scaler)
        val_loss, val_metrics = validate_epoch(model, val_loader, criterion)
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f}, Dice: {train_metrics['dice']:.4f}, IoU: {train_metrics['iou']:.4f}, F1: {train_metrics['f1']:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Dice: {val_metrics['dice']:.4f}, IoU: {val_metrics['iou']:.4f}, F1: {val_metrics['f1']:.4f}")
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_dice'].append(train_metrics['dice'])
        history['val_dice'].append(val_metrics['dice'])
        history['train_iou'].append(train_metrics['iou'])
        history['val_iou'].append(val_metrics['iou'])
        history['train_f1'].append(train_metrics['f1'])
        history['val_f1'].append(val_metrics['f1'])
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'model-unet.best.pth')
            print("Saved best model to 'model-unet.best.pth'")
    except Exception as e:
        print(f"An error occurred during epoch {epoch+1}: {e}")
        break

# -------------------------------
# 10. Save the Final Model
# -------------------------------
try:
    torch.save(model.state_dict(), 'final_model_unet.pth')
    print("Model training complete and saved as 'final_model_unet.pth'.")
except Exception as e:
    print(f"Error saving final model: {e}")

# -------------------------------
# 11. Save Training History
# -------------------------------
try:
    with open('training_history.json', 'w') as f:
        json.dump(history, f)
    print("Training history saved as 'training_history.json'.")
except Exception as e:
    print(f"Error saving training history: {e}")

# -------------------------------
# 12. Plot Training History
# -------------------------------
def plot_training_history(history, metrics=['loss', 'dice', 'iou', 'f1']):
    """
    Plots training and validation metrics over epochs.
    """
    for metric in metrics:
        try:
            plt.figure(figsize=(8, 6))
            plt.plot(history[f'train_{metric}'], label=f'Training {metric}')
            plt.plot(history[f'val_{metric}'], label=f'Validation {metric}')
            plt.title(f'Training and Validation {metric.capitalize()}')
            plt.xlabel('Epoch')
            plt.ylabel(metric.capitalize())
            plt.legend()
            plt.grid(True)
            plt.savefig(f'{metric}_plot.png')
            plt.close()
            print(f"{metric.capitalize()} plot saved as '{metric}_plot.png'.")
        except Exception as e:
            print(f"Error plotting {metric}: {e}")

plot_training_history(history)

# -------------------------------
# 13. Preprocess Image for Prediction
# -------------------------------
def preprocess_image(file_path, image_dim=(128, 128)):
    """
    Loads and preprocesses an image and its corresponding mask from an h5 file.
    """
    try:
        with h5py.File(file_path, 'r') as hf:
            if 'image' in hf.keys() and 'mask' in hf.keys():
                image = hf['image'][:]
                mask = hf['mask'][:]
            else:
                raise KeyError(f"Unexpected keys in {file_path}: {list(hf.keys())}")

        print(f"Mask stats for {file_path}: min={mask.min()}, max={mask.max()}, mean={mask.mean()}")

        if mask.ndim > 2:
            mask = np.mean(mask, axis=-1)

        if image.shape[:2] != image_dim:
            image = resize(image, (*image_dim, image.shape[2]), preserve_range=True, anti_aliasing=True)
        if mask.shape != image_dim:
            mask = resize(mask, image_dim, preserve_range=True, order=0, anti_aliasing=False)

        image_max = np.max(image)
        if image_max > 0:
            image = image.astype(np.float32) / image_max

        mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)

        return image, mask
    except Exception as e:
        print(f"Error preprocessing {file_path}: {e}")
        return None, None

# -------------------------------
# 14. Predict and Visualize with Grad-CAM
# -------------------------------
def predict_and_visualize(file_paths, model, image_dim=(128, 128)):
    """
    Loads, preprocesses, predicts, and visualizes results with Grad-CAM.
    """
    model.eval()
    grad_cam = GradCAM(model, 'dec1')
    
    for i, file_path in enumerate(file_paths):
        try:
            image, true_mask = preprocess_image(file_path, image_dim)
            if image is None or true_mask is None:
                continue

            img_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
            with torch.no_grad():
                pred_mask = torch.sigmoid(model(img_tensor)).cpu().numpy()[0]  # Apply sigmoid to logits
                pred_mask_binary = (pred_mask > 0.5).astype(np.float32)  # Threshold to get binary mask

            # Generate Grad-CAM heatmap
            heatmap = grad_cam(img_tensor)
            superimposed_img = overlay_gradcam(image, heatmap)

            # Visualize
            plt.figure(figsize=(16, 6))

            plt.subplot(1, 4, 1)
            plt.imshow(image[..., 0], cmap='gray')
            plt.title('Input Image (Channel 0)')
            plt.axis('off')

            plt.subplot(1, 4, 2)
            plt.imshow(true_mask, cmap='gray')
            plt.title('True Mask')
            plt.axis('off')

            plt.subplot(1, 4, 3)
            plt.imshow(pred_mask_binary.squeeze(), cmap='gray')
            plt.title('Predicted Mask')
            plt.axis('off')

            plt.subplot(1, 4, 4)
            plt.imshow(superimposed_img)
            plt.title('Grad-CAM Overlay')
            plt.axis('off')

            plt.savefig(f'visualization_{i+1}.png')
            plt.close()
            print(f"Visualization for sample {i+1} saved as 'visualization_{i+1}.png'")
        except Exception as e:
            print(f"Error visualizing {file_path}: {e}")

# -------------------------------
# 15. Specify File Paths and Execute
# -------------------------------
files_to_predict = [
    "/kaggle/input/brats2020-training-data/BraTS2020_training_data/content/data/volume_100_slice_100.h5"
]
try:
    predict_and_visualize(files_to_predict, model)
except Exception as e:
    print(f"Error during prediction and visualization: {e}")

# -------------------------------
# 16. List Output Files
# -------------------------------
print("Files in working directory:", os.listdir('/kaggle/working'))

Using device: cuda
GPU devices: Tesla T4
Additional GPUs: ['Tesla T4']
Total samples: 57195
Training samples: 45756
Validation samples: 11439


  scaler = GradScaler()
  with autocast():


Epoch 1/10
Train Loss: 0.0184, Dice: 0.7216, IoU: 0.6082, F1: 0.7124
Val Loss: 0.0086, Dice: 0.8119, IoU: 0.7083, F1: 0.8057
Saved best model to 'model-unet.best.pth'
Epoch 2/10
Train Loss: 0.0078, Dice: 0.8321, IoU: 0.7357, F1: 0.8225
Val Loss: 0.0068, Dice: 0.8440, IoU: 0.7526, F1: 0.8330
Saved best model to 'model-unet.best.pth'
Epoch 3/10
Train Loss: 0.0063, Dice: 0.8609, IoU: 0.7746, F1: 0.8494
Val Loss: 0.0063, Dice: 0.8568, IoU: 0.7716, F1: 0.8454
Saved best model to 'model-unet.best.pth'
Epoch 4/10
Train Loss: 0.0056, Dice: 0.8728, IoU: 0.7920, F1: 0.8626
Val Loss: 0.0052, Dice: 0.8796, IoU: 0.8019, F1: 0.8682
Saved best model to 'model-unet.best.pth'
Epoch 5/10
Train Loss: 0.0051, Dice: 0.8853, IoU: 0.8090, F1: 0.8752
Val Loss: 0.0048, Dice: 0.8888, IoU: 0.8156, F1: 0.8778
Saved best model to 'model-unet.best.pth'
Epoch 6/10
Train Loss: 0.0047, Dice: 0.8938, IoU: 0.8211, F1: 0.8833
Val Loss: 0.0047, Dice: 0.8854, IoU: 0.8108, F1: 0.8748
Saved best model to 'model-unet.best.pth