# Breast Tissue Segmentation and Tiling

This notebook segments breast tissue from mammogram images and creates tiles only from the breast regions, ignoring background areas.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import cv2
from scipy import ndimage
from skimage import morphology, measure, filters
import random
from tqdm import tqdm

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

In [None]:
# Configuration
DATA_DIR = Path("/home/pranaypalem/Downloads/VinDr_png_archive/split_images/training")
TILE_SIZE = 256
TILE_STRIDE = 128
NUM_SAMPLES = 10
MIN_BREAST_AREA_RATIO = 0.3  # Minimum fraction of tile that should be breast tissue

In [None]:
def segment_breast_tissue(image_array):
    """
    Segment breast tissue from mammogram image using thresholding and morphological operations.
    
    Args:
        image_array: numpy array of the mammogram image (grayscale)
    
    Returns:
        Binary mask where True indicates breast tissue
    """
    # Convert to grayscale if needed
    if len(image_array.shape) == 3:
        gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
    else:
        gray = image_array.copy()
    
    # Apply Gaussian blur to reduce noise
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    
    # Use Otsu's thresholding to separate breast tissue from background
    _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    # Remove small noise using morphological opening
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    opened = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
    
    # Fill holes in the breast tissue
    filled = ndimage.binary_fill_holes(opened).astype(np.uint8) * 255
    
    # Keep only the largest connected component (main breast tissue)
    num_labels, labels = cv2.connectedComponents(filled)
    if num_labels > 1:
        # Find the largest component (excluding background)
        largest_label = 1 + np.argmax([np.sum(labels == i) for i in range(1, num_labels)])
        mask = (labels == largest_label).astype(np.uint8) * 255
    else:
        mask = filled
    
    # Apply morphological closing to smooth the boundaries
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    
    return mask > 0

In [None]:
def extract_breast_tiles(image_array, breast_mask, tile_size, stride, min_breast_ratio=0.3):
    """
    Extract tiles from breast regions only.
    
    Args:
        image_array: Original mammogram image
        breast_mask: Binary mask of breast tissue
        tile_size: Size of tiles to extract
        stride: Stride for tile extraction
        min_breast_ratio: Minimum fraction of tile that should be breast tissue
    
    Returns:
        List of (tile_image, tile_coords) tuples
    """
    height, width = image_array.shape[:2]
    tiles = []
    
    # Generate all possible tile positions
    y_positions = list(range(0, max(1, height - tile_size + 1), stride))
    x_positions = list(range(0, max(1, width - tile_size + 1), stride))
    
    # Add edge positions if needed
    if y_positions[-1] + tile_size < height:
        y_positions.append(height - tile_size)
    if x_positions[-1] + tile_size < width:
        x_positions.append(width - tile_size)
    
    for y in y_positions:
        for x in x_positions:
            # Extract tile from mask
            tile_mask = breast_mask[y:y+tile_size, x:x+tile_size]
            
            # Calculate breast tissue ratio in this tile
            breast_ratio = np.sum(tile_mask) / (tile_size * tile_size)
            
            # Only keep tiles with sufficient breast tissue
            if breast_ratio >= min_breast_ratio:
                if len(image_array.shape) == 3:
                    tile_image = image_array[y:y+tile_size, x:x+tile_size]
                else:
                    tile_image = image_array[y:y+tile_size, x:x+tile_size]
                
                tiles.append((tile_image, (x, y), breast_ratio))
    
    return tiles

In [None]:
# Get sample images
image_paths = list(DATA_DIR.glob("*.png"))
print(f"Found {len(image_paths)} images in {DATA_DIR}")

# Select random sample
sample_paths = random.sample(image_paths, min(NUM_SAMPLES, len(image_paths)))
print(f"Processing {len(sample_paths)} sample images")

In [None]:
# Process each sample image
results = []

for i, img_path in enumerate(tqdm(sample_paths, desc="Processing images")):
    print(f"\nProcessing image {i+1}: {img_path.name}")
    
    # Load image
    with Image.open(img_path) as img:
        img_array = np.array(img)
    
    print(f"  Image shape: {img_array.shape}")
    
    # Segment breast tissue
    breast_mask = segment_breast_tissue(img_array)
    breast_area = np.sum(breast_mask)
    total_area = breast_mask.shape[0] * breast_mask.shape[1]
    breast_percentage = (breast_area / total_area) * 100
    
    print(f"  Breast tissue: {breast_percentage:.1f}% of image")
    
    # Extract tiles from breast regions
    tiles = extract_breast_tiles(img_array, breast_mask, TILE_SIZE, TILE_STRIDE, MIN_BREAST_AREA_RATIO)
    
    print(f"  Generated {len(tiles)} breast tissue tiles")
    
    results.append({
        'path': img_path,
        'image': img_array,
        'mask': breast_mask,
        'tiles': tiles,
        'breast_percentage': breast_percentage
    })

print(f"\nCompleted processing {len(results)} images")

In [None]:
def display_segmentation_results(results, max_images=5):
    """
    Display original images, breast masks, and sample tiles
    """
    fig, axes = plt.subplots(max_images, 3, figsize=(15, 5*max_images))
    if max_images == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(min(max_images, len(results))):
        result = results[i]
        
        # Original image
        axes[i, 0].imshow(result['image'], cmap='gray')
        axes[i, 0].set_title(f"Original\n{result['path'].name}")
        axes[i, 0].axis('off')
        
        # Breast mask
        axes[i, 1].imshow(result['mask'], cmap='gray')
        axes[i, 1].set_title(f"Breast Mask\n{result['breast_percentage']:.1f}% breast tissue")
        axes[i, 1].axis('off')
        
        # Overlay
        overlay = result['image'].copy()
        if len(overlay.shape) == 2:
            overlay = np.stack([overlay, overlay, overlay], axis=-1)
        
        # Add red tint to breast regions
        overlay[result['mask'], 0] = np.minimum(255, overlay[result['mask'], 0] + 50)
        
        axes[i, 2].imshow(overlay)
        axes[i, 2].set_title(f"Overlay\n{len(result['tiles'])} tiles generated")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Display segmentation results
display_segmentation_results(results, max_images=5)

In [None]:
def display_sample_tiles(results, tiles_per_image=6):
    """
    Display sample tiles from each processed image
    """
    for i, result in enumerate(results[:5]):  # Show first 5 images
        tiles = result['tiles']
        if len(tiles) == 0:
            print(f"No tiles generated for {result['path'].name}")
            continue
            
        # Select random tiles to display
        sample_tiles = random.sample(tiles, min(tiles_per_image, len(tiles)))
        
        fig, axes = plt.subplots(2, 3, figsize=(12, 8))
        fig.suptitle(f"Sample Tiles - {result['path'].name} ({len(tiles)} total tiles)", fontsize=14)
        
        for j, (tile_img, coords, breast_ratio) in enumerate(sample_tiles):
            row = j // 3
            col = j % 3
            
            axes[row, col].imshow(tile_img, cmap='gray')
            axes[row, col].set_title(f"Pos: {coords}\nBreast: {breast_ratio:.1%}")
            axes[row, col].axis('off')
        
        # Hide unused subplots
        for j in range(len(sample_tiles), 6):
            row = j // 3
            col = j % 3
            axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.show()

# Display sample tiles
display_sample_tiles(results)

In [None]:
# Summary statistics
total_tiles = sum(len(result['tiles']) for result in results)
avg_tiles_per_image = total_tiles / len(results) if results else 0
avg_breast_percentage = np.mean([result['breast_percentage'] for result in results])

print("=== Summary Statistics ===")
print(f"Total images processed: {len(results)}")
print(f"Total breast tissue tiles: {total_tiles}")
print(f"Average tiles per image: {avg_tiles_per_image:.1f}")
print(f"Average breast tissue percentage: {avg_breast_percentage:.1f}%")
print(f"Tile size: {TILE_SIZE}x{TILE_SIZE} pixels")
print(f"Minimum breast tissue ratio per tile: {MIN_BREAST_AREA_RATIO:.1%}")

# Distribution of tiles per image
tile_counts = [len(result['tiles']) for result in results]
plt.figure(figsize=(8, 5))
plt.hist(tile_counts, bins=10, alpha=0.7, edgecolor='black')
plt.xlabel('Number of Tiles per Image')
plt.ylabel('Frequency')
plt.title('Distribution of Breast Tissue Tiles per Image')
plt.grid(True, alpha=0.3)
plt.show()