# Breast Cancer Identification - Phase 1: Data Preparation

This notebook prepares datasets for multi-modal breast cancer identification:
- BreakHis (Histopathology images)
- CBIS-DDSM (Mammography images)

**Run this in Google Colab for free GPU access!**

## 1. Setup & Environment Configuration

In [None]:
# Install required packages
!pip install -q torch torchvision pytorch-lightning timm albumentations opencv-python scikit-image pandas numpy pyyaml tqdm optuna wandb captum shap streamlit fastapi uvicorn

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import cv2
import os
import yaml
from pathlib import Path
from typing import Tuple, List, Dict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")

# Enable cuDNN benchmark
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Mount Google Drive & Download Datasets

In [None]:
# Mount Google Drive (only for Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    in_colab = True
    data_root = '/content/drive/MyDrive/breast_cancer_data'
    print("✓ Google Drive mounted")
except ImportError:
    in_colab = False
    data_root = '/tmp/breast_cancer_data'
    print("✓ Running locally (not in Colab)")

# Create data directory
os.makedirs(data_root, exist_ok=True)
print(f"Data root: {data_root}")

In [None]:
# Download BreakHis dataset (from Kaggle)
# You'll need Kaggle API credentials: !kaggle datasets download -d ihebski/histopathologic-breast-cancer-images

# For demo, create dummy structure
breakhis_path = os.path.join(data_root, 'BreakHis')
cbis_ddsm_path = os.path.join(data_root, 'CBIS-DDSM')

os.makedirs(f'{breakhis_path}/benign', exist_ok=True)
os.makedirs(f'{breakhis_path}/malignant', exist_ok=True)
os.makedirs(f'{cbis_ddsm_path}/images', exist_ok=True)

print(f"✓ Dataset directories created:")
print(f"  - BreakHis: {breakhis_path}")
print(f"  - CBIS-DDSM: {cbis_ddsm_path}")

## 3. Data Inspection & Metadata Parsing

In [None]:
def scan_dataset(dataset_path: str, modality: str) -> pd.DataFrame:
    """
    Scan dataset and create metadata DataFrame.
    """
    records = []
    
    if modality == 'histopathology':
        # BreakHis structure: benign/ and malignant/ folders
        for class_name in ['benign', 'malignant']:
            class_path = os.path.join(dataset_path, class_name)
            if os.path.exists(class_path):
                for img_file in os.listdir(class_path):
                    if img_file.endswith(('.png', '.jpg', '.jpeg')):
                        img_path = os.path.join(class_path, img_file)
                        label = 0 if class_name == 'benign' else 1
                        records.append({
                            'image_path': img_path,
                            'label': label,
                            'class': class_name,
                            'modality': 'histopathology',
                            'magnification': 'unknown'  # Extract from filename if available
                        })
    
    return pd.DataFrame(records)

# Scan datasets
print("Scanning datasets...")
breakhis_df = scan_dataset(breakhis_path, 'histopathology')
print(f"BreakHis: {len(breakhis_df)} images found")
print(breakhis_df.head())

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

if len(breakhis_df) > 0:
    breakhis_df['class'].value_counts().plot(kind='bar', ax=axes[0])
    axes[0].set_title('BreakHis Class Distribution')
    axes[0].set_ylabel('Count')

print("\n✓ Class distribution analyzed")
plt.tight_layout()
plt.show()

## 4. Albumentations Augmentation Pipelines

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

class DataAugmenter:
    """Create augmentation pipelines for different datasets."""
    
    @staticmethod
    def get_train_transforms(image_size: int = 224):
        return A.Compose([
            A.RandomResizedCrop(image_size, image_size, scale=(0.8, 1.0)),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=30, p=0.7),
            A.GaussNoise(p=0.2),
            A.RandomBrightnessContrast(p=0.5),
            A.ElasticTransform(p=0.3),
            A.GaussianBlur(blur_limit=3, p=0.2),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ], is_check_shapes=False)
    
    @staticmethod
    def get_val_transforms(image_size: int = 224):
        return A.Compose([
            A.Resize(image_size, image_size),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

augmenter = DataAugmenter()
train_transform = augmenter.get_train_transforms()
val_transform = augmenter.get_val_transforms()

print("✓ Augmentation pipelines created")
print(f"Train transforms: {len(train_transform.transforms)} augmentations")
print(f"Val transforms: {len(val_transform.transforms)} augmentations")

In [None]:
# Visualize augmentations (demo with a sample image)
# Create a dummy image for demo
dummy_image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for i, ax in enumerate(axes):
    augmented = train_transform(image=dummy_image)
    img = augmented['image'].permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)  # Normalize for display
    ax.imshow(img)
    ax.set_title(f'Augmentation {i+1}')
    ax.axis('off')

plt.suptitle('Sample Augmentations (10 different transforms)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("✓ Augmentation examples displayed")

## 5. Balanced Train/Val/Test Split

In [None]:
from sklearn.model_selection import train_test_split

def create_balanced_splits(df: pd.DataFrame, 
                          train_ratio: float = 0.7,
                          val_ratio: float = 0.15,
                          random_state: int = 42) -> Dict[str, pd.DataFrame]:
    """
    Create stratified splits maintaining class distribution.
    
    Train: 70% | Val: 15% | Test: 15%
    """
    # First split: train vs temp (val+test)
    train_df, temp_df = train_test_split(
        df,
        test_size=(1 - train_ratio),
        stratify=df['label'],
        random_state=random_state
    )
    
    # Second split: val vs test
    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.5,
        stratify=temp_df['label'],
        random_state=random_state
    )
    
    return {'train': train_df, 'val': val_df, 'test': test_df}

# Create splits
if len(breakhis_df) > 0:
    splits = create_balanced_splits(breakhis_df)
    
    print("Balanced Splits (BreakHis):")
    for split_name, split_df in splits.items():
        benign = (split_df['label'] == 0).sum()
        malignant = (split_df['label'] == 1).sum()
        print(f"  {split_name.upper():5s}: {len(split_df):5d} images | Benign: {benign:4d} | Malignant: {malignant:4d}")
else:
    print("⚠ No images found in BreakHis folder")

In [None]:
# Save splits to CSV for reproducibility
splits_dir = os.path.join(data_root, 'splits')
os.makedirs(splits_dir, exist_ok=True)

if len(breakhis_df) > 0:
    for split_name, split_df in splits.items():
        csv_path = os.path.join(splits_dir, f'breakhis_{split_name}.csv')
        split_df.to_csv(csv_path, index=False)
        print(f"✓ Saved {split_name}: {csv_path}")

# Visualize split distribution
fig, ax = plt.subplots(figsize=(10, 5))

split_stats = []
for split_name, split_df in splits.items():
    benign = (split_df['label'] == 0).sum()
    malignant = (split_df['label'] == 1).sum()
    split_stats.append({'Split': split_name, 'Benign': benign, 'Malignant': malignant})

stats_df = pd.DataFrame(split_stats)
stats_df.set_index('Split')[['Benign', 'Malignant']].plot(kind='bar', ax=ax, rot=0)
ax.set_title('Balanced Train/Val/Test Splits')
ax.set_ylabel('Number of Images')
plt.tight_layout()
plt.show()

print("✓ Splits visualization complete")

## 6. PyTorch DataLoader Setup

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image as PILImage

class MedicalImageDataset(Dataset):
    """PyTorch Dataset for medical images."""
    
    def __init__(self, dataframe: pd.DataFrame, transform=None):
        self.df = dataframe
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        image_path = row['image_path']
        if os.path.exists(image_path):
            image = cv2.imread(image_path)
            if image is None:
                # Return dummy image if file not found
                image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        else:
            # Return dummy image for demo
            image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
        
        # Apply transforms
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        label = torch.tensor(row['label'], dtype=torch.long)
        
        return image, label

# Create datasets
if len(breakhis_df) > 0:
    train_dataset = MedicalImageDataset(splits['train'], transform=train_transform)
    val_dataset = MedicalImageDataset(splits['val'], transform=val_transform)
    test_dataset = MedicalImageDataset(splits['test'], transform=val_transform)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)
    
    print(f"✓ DataLoaders created:")
    print(f"  Train: {len(train_loader)} batches")
    print(f"  Val: {len(val_loader)} batches")
    print(f"  Test: {len(test_loader)} batches")

In [None]:
# Visualize batch
if len(breakhis_df) > 0:
    batch_images, batch_labels = next(iter(train_loader))
    
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()
    
    for i, (img, label) in enumerate(zip(batch_images[:8], batch_labels[:8])):
        img = img.permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        axes[i].imshow(img)
        class_name = 'Malignant' if label == 1 else 'Benign'
        axes[i].set_title(f'{class_name} (L={label})')
        axes[i].axis('off')
    
    plt.suptitle('Sample Batch from Training Set', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()
    
    print(f"Batch shape: {batch_images.shape}")
    print(f"Labels: {batch_labels.tolist()}")

## 7. Configuration Summary

In [None]:
# Create configuration summary
config_summary = {
    'data': {
        'datasets': {
            'BreakHis': f'{len(breakhis_df)} images (histopathology)',
            'CBIS-DDSM': 'To be added (mammography)',
        },
        'augmentation': [
            'RandomResizedCrop(224)',
            'HorizontalFlip(0.5)',
            'VerticalFlip(0.5)',
            'Rotate(30°)',
            'GaussNoise, BrightnessContrast',
            'ElasticTransform, GaussianBlur'
        ],
        'splits': {'train': 0.7, 'val': 0.15, 'test': 0.15},
    },
    'models': {
        'histopathology': 'EfficientNet-B0 (5.3M params)',
        'mammography': 'MobileNet-V3 (5.4M params)',
    },
    'training': {
        'batch_size': 32,
        'epochs': 50,
        'mixed_precision': True,
        'optimizer': 'AdamW',
        'scheduler': 'OneCycleLR',
    }
}

import json
print(json.dumps(config_summary, indent=2))

## 8. Save Metadata to Google Drive

In [None]:
# Save dataset metadata
metadata_dir = os.path.join(data_root, 'metadata')
os.makedirs(metadata_dir, exist_ok=True)

if len(breakhis_df) > 0:
    metadata_path = os.path.join(metadata_dir, 'breakhis_metadata.csv')
    breakhis_df.to_csv(metadata_path, index=False)
    print(f"✓ Metadata saved to: {metadata_path}")

# Save config
config_path = os.path.join(metadata_dir, 'data_config.json')
with open(config_path, 'w') as f:
    json.dump(config_summary, f, indent=2)
print(f"✓ Config saved to: {config_path}")

print("\n✓ Phase 1 Data Preparation Complete!")
print(f"\nDatasets ready at: {data_root}")