In [None]:
import subprocess
import sys
import os # 'os' is needed for script logic, so we import it here.

def install_and_import(package, import_name=None):
    """
    Tries to import a package. If it fails, attempts to install it via pip
    and then tries to import it again.
    """
    if import_name is None:
        import_name = package

    try:
        __import__(import_name)
        print(f"{import_name} is already installed.")
    except ImportError:
        print(f"{import_name} not found. Attempting to install {package}...")
        try:
            # Use sys.executable to ensure pip is called for the correct python env
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            print(f"Successfully installed {package}.")
            __import__(import_name)
        except Exception as e:
            print(f"Error: Failed to install {package}. {e}")
            print(f"Please install {package} manually by running: pip install {package}")
            sys.exit(1)

def setup_dependencies():
    """
    Checks and installs all required dependencies.
    """
    print("Checking dependencies...")
    install_and_import("torch")
    install_and_import("numpy", "numpy")
    install_and_import("einops", "einops")
    install_and_import("Pillow", "PIL")
    install_and_import("torchvision", "torchvision")
    print("All dependencies are set up.")

# --- Run Dependency Setup ---
# This block will run when the script is executed.
# It ensures all required packages are installed before they are imported.
setup_dependencies()

# --- Main Script Imports ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import time
import zipfile
import urllib.request
import glob
from PIL import Image
from torchvision import transforms

# --- Helper Functions & Classes ---

class DiceLoss(nn.Module):
    """
    A common loss function for segmentation tasks.
    It measures the overlap between the predicted mask and the ground truth.
    """
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        # Apply sigmoid to get probabilities
        inputs = torch.sigmoid(inputs)

        # Flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

        return 1 - dice

def dice_coefficient(inputs, targets, smooth=1.0):
    """
    Calculates the Dice Coefficient (a metric, not a loss)
    """
    # Apply sigmoid and threshold to get binary predictions
    inputs = torch.sigmoid(inputs)
    inputs = (inputs > 0.5).float()

    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()
    dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

    return dice.item()

# --- U-Net Model Architecture ---
#
class DoubleConv(nn.Module):
    """(Convolution => BatchNorm => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with MaxPool then DoubleConv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then DoubleConv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        # if bilinear, use the normal interpolation to reduce artifacts
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # x2 is the skip connection
        # Concatenate along the channel dimension
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    """
    A standard U-Net implementation.
    The number of output channels is 1 (for a binary mask).
    """
    def __init__(self, n_channels=1, n_classes=1, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# --- Vision Transformer (ViT) Segmentation Model ---
#
class PatchEmbedding(nn.Module):
    """
    Splits the image into patches and embeds them.
    """
    def __init__(self, img_size=128, patch_size=16, in_channels=1, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # This layer does the patching and embedding in one go
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)  # (B, E, H_p, W_p)
        x = x.flatten(2)  # (B, E, N_p)
        x = x.transpose(1, 2)  # (B, N_p, E)
        return x

class ViTSeg(nn.Module):
    """
    A simplified Vision Transformer for Segmentation.
    This model uses a Transformer Encoder followed by a simple
    convolutional decoder to upscale the features back to image size.
    """
    def __init__(self, img_size=128, patch_size=16, in_channels=1, n_classes=1,
                 embed_dim=768, depth=12, n_heads=12, mlp_dim=3072):
        super(ViTSeg, self).__init__()

        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.n_patches = (img_size // patch_size) ** 2

        # 1. Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

        # 2. Positional Embedding
        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches, embed_dim))

        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=n_heads, dim_feedforward=mlp_dim, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # 4. Simple Convolutional Decoder
        # Reshape (B, N_p, E) -> (B, E, H_p, W_p)
        self.patch_h = img_size // patch_size
        self.patch_w = img_size // patch_size

        # Use ConvTranspose2d to upsample
        # We need to upsample by a factor of `patch_size` (e.g., 16)
        # We can do this in steps (e.g., 4x, then 4x)
        self.decoder = nn.Sequential(
            nn.Conv2d(embed_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
        )
        # Final conv to get to the number of classes
        self.out_conv = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x):
        # 1. Patch Embedding
        x = self.patch_embed(x)  # (B, N_p, E)

        # 2. Add Positional Embedding
        x = x + self.pos_embed

        # 3. Transformer Encoder
        x = self.transformer_encoder(x)  # (B, N_p, E)

        # 4. Decoder
        # Reshape for decoder: (B, N_p, E) -> (B, E, H_p, W_p)
        x = x.transpose(1, 2)
        x = x.view(x.shape[0], self.embed_dim, self.patch_h, self.patch_w)

        # Upsample
        x = self.decoder(x)

        # Final output conv
        logits = self.out_conv(x)

        return logits


# --- Data Download and Loading ---

def download_and_extract_dataset():
    """
    Downloads and extracts the 2D brain tumor segmentation dataset from Zenodo.
    Uses Zenodo record 12735702 as a reliable source.
    """
    # *** NEW, WORKING DATASET URL ***
    dataset_url = "https://zenodo.org/records/12735702/files/brain-tumor-mri-dataset.zip?download=1"
    data_dir = "./"
    zip_path = os.path.join(data_dir, "brain-tumor-mri-dataset.zip")
    # This is the folder name *inside* the zip file
    extract_path = os.path.join(data_dir, "brain-tumor-mri-dataset")

    if os.path.exists(extract_path):
        print(f"Dataset already found at {extract_path}")
        return extract_path

    print("Downloading dataset (156MB)... (This may take a few minutes)")
    try:
        # Use a context manager to ensure the request is closed
        with urllib.request.urlopen(dataset_url) as response:
            # Check if response is successful
            if response.status != 200:
                print(f"Error downloading dataset: HTTP Status {response.status}")
                return None

            with open(zip_path, 'wb') as out_file:
                out_file.write(response.read())
        print("Download complete.")
    except Exception as e:
        print(f"Error downloading dataset: {e}")
        # Clean up partial download if it exists
        if os.path.exists(zip_path):
            os.remove(zip_path)
        return None

    print("Extracting dataset...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(data_dir)
        print("Extraction complete.")
    except Exception as e:
        print(f"Error extracting dataset: {e}")
        return None

    # Clean up the zip file
    try:
        os.remove(zip_path)
        print(f"Removed zip file: {zip_path}")
    except OSError as e:
        print(f"Error removing zip file: {e}")

    return extract_path

# We need a custom dataset class to apply different transforms
# to images (normalize) and masks (don't normalize)
class CustomDataset(Dataset):
    """
    Loads the downloaded 2D MRI dataset from Zenodo record 12735702.
    Applies separate transforms for images and masks.
    """
    def __init__(self, data_dir, image_transform=None, mask_transform=None):
        # *** UPDATED PATHS to match the new zip file's structure ***
        image_dir = os.path.join(data_dir, "images")
        mask_dir = os.path.join(data_dir, "masks")

        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*.png")))
        self.image_transform = image_transform
        self.mask_transform = mask_transform

        # Sanity check
        if not self.image_paths or not self.mask_paths:
            print(f"Error: No images or masks found in {data_dir}. Check directory structure.")
            print(f"Looking in: {image_dir} and {mask_dir}")
            raise FileNotFoundError(f"Dataset files not found in {image_dir} or {mask_dir}")

        assert len(self.image_paths) == len(self.mask_paths), \
            "Number of images and masks do not match!"
        print(f"Found {len(self.image_paths)} image/mask pairs.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            # Load image and mask
            image = Image.open(image_path).convert("L")  # Convert to grayscale
            mask = Image.open(mask_path).convert("L")   # Convert to grayscale
        except Exception as e:
            print(f"Error loading image/mask at index {idx}: {e}")
            print(f"Image path: {image_path}, Mask path: {mask_path}")
            # Return a dummy tensor to avoid crashing the loader
            return torch.zeros(1, 128, 128), torch.zeros(1, 128, 128)

        # Apply transforms
        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        # Binarize mask (some masks might have anti-aliasing)
        mask = (mask > 0.5).float()

        return image, mask

# --- Training and Evaluation Logic ---

def train_model(model, dataloader, criterion, optimizer, device, num_epochs=3):
    """
    A simple training loop.
    """
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        epoch_start = time.time()
        running_loss = 0.0

        for i, (images, masks) in enumerate(dataloader):
            # Check for dummy data from loader errors
            if images.shape[0] == 0:
                continue

            images = images.to(device)
            masks = masks.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if (i + 1) % 20 == 0 or i == len(dataloader) - 1:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")


        epoch_loss = running_loss / len(dataloader)
        epoch_time = time.time() - epoch_start
        print(f"Epoch [{epoch+1}/{num_epochs}], Avg. Loss: {epoch_loss:.4f}, Time: {epoch_time:.2f}s")

    print("Finished Training.")
    return model

def evaluate_model(model, dataloader, device):
    """
    A simple evaluation loop to calculate the average Dice Coefficient.
    """
    model.to(device)
    model.eval()

    total_dice = 0.0

    with torch.no_grad():
        for images, masks in dataloader:
            if images.shape[0] == 0:
                continue

            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)

            dice = dice_coefficient(outputs, masks)
            total_dice += dice

    # Handle empty dataloader case
    if len(dataloader) == 0:
        print("Warning: Empty validation dataloader.")
        return 0.0

    avg_dice = total_dice / len(dataloader)
    return avg_dice

# --- Main Comparison Function ---

def main():
    # --- 1. Configuration ---
    # Hyperparameters
    NUM_EPOCHS = 3
    BATCH_SIZE = 8
    LEARNING_RATE = 1e-4
    IMG_SIZE = 128

    # ViT specific parameters
    # Note: These are small to make the model runnable on most systems.
    # Real ViTs are much larger (e.g., embed_dim=768, depth=12)
    VIT_PATCH_SIZE = 16
    VIT_EMBED_DIM = 256  # Reduced from 768 for demo
    VIT_DEPTH = 4        # Reduced from 12 for demo
    VIT_HEADS = 8        # Reduced from 12 for demo
    VIT_MLP_DIM = 1024   # Reduced from 3072 for demo

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # --- 2. Data ---
    print("Checking for dataset...")
    data_dir = download_and_extract_dataset()
    if data_dir is None:
        print("Failed to get dataset. Exiting.")
        return

    # Define transformations
    # We normalize images to [-1, 1] range for better model stability
    data_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    # Masks should just be resized and converted to a tensor
    mask_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor()
    ])

    # Create full dataset
    try:
        full_dataset = CustomDataset(data_dir, image_transform=data_transform, mask_transform=mask_transform)
    except FileNotFoundError as e:
        print(e)
        return

    if len(full_dataset) == 0:
        print("Dataset loaded 0 samples. Exiting.")
        return

    # Split into training and validation
    # Using a simple 80/20 split
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size

    # Check if dataset is large enough
    if train_size == 0 or val_size == 0:
        print("Dataset is too small to split. Using all data for training.")
        train_dataset = full_dataset
        val_dataset = full_dataset # Not ideal, but prevents crash
    else:
        train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

    print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples.")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # --- 3. Models ---
    print("Initializing models...")
    # U-Net
    model_unet = UNet(n_channels=1, n_classes=1)

    # Vision Transformer
    model_vit = ViTSeg(
        img_size=IMG_SIZE,
        patch_size=VIT_PATCH_SIZE,
        embed_dim=VIT_EMBED_DIM,
        depth=VIT_DEPTH,
        n_heads=VIT_HEADS,
        mlp_dim=VIT_MLP_DIM
    )

    # --- 4. Loss and Optimizers ---
    criterion = DiceLoss()
    optimizer_unet = optim.Adam(model_unet.parameters(), lr=LEARNING_RATE)
    optimizer_vit = optim.Adam(model_vit.parameters(), lr=LEARNING_RATE)

    # --- 5. Train and Evaluate U-Net ---
    print("\n--- Training U-Net ---")
    start_unet = time.time()
    model_unet = train_model(model_unet, train_loader, criterion, optimizer_unet, device, num_epochs=NUM_EPOCHS)
    time_unet = time.time() - start_unet

    print("Evaluating U-Net...")
    dice_unet = evaluate_model(model_unet, val_loader, device)

    # --- 6. Train and Evaluate ViT ---
    print("\n--- Training Vision Transformer (ViT-Seg) ---")
    start_vit = time.time()
    model_vit = train_model(model_vit, train_loader, criterion, optimizer_vit, device, num_epochs=NUM_EPOCHS)
    time_vit = time.time() - start_vit

    print("Evaluating ViT-Seg...")
    dice_vit = evaluate_model(model_vit, val_loader, device)

    # --- 7. Comparison ---
    print("\n--- Comparison Results ---")
    print(f"Trained on {len(train_dataset)} images for {NUM_EPOCHS} epochs.")
    print("-" * 30)
    print(f"| Model     | Avg. Val Dice Score | Training Time |")
    print(f"|-----------|---------------------|---------------|")
    print(f"| U-Net     | {dice_unet:.4f}             | {time_unet:.2f}s        |")
    print(f"| ViT-Seg   | {dice_vit:.4f}             | {time_vit:.2f}s        |")
    print("-" * 30)
    print("\nNote: These results are from a short training run on small models.")
    print("For a real study, you would need more epochs, larger models,")
    print("and rigorous hyperparameter tuning.")


if __name__ == "__main__":
    main()