# Medical-Optimized Breast Tissue Segmentation and BYOL Augmentation Demo

This notebook demonstrates the medical-optimized breast tissue segmentation and BYOL augmentations implemented in `train_byol_mammo.py`. Shows tiles, frequency energy detection for micro-calcifications, and medical-appropriate transforms.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import cv2
from scipy import ndimage
from skimage import morphology, measure, filters
import random
from tqdm import tqdm
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

# Import BYOL transforms
from lightly.transforms.byol_transform import BYOLTransform

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

In [None]:
# Configuration - Same as train_byol_mammo.py
DATA_DIR = Path("./split_images/training")
TILE_SIZE = 256
TILE_STRIDE = 128
NUM_SAMPLES = 10
MIN_BREAST_RATIO = 0.1  # Lowered for micro-calcifications (updated from 0.3)
MIN_FREQ_ENERGY = 0.01  # New: minimum high-frequency energy for calcification detection

In [None]:
def compute_frequency_energy(image_patch: np.ndarray) -> float:
    """
    Compute high-frequency energy using Laplacian of Gaussian (LoG) 
    to detect micro-calcifications and other high-frequency structures.
    SAME FUNCTION AS IN train_byol_mammo.py
    """
    if len(image_patch.shape) == 3:
        gray = cv2.cvtColor(image_patch, cv2.COLOR_RGB2GRAY)
    else:
        gray = image_patch.copy()
    
    # Apply Laplacian of Gaussian for high-frequency detection
    blurred = cv2.GaussianBlur(gray.astype(np.float32), (3, 3), 1.0)
    laplacian = cv2.Laplacian(blurred, cv2.CV_32F, ksize=3)
    
    # Compute energy (normalized variance of high-frequency components)
    energy = np.var(laplacian) / (np.mean(gray) + 1e-8)  # Normalized by intensity
    return float(energy)


def segment_breast_tissue(image_array: np.ndarray) -> np.ndarray:
    """
    Segment breast tissue from mammogram using morphological operations.
    SAME FUNCTION AS IN train_byol_mammo.py
    """
    if len(image_array.shape) == 3:
        gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
    else:
        gray = image_array.copy()
    
    # Gentle blur to preserve medical details
    blurred = cv2.GaussianBlur(gray, (3, 3), 0)
    
    # Otsu thresholding for breast tissue segmentation
    _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    # Minimal morphological operations to preserve detail
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    opened = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
    
    # Fill holes
    filled = ndimage.binary_fill_holes(opened).astype(np.uint8) * 255
    
    # Keep largest connected component (main breast tissue)
    num_labels, labels = cv2.connectedComponents(filled)
    if num_labels > 1:
        largest_label = 1 + np.argmax([np.sum(labels == i) for i in range(1, num_labels)])
        mask = (labels == largest_label).astype(np.uint8) * 255
    else:
        mask = filled
    
    # Gentle closing to smooth boundaries
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    
    return mask > 0

In [None]:
def extract_breast_tiles_with_freq_energy(image_array, breast_mask, tile_size, stride, min_breast_ratio=0.1, min_freq_energy=0.01):
    """
    Extract tiles containing sufficient breast tissue OR high-frequency content (micro-calcifications).
    SAME LOGIC AS IN train_byol_mammo.py
    """
    height, width = image_array.shape[:2]
    tiles = []
    
    # Generate all possible tile positions
    y_positions = list(range(0, max(1, height - tile_size + 1), stride))
    x_positions = list(range(0, max(1, width - tile_size + 1), stride))
    
    # Add edge positions if needed
    if y_positions[-1] + tile_size < height:
        y_positions.append(height - tile_size)
    if x_positions[-1] + tile_size < width:
        x_positions.append(width - tile_size)
    
    for y in y_positions:
        for x in x_positions:
            # Check breast tissue ratio in this tile
            tile_mask = breast_mask[y:y+tile_size, x:x+tile_size]
            breast_ratio = np.sum(tile_mask) / (tile_size * tile_size)
            
            # Extract image tile for frequency analysis
            tile_image = image_array[y:y+tile_size, x:x+tile_size]
            freq_energy = compute_frequency_energy(tile_image)
            
            # Keep tiles with sufficient breast tissue OR high-frequency content (for micro-calcifications)
            if breast_ratio >= min_breast_ratio or freq_energy >= min_freq_energy:
                tiles.append((tile_image, (x, y), breast_ratio, freq_energy))
    
    return tiles


def create_medical_transforms(input_size: int):
    """
    Create BYOL transforms optimized for medical imaging.
    SAME FUNCTION AS IN train_byol_mammo.py
    """
    # Medical-appropriate transforms for View 1 (lighter augmentations)
    view1_transform = T.Compose([
        T.ToTensor(),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomRotation(degrees=7, fill=0),  # Small rotations to preserve anatomy
        T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0, hue=0),  # Mild brightness/contrast, no color
        T.Resize(input_size, antialias=True),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Grayscale-appropriate normalization
    ])
    
    # Medical-appropriate transforms for View 2 (slightly stronger augmentations)  
    view2_transform = T.Compose([
        T.ToTensor(),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomRotation(degrees=7, fill=0),
        T.ColorJitter(brightness=0.15, contrast=0.15, saturation=0, hue=0),  # Slightly stronger
        T.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=0),  # Small translations/scaling
        T.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)),  # Very mild blur to preserve details
        T.Resize(input_size, antialias=True),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Grayscale-appropriate normalization
    ])
    
    return BYOLTransform(
        view_1_transform=view1_transform,
        view_2_transform=view2_transform,
    )

In [None]:
# Get sample images from the dataset
image_paths = list(DATA_DIR.glob("*.png"))
print(f"Found {len(image_paths)} images in {DATA_DIR}")

if len(image_paths) == 0:
    print("❌ No images found! Make sure the path is correct and contains .png files")
else:
    # Select random sample
    sample_paths = random.sample(image_paths, min(NUM_SAMPLES, len(image_paths)))
    print(f"✅ Processing {len(sample_paths)} sample images to demonstrate medical BYOL pipeline")

In [None]:
# Process each sample image with the updated medical pipeline
results = []

if 'sample_paths' in locals() and len(sample_paths) > 0:
    for i, img_path in enumerate(tqdm(sample_paths, desc="Processing images with medical pipeline")):
        print(f"\nProcessing image {i+1}: {img_path.name}")
        
        # Load image
        with Image.open(img_path) as img:
            img_array = np.array(img)
        
        print(f"  Image shape: {img_array.shape}")
        
        # Segment breast tissue using updated function
        breast_mask = segment_breast_tissue(img_array)
        breast_area = np.sum(breast_mask)
        total_area = breast_mask.shape[0] * breast_mask.shape[1]
        breast_percentage = (breast_area / total_area) * 100
        
        print(f"  Breast tissue: {breast_percentage:.1f}% of image")
        
        # Extract tiles using updated logic (breast tissue OR frequency energy)
        tiles = extract_breast_tiles_with_freq_energy(
            img_array, breast_mask, TILE_SIZE, TILE_STRIDE, 
            MIN_BREAST_RATIO, MIN_FREQ_ENERGY
        )
        
        # Separate tiles by selection criteria
        breast_tiles = [t for t in tiles if t[2] >= MIN_BREAST_RATIO]
        freq_tiles = [t for t in tiles if t[2] < MIN_BREAST_RATIO and t[3] >= MIN_FREQ_ENERGY]
        
        print(f"  Generated {len(tiles)} total tiles:")
        print(f"    - {len(breast_tiles)} tiles by breast tissue ratio (≥{MIN_BREAST_RATIO:.1%})")
        print(f"    - {len(freq_tiles)} tiles by frequency energy (≥{MIN_FREQ_ENERGY:.3f})")
        
        results.append({
            'path': img_path,
            'image': img_array,
            'mask': breast_mask,
            'tiles': tiles,
            'breast_tiles': breast_tiles,
            'freq_tiles': freq_tiles,
            'breast_percentage': breast_percentage
        })

    print(f"\n✅ Completed processing {len(results)} images with medical-optimized pipeline")
else:
    print("❌ No sample images to process")

In [None]:
def display_medical_segmentation_pipeline(results, max_images=5):
    """
    Display the medical-optimized segmentation pipeline results
    """
    if not results:
        print("❌ No results to display")
        return
        
    fig, axes = plt.subplots(max_images, 4, figsize=(20, 5*max_images))
    if max_images == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(min(max_images, len(results))):
        result = results[i]
        
        # Original image
        axes[i, 0].imshow(result['image'], cmap='gray')
        axes[i, 0].set_title(f"Original\n{result['path'].name}")
        axes[i, 0].axis('off')
        
        # Breast mask (updated segmentation)
        axes[i, 1].imshow(result['mask'], cmap='gray')
        axes[i, 1].set_title(f"Breast Mask\n{result['breast_percentage']:.1f}% breast tissue")
        axes[i, 1].axis('off')
        
        # Frequency energy heatmap for micro-calcification detection
        if len(result['tiles']) > 0:
            freq_map = np.zeros(result['image'].shape[:2])
            for tile_img, (x, y), breast_ratio, freq_energy in result['tiles']:
                freq_map[y:y+TILE_SIZE, x:x+TILE_SIZE] = max(freq_map[y:y+TILE_SIZE, x:x+TILE_SIZE].max(), freq_energy)
            
            im = axes[i, 2].imshow(freq_map, cmap='hot', alpha=0.7)
            axes[i, 2].imshow(result['image'], cmap='gray', alpha=0.3)
            axes[i, 2].set_title(f"Frequency Energy\n(Micro-calcification detection)")
            axes[i, 2].axis('off')
        
        # Tile overlay showing selected regions
        overlay = result['image'].copy()
        if len(overlay.shape) == 2:
            overlay = np.stack([overlay, overlay, overlay], axis=-1)
        
        # Green for breast tissue tiles, Red for frequency energy tiles
        for tile_img, (x, y), breast_ratio, freq_energy in result['breast_tiles']:
            cv2.rectangle(overlay, (x, y), (x+TILE_SIZE, y+TILE_SIZE), (0, 255, 0), 2)
        
        for tile_img, (x, y), breast_ratio, freq_energy in result['freq_tiles']:
            cv2.rectangle(overlay, (x, y), (x+TILE_SIZE, y+TILE_SIZE), (255, 0, 0), 2)
        
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title(f"Selected Tiles\n🟢 Breast ({len(result['breast_tiles'])}) 🔴 Freq ({len(result['freq_tiles'])})")
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

# Display medical segmentation pipeline results
if results:
    display_medical_segmentation_pipeline(results, max_images=min(5, len(results)))

In [None]:
def display_byol_augmentations(results, num_augmentations=6):
    """
    Display the medical-optimized BYOL augmentations on sample tiles
    """
    if not results:
        print("❌ No results to display")
        return
    
    # Create the medical transforms
    transform = create_medical_transforms(TILE_SIZE)
    
    # Select a result with tiles
    result = None
    for r in results:
        if len(r['tiles']) > 3:
            result = r
            break
    
    if result is None:
        print("❌ No results with sufficient tiles found")
        return
    
    # Select a few interesting tiles (mix of breast tissue and frequency energy tiles)
    sample_tiles = []
    if len(result['breast_tiles']) > 0:
        sample_tiles.extend(random.sample(result['breast_tiles'], min(2, len(result['breast_tiles']))))
    if len(result['freq_tiles']) > 0:
        sample_tiles.extend(random.sample(result['freq_tiles'], min(2, len(result['freq_tiles']))))
    
    if not sample_tiles:
        sample_tiles = random.sample(result['tiles'], min(3, len(result['tiles'])))
    
    for i, (tile_img, coords, breast_ratio, freq_energy) in enumerate(sample_tiles[:3]):
        # Convert tile to PIL Image and prepare for transforms
        if tile_img.max() <= 1.0:
            tile_img = (tile_img * 255).astype(np.uint8)
        
        # Convert to grayscale then RGB (same as in training pipeline)
        pil_tile = Image.fromarray(tile_img.astype(np.uint8))
        if pil_tile.mode != 'L':
            pil_tile = pil_tile.convert('L')
        pil_tile = pil_tile.convert('RGB')  # Replicate grayscale channel
        
        # Generate multiple augmentations
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        fig.suptitle(f"BYOL Medical Augmentations - Tile {i+1} (Breast: {breast_ratio:.1%}, Freq: {freq_energy:.3f})", fontsize=14)
        
        # Original tile
        axes[0, 0].imshow(tile_img, cmap='gray')
        axes[0, 0].set_title("Original Tile")
        axes[0, 0].axis('off')
        
        # Show multiple augmented versions
        for j in range(1, 8):
            row = j // 4
            col = j % 4
            
            # Apply BYOL transforms
            views = transform(pil_tile)
            view1, view2 = views
            
            # Convert back to display format
            if j <= 3:
                # Show View 1 transforms (lighter)
                display_tensor = view1
                title = f"View 1 Aug #{j}"
            else:
                # Show View 2 transforms (stronger)
                display_tensor = view2
                title = f"View 2 Aug #{j-3}"
            
            # Denormalize for display
            display_img = display_tensor.clone()
            display_img = display_img * 0.5 + 0.5  # Reverse normalization
            display_img = torch.clamp(display_img, 0, 1)
            
            # Convert to numpy and show
            display_np = display_img.permute(1, 2, 0).numpy()
            axes[row, col].imshow(display_np)
            axes[row, col].set_title(title)
            axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.show()

# Display BYOL augmentations
if results:
    print("🎯 Demonstrating Medical-Optimized BYOL Augmentations")
    print("View 1: Lighter augmentations (horizontal flip, ±7° rotation, mild brightness/contrast)")
    print("View 2: Stronger augmentations (+ translation, scaling, mild blur)")
    print("Note: No strong color jitter/solarization to preserve medical details\n")
    
    display_byol_augmentations(results)

In [None]:
# Comprehensive summary of the medical-optimized pipeline
if results:
    total_tiles = sum(len(result['tiles']) for result in results)
    total_breast_tiles = sum(len(result['breast_tiles']) for result in results)
    total_freq_tiles = sum(len(result['freq_tiles']) for result in results)
    avg_tiles_per_image = total_tiles / len(results) if results else 0
    avg_breast_percentage = np.mean([result['breast_percentage'] for result in results])
    
    print("🏥 === MEDICAL-OPTIMIZED BYOL PIPELINE SUMMARY ===")
    print(f"📊 Dataset Statistics:")
    print(f"   • Total images processed: {len(results)}")
    print(f"   • Total tiles generated: {total_tiles:,}")
    print(f"   • Average tiles per image: {avg_tiles_per_image:.1f}")
    print(f"   • Average breast tissue percentage: {avg_breast_percentage:.1f}%")
    
    print(f"\n🎯 Tile Selection Strategy:")
    print(f"   • Breast tissue tiles (≥{MIN_BREAST_RATIO:.1%} tissue): {total_breast_tiles:,} ({total_breast_tiles/total_tiles*100:.1f}%)")
    print(f"   • Frequency energy tiles (≥{MIN_FREQ_ENERGY:.3f} energy): {total_freq_tiles:,} ({total_freq_tiles/total_tiles*100:.1f}%)")
    print(f"   • Tile size: {TILE_SIZE}×{TILE_SIZE} pixels")
    print(f"   • Tile stride: {TILE_STRIDE} pixels ({TILE_STRIDE/TILE_SIZE*100:.0f}% overlap)")
    
    print(f"\n🔬 Medical Improvements vs Original:")
    print(f"   ✅ Lowered breast ratio threshold: 0.3 → {MIN_BREAST_RATIO} (captures peripheral regions)")
    print(f"   ✅ Added frequency energy detection: micro-calcification sensitivity")
    print(f"   ✅ Gentle segmentation: preserves medical details")
    print(f"   ✅ Grayscale-appropriate preprocessing: L→RGB replication")
    
    print(f"\n🎛️ BYOL Augmentation Optimizations:")
    print(f"   ✅ Medical-safe rotations: ±7° (preserves anatomy)")
    print(f"   ✅ Mild brightness/contrast: no color distortion") 
    print(f"   ✅ Light blur: preserves calcification details")
    print(f"   ✅ No solarization/strong color jitter: medical data integrity")
    
    print(f"\n⚡ A100 Performance Optimizations:")
    print(f"   ✅ Mixed precision training: autocast + GradScaler")
    print(f"   ✅ Per-step momentum updates: better convergence")
    print(f"   ✅ Optimized hyperparameters: LR=3e-4, WD=1e-4 (batch=8)")
    print(f"   ✅ Multi-label classification ready: [mass, calcification]")
    
    # Distribution visualization
    breast_ratios = []
    freq_energies = []
    for result in results:
        for tile_img, coords, breast_ratio, freq_energy in result['tiles']:
            breast_ratios.append(breast_ratio)
            freq_energies.append(freq_energy)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Breast ratio distribution
    ax1.hist(breast_ratios, bins=20, alpha=0.7, edgecolor='black', color='green')
    ax1.axvline(MIN_BREAST_RATIO, color='red', linestyle='--', label=f'Threshold: {MIN_BREAST_RATIO:.1%}')
    ax1.set_xlabel('Breast Tissue Ratio')
    ax1.set_ylabel('Number of Tiles')
    ax1.set_title('Distribution of Breast Tissue Ratios')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Frequency energy distribution
    ax2.hist(freq_energies, bins=20, alpha=0.7, edgecolor='black', color='orange')
    ax2.axvline(MIN_FREQ_ENERGY, color='red', linestyle='--', label=f'Threshold: {MIN_FREQ_ENERGY:.3f}')
    ax2.set_xlabel('Frequency Energy (LoG variance)')
    ax2.set_ylabel('Number of Tiles')
    ax2.set_title('Distribution of Frequency Energy (Calcification Detection)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n🚀 Ready for A100 training with: sbatch submit_byol.sbatch")
    
else:
    print("❌ No results to summarize - please check the data directory path")