In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import os
import numpy as np
import pydicom
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
import torchvision.models as models
warnings.filterwarnings('ignore')

print("🚀 Starting Femur Segmentation Training with TransUNet (ViT-B/16)")
print("=" * 60)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(f"✅ GPU Available: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ Using CPU")

# Dataset paths
raw_path = '/kaggle/input/unet-dataset/data/raw'
mask_path = '/kaggle/input/unet-dataset/data/mask'
image_size = (224, 224)

class SegmentationMetrics:
    @staticmethod
    def dice_coefficient(y_true, y_pred, smooth=1e-5):
        y_true_flat = y_true.flatten()
        y_pred_flat = y_pred.flatten()
        intersection = np.sum(y_true_flat * y_pred_flat)
        return (2.0 * intersection + smooth) / (np.sum(y_true_flat) + np.sum(y_pred_flat) + smooth)
    
    @staticmethod
    def iou_score(y_true, y_pred, smooth=1e-5):
        y_true_flat = y_true.flatten()
        y_pred_flat = y_pred.flatten()
        intersection = np.sum(y_true_flat * y_pred_flat)
        union = np.sum(y_true_flat) + np.sum(y_pred_flat) - intersection
        return (intersection + smooth) / (union + smooth)
    
    @classmethod
    def evaluate_batch(cls, y_true_batch, y_pred_batch, threshold=0.5):
        dice_scores = []
        iou_scores = []
        
        for i in range(len(y_true_batch)):
            y_true = y_true_batch[i]
            y_pred = (y_pred_batch[i] > threshold).astype(np.float32)
            
            dice = cls.dice_coefficient(y_true, y_pred)
            iou = cls.iou_score(y_true, y_pred)
            
            dice_scores.append(dice)
            iou_scores.append(iou)
        
        return {
            'dice_mean': np.mean(dice_scores),
            'dice_std': np.std(dice_scores),
            'iou_mean': np.mean(iou_scores),
            'iou_std': np.std(iou_scores),
            'dice_scores': dice_scores,
            'iou_scores': iou_scores
        }

class FemurDataset(Dataset):
    def __init__(self, image_files, mask_files):
        self.image_files = image_files
        self.mask_files = mask_files
        print(f"📊 Dataset initialized with {len(image_files)} samples")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        try:
            # Load and process image
            img_path = self.image_files[idx]
            ds_img = pydicom.dcmread(img_path)
            img = ds_img.pixel_array.astype(np.float32)
            img = cv2.resize(img, image_size)
            
            # Normalize
            img_min, img_max = img.min(), img.max()
            if img_max > img_min:
                img = (img - img_min) / (img_max - img_min)
            else:
                img = np.zeros_like(img)
            
            # Convert to 3 channels for ViT
            img = np.stack([img] * 3, axis=0)
            
            # Load and process mask
            mask_path = self.mask_files[idx]
            ds_mask = pydicom.dcmread(mask_path)
            mask = ds_mask.pixel_array.astype(np.float32)
            mask = cv2.resize(mask, image_size, interpolation=cv2.INTER_NEAREST)
            mask = (mask > 0).astype(np.float32)
            mask = np.expand_dims(mask, axis=0)

            return torch.tensor(img), torch.tensor(mask)
        
        except Exception as e:
            print(f"❌ Error loading sample {idx}: {e}")
            return torch.zeros(3, *image_size), torch.zeros(1, *image_size)

def collect_file_pairs(raw_root, mask_root, ext=".dcm"):
    print(f"🔍 Scanning for {ext} files...")
    
    if not os.path.exists(raw_root) or not os.path.exists(mask_root):
        print(f"❌ Paths don't exist: {raw_root}, {mask_root}")
        return [], []
    
    image_paths, mask_paths = [], []
    cases = sorted(os.listdir(raw_root))
    
    for case in tqdm(cases, desc="Processing cases"):
        raw_dir = os.path.join(raw_root, case)
        mask_dir = os.path.join(mask_root, case.replace("-input", "-seg"))
        
        if not os.path.isdir(raw_dir) or not os.path.isdir(mask_dir): 
            continue

        raw_files = sorted([f for f in os.listdir(raw_dir) if f.endswith(ext)])
        mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith(ext)])

        raw_map = {os.path.splitext(f)[0]: os.path.join(raw_dir, f) for f in raw_files}
        mask_map = {os.path.splitext(f)[0]: os.path.join(mask_dir, f) for f in mask_files}

        common = set(raw_map) & set(mask_map)
        image_paths.extend([raw_map[k] for k in sorted(common)])
        mask_paths.extend([mask_map[k] for k in sorted(common)])

    print(f"✅ Found {len(image_paths)} matched pairs")
    return image_paths, mask_paths

class TransUNet(nn.Module):
    def __init__(self, vit_pretrained=True):
        super().__init__()
        
        # Load pretrained ViT-B/16
        self.vit = models.vit_b_16(weights='IMAGENET1K_V1' if vit_pretrained else None)
        self.vit.heads = nn.Identity()  # Remove classification head
        
        # Decoder blocks
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )
        
        # Decoder layers
        self.up4 = nn.ConvTranspose2d(768, 512, 2, 2)
        self.dec4 = conv_block(768 + 512, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3 = conv_block(512 + 256, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = conv_block(256 + 128, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = conv_block(128 + 64, 64)
        
        # Final layer
        self.final = nn.Conv2d(64, 1, 1)
        
        # Skip connection adapters
        self.skip_conv1 = nn.Conv2d(3, 64, 1)
        self.skip_conv2 = nn.Conv2d(768, 128, 1)
        self.skip_conv3 = nn.Conv2d(768, 256, 1)
        self.skip_conv4 = nn.Conv2d(768, 512, 1)
        
        print(f"🏗️ TransUNet initialized with {sum(p.numel() for p in self.parameters()):,} parameters")

    def forward(self, x):
        # ViT encoder
        vit_features = self.vit(x)  # [batch, num_patches, 768]
        
        # Reshape to spatial dimensions (14x14 for 224x224 input)
        b = x.shape[0]
        vit_features = vit_features[:, 1:, :].view(b, 768, 14, 14)  # Exclude CLS token
        
        # Generate skip connections
        skip1 = self.skip_conv1(x)
        skip2 = self.skip_conv2(vit_features)
        skip3 = self.skip_conv3(vit_features)
        skip4 = self.skip_conv4(vit_features)
        
        # Decoder with skip connections
        d4 = self.dec4(torch.cat([self.up4(vit_features), skip4], 1))
        d3 = self.dec3(torch.cat([self.up3(d4), skip3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), skip2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), skip1], 1))
        
        return torch.sigmoid(self.final(d1))

class DiceBCELoss(nn.Module):
    def __init__(self, dice_weight=0.5, bce_weight=0.5):
        super().__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.bce = nn.BCELoss()

    def forward(self, y_pred, y_true, smooth=1e-5):
        # Flatten tensors
        y_pred_flat = y_pred.view(-1)
        y_true_flat = y_true.view(-1)
        
        # Dice loss
        intersection = (y_pred_flat * y_true_flat).sum()
        dice = (2. * intersection + smooth) / (y_pred_flat.sum() + y_true_flat.sum() + smooth)
        dice_loss = 1 - dice
        
        # BCE loss
        bce_loss = self.bce(y_pred_flat, y_true_flat)
        
        return self.dice_weight * dice_loss + self.bce_weight * bce_loss

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5):
    print(f"🚀 Starting training for {num_epochs} epochs")
    print("=" * 60)
    
    train_losses, val_losses, dice_scores = [], [], []
    best_dice = 0.0
    metrics_calc = SegmentationMetrics()
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss = 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        for imgs, masks in train_pbar:
            imgs, masks = imgs.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0
        all_predictions = []
        all_ground_truths = []
        
        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                imgs, masks = imgs.to(device), masks.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, masks)
                
                val_loss += loss.item()
                all_predictions.extend(outputs.cpu().numpy())
                all_ground_truths.extend(masks.cpu().numpy())
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        # Calculate metrics
        metrics = metrics_calc.evaluate_batch(all_ground_truths, all_predictions)
        current_dice = metrics['dice_mean']
        dice_scores.append(current_dice)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}")
        print(f"  Val Dice: {current_dice:.4f}")
        print(f"  Val IoU: {metrics['iou_mean']:.4f}")
        
        # Save best model
        if current_dice > best_dice:
            best_dice = current_dice
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"  ✅ New best model saved! (Dice: {best_dice:.4f})")
        
        print("-" * 60)
    
    return train_losses, val_losses, dice_scores

def visualize_results(model, val_loader, n_samples=6):
    model.eval()
    imgs, masks = next(iter(val_loader))
    with torch.no_grad():
        imgs_device = imgs.to(device)
        preds = model(imgs_device).cpu()
    
    fig, axes = plt.subplots(3, n_samples, figsize=(3*n_samples, 9))
    for i in range(min(n_samples, len(imgs))):
        img = imgs[i][0].numpy()
        mask = masks[i].squeeze().numpy()
        pred = preds[i].squeeze().numpy()
        
        axes[0, i].imshow(img, cmap='gray')
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(mask, cmap='gray')
        axes[1, i].set_title('Ground Truth')
        axes[1, i].axis('off')
        
        axes[2, i].imshow(pred, cmap='gray')
        axes[2, i].set_title(f'Prediction\nDice: {SegmentationMetrics.dice_coefficient(mask, pred > 0.5):.3f}')
        axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_training_curves(train_losses, val_losses, dice_scores):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(train_losses, label='Train Loss', color='blue')
    ax1.plot(val_losses, label='Val Loss', color='red')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training & Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(dice_scores, label='Dice Score', color='green')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Score')
    ax2.set_title('Validation Dice Score')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Main execution
if __name__ == "__main__":
    try:
        # Load dataset
        print("📂 Loading dataset...")
        raw_files, mask_files = collect_file_pairs(raw_path, mask_path)
        
        if len(raw_files) == 0:
            print("❌ No data found! Please check your dataset paths.")
            exit()
        
        # Split data
        train_images, val_images, train_masks, val_masks = train_test_split(
            raw_files, mask_files, test_size=0.2, random_state=42
        )
        
        # Create datasets and loaders
        train_dataset = FemurDataset(train_images, train_masks)
        val_dataset = FemurDataset(val_images, val_masks)
        
        batch_size = 4 if torch.cuda.is_available() else 2
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        
        print(f"📊 Training samples: {len(train_dataset)}")
        print(f"📊 Validation samples: {len(val_dataset)}")
        
        # Initialize model, loss, and optimizer
        model = TransUNet(vit_pretrained=True).to(device)
        criterion = DiceBCELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
        
        # Train model
        train_losses, val_losses, dice_scores = train_model(
            model, train_loader, val_loader, criterion, optimizer, num_epochs=5
        )
        
        # Load best model and evaluate
        model.load_state_dict(torch.load('best_model.pth'))
        
        # Final evaluation
        print("\n🔬 Final Evaluation:")
        model.eval()
        all_preds, all_truths = [], []
        with torch.no_grad():
            for imgs, masks in val_loader:
                outputs = model(imgs.to(device))
                all_preds.extend(outputs.cpu().numpy())
                all_truths.extend(masks.numpy())
        
        final_metrics = SegmentationMetrics.evaluate_batch(all_truths, all_preds)
        print(f"Final Dice Score: {final_metrics['dice_mean']:.4f} ± {final_metrics['dice_std']:.4f}")
        print(f"Final IoU Score: {final_metrics['iou_mean']:.4f} ± {final_metrics['iou_std']:.4f}")
        
        # Visualize results
        plot_training_curves(train_losses, val_losses, dice_scores)
        visualize_results(model, val_loader)
        
        print("\n🎉 Training completed successfully!")
        print(f"Best model saved as 'best_model.pth'")
        
    except Exception as e:
        print(f"❌ Error occurred: {e}")
        import traceback
        traceback.print_exc()

🚀 Starting Femur Segmentation Training with TransUNet (ViT-B/16)
⚠️ Using CPU
📂 Loading dataset...
🔍 Scanning for .dcm files...


Processing cases: 100%|██████████| 8/8 [00:00<00:00, 30.60it/s]


✅ Found 4289 matched pairs
📊 Dataset initialized with 3431 samples
📊 Dataset initialized with 858 samples
📊 Training samples: 3431
📊 Validation samples: 858


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 169MB/s]


🏗️ TransUNet initialized with 100,109,825 parameters
🚀 Starting training for 5 epochs


Epoch 1/5 [Train]:   0%|          | 0/1716 [00:01<?, ?it/s]

❌ Error occurred: too many indices for tensor of dimension 2



Traceback (most recent call last):
  File "/tmp/ipykernel_13/318753125.py", line 376, in <cell line: 0>
    train_losses, val_losses, dice_scores = train_model(
                                            ^^^^^^^^^^^^
  File "/tmp/ipykernel_13/318753125.py", line 243, in train_model
    outputs = model(imgs)
              ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_13/318753125.py", line 187, in forward
    vit_features = vit_features[:, 1:, :].view(b, 768, 14, 14)  # Exclude CLS token
                   ~~~~~~~~~~~~^^^^^^^^^^
IndexError: too many indices for tensor of dimension 2
