# Autoencoder for Histopathology WSI Analysis
### End-to-End Blueprint for Dimensionality Reduction

*A practical, comprehensive guide for designing an autoencoder (AE) to learn compact, high-quality embeddings from whole-slide images (WSIs) for dimensionality reduction and downstream analysis.*

---

## 1. Define the Objective Precisely

### Primary Goal
Learn tile-level embeddings that capture tissue morphology and stain characteristics while discarding background and slide-specific artifacts.

### Downstream Applications
Clarify whether you need embeddings for:
- **(a)** Slide-level modeling (e.g., MIL for diagnosis/prognosis)
- **(b)** Content-based retrieval
- **(c)** Clustering/tile curation  
- **(d)** Input to classical DR (PCA/UMAP) for visualization

> **Note:** This choice affects architecture, loss weighting, and aggregation strategies.

### Scale Considerations
- Decide magnification(s) relevant to your task
- **5×**: Tissue layout and architecture
- **10×–20×**: Cellular detail and morphology
- **Multi-scale**: Design variant from the start if you need both

In [26]:
# Modules and libraries
import os
import glob
from pathlib import Path
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
from PIL import Image, ExifTags
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [15]:
# ESSENTIAL CONFIGURATION - Required Constants
print("SETTING UP CONFIGURATION...")

# Model Configuration (based on patch analysis)
INPUT_SIZE = 150  # Patch size from analysis: 150x150 pixels
LATENT_DIM = 256  # Latent space dimensionality
NUM_CHANNELS = 3  # RGB channels

# Training Configuration
RANDOM_SEED = 42
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15  
TEST_RATIO = 0.15

# Data Paths
OUTPUT_BASE_DIR = "/Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits"

print(f"Configuration set:")
print(f" - Input size: {INPUT_SIZE}x{INPUT_SIZE}")
print(f" - Latent dim: {LATENT_DIM}")
print(f" - Channels: {NUM_CHANNELS}")
print(f" - Output dir: {OUTPUT_BASE_DIR}")
print(f" - Data splits: {TRAIN_RATIO:.1%} train, {VAL_RATIO:.1%} val, {TEST_RATIO:.1%} test")

SETTING UP CONFIGURATION...
Configuration set:
 - Input size: 150x150
 - Latent dim: 256
 - Channels: 3
 - Output dir: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits
 - Data splits: 70.0% train, 15.0% val, 15.0% test


In [27]:
# analyze provided patches
PATCHES_PATH = "/Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/01_TUMOR"
!{sys.executable} -m pip install seaborn

zsh:1: parse error near `-m'


In [28]:
# Comprehensive Patch Analysis Script
print("Starting comprehensive patch analysis...")
print(f"Analyzing patches in: {PATCHES_PATH}")

# 1. DISCOVERY: Find all image files
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tiff', '*.tif', '*.bmp']
all_patches = []

for ext in image_extensions:
    patches = glob.glob(os.path.join(PATCHES_PATH, '**', ext), recursive=True)
    patches.extend(glob.glob(os.path.join(PATCHES_PATH, '**', ext.upper()), recursive=True))
    all_patches.extend(patches)

print(f"Found {len(all_patches)} image files")

if len(all_patches) == 0:
    print("No image files found! Check the path.")
    print(f"Directory exists: {os.path.exists(PATCHES_PATH)}")
    if os.path.exists(PATCHES_PATH):
        print(f"Contents: {os.listdir(PATCHES_PATH)}")
else:
    # 2. BASIC STATISTICS
    file_extensions = Counter([Path(p).suffix.lower() for p in all_patches])
    print(f"\nFile Extensions Found:")
    for ext, count in file_extensions.items():
        print(f"  {ext}: {count:,} files")
    
    # 3. DETAILED ANALYSIS (sample subset for speed)
    sample_size = min(1000, len(all_patches))  # Analyze max 1000 for speed
    sample_patches = np.random.choice(all_patches, sample_size, replace=False)
    
    print(f"\nAnalyzing {sample_size} sample patches...")
    
    # Initialize data collectors
    patch_data = []
    errors = []
    
    for i, patch_path in enumerate(tqdm(sample_patches, desc="Analyzing patches")):
        try:
            with Image.open(patch_path) as img:
                # Basic image properties
                width, height = img.size
                mode = img.mode
                format_type = img.format
                
                # File size
                file_size = os.path.getsize(patch_path) / 1024  # KB
                
                # Color analysis (for RGB images)
                if mode in ['RGB', 'RGBA']:
                    img_array = np.array(img.convert('RGB'))
                    
                    # Basic statistics
                    mean_intensity = np.mean(img_array)
                    std_intensity = np.std(img_array)
                    
                    # Channel-wise analysis
                    r_mean, g_mean, b_mean = np.mean(img_array, axis=(0,1))
                    
                    # Simple tissue detection (non-white pixels)
                    gray = np.mean(img_array, axis=2)
                    tissue_ratio = np.mean(gray < 200)  # Pixels that are not close to white
                    
                    # Blur detection (Laplacian variance)
                    from scipy import ndimage
                    blur_score = ndimage.laplace(gray).var()
                    
                else:
                    mean_intensity = std_intensity = 0
                    r_mean = g_mean = b_mean = 0
                    tissue_ratio = 0
                    blur_score = 0
                
                # Try to get DPI/resolution info
                dpi = img.info.get('dpi', (None, None))
                
                patch_data.append({
                    'filename': Path(patch_path).name,
                    'width': width,
                    'height': height,
                    'mode': mode,
                    'format': format_type,
                    'file_size_kb': file_size,
                    'dpi_x': dpi[0] if dpi[0] else None,
                    'dpi_y': dpi[1] if dpi[1] else None,
                    'mean_intensity': mean_intensity,
                    'std_intensity': std_intensity,
                    'r_mean': r_mean,
                    'g_mean': g_mean,
                    'b_mean': b_mean,
                    'tissue_ratio': tissue_ratio,
                    'blur_score': blur_score
                })
                
        except Exception as e:
            errors.append({'file': patch_path, 'error': str(e)})
    
    # 4. ANALYSIS RESULTS
    df = pd.DataFrame(patch_data)
    
    print(f"\nANALYSIS RESULTS:")
    print(f"Successfully analyzed: {len(df)} patches")
    print(f"Errors encountered: {len(errors)}")
    
    if len(df) > 0:
        print(f"\nDIMENSIONS:")
        print(f"Width range: {df['width'].min()} - {df['width'].max()} pixels")
        print(f"Height range: {df['height'].min()} - {df['height'].max()} pixels")
        print(f"Most common size: {df.groupby(['width', 'height']).size().idxmax()}")
        print(f"Unique sizes found: {len(df.groupby(['width', 'height']))}")
        
        print(f"\nCOLOR PROPERTIES:")
        print(f"Color modes: {df['mode'].value_counts().to_dict()}")
        print(f"File formats: {df['format'].value_counts().to_dict()}")
        print(f"Mean intensity: {df['mean_intensity'].mean():.1f} ± {df['mean_intensity'].std():.1f}")
        
        print(f"\nFILE PROPERTIES:")
        print(f"File size range: {df['file_size_kb'].min():.1f} - {df['file_size_kb'].max():.1f} KB")
        print(f"Average file size: {df['file_size_kb'].mean():.1f} KB")
        
        # DPI analysis
        if df['dpi_x'].notna().any():
            print(f"\nRESOLUTION:")
            unique_dpis = df[['dpi_x', 'dpi_y']].drop_duplicates()
            print(f"DPI information available for {df['dpi_x'].notna().sum()} files")
            print(f"Unique DPI values: {len(unique_dpis)}")
            if len(unique_dpis) <= 5:
                for _, row in unique_dpis.iterrows():
                    print(f"  {row['dpi_x']} x {row['dpi_y']} DPI")
        else:
            print(f"\nRESOLUTION: No DPI information found in image metadata")
        
        print(f"\nTISSUE ANALYSIS:")
        print(f"Average tissue ratio: {df['tissue_ratio'].mean():.2f} (1.0 = all tissue, 0.0 = all background)")
        print(f"Tissue ratio range: {df['tissue_ratio'].min():.2f} - {df['tissue_ratio'].max():.2f}")
        
        print(f"\nIMAGE QUALITY:")
        print(f"Blur score range: {df['blur_score'].min():.1f} - {df['blur_score'].max():.1f}")
        print(f"Average blur score: {df['blur_score'].mean():.1f} (higher = sharper)")
        
        # 5. CONSISTENCY CHECK
        print(f"\nCONSISTENCY CHECK:")
        
        # Check if all patches have same dimensions
        unique_sizes = df.groupby(['width', 'height']).size()
        if len(unique_sizes) == 1:
            w, h = unique_sizes.index[0]
            print(f"All patches have consistent size: {w}×{h} pixels")
        else:
            print(f"Found {len(unique_sizes)} different sizes:")
            for (w, h), count in unique_sizes.head().items():
                print(f"   {w}×{h}: {count} patches ({count/len(df)*100:.1f}%)")
        
        # Check color mode consistency
        if len(df['mode'].unique()) == 1:
            print(f"All patches have consistent color mode: {df['mode'].iloc[0]}")
        else:
            print(f"Mixed color modes found: {df['mode'].value_counts().to_dict()}")
        
        # Check file format consistency
        if len(df['format'].unique()) == 1:
            print(f"All patches have consistent format: {df['format'].iloc[0]}")
        else:
            print(f"Mixed formats found: {df['format'].value_counts().to_dict()}")

        # 6. RECOMMENDATIONS
        print(f"\nRECOMMENDATIONS:")
        
        most_common_size = df.groupby(['width', 'height']).size().idxmax()
        w, h = most_common_size
        print(f"• Standard patch size to use: {w}×{h} pixels")
        
        if df['tissue_ratio'].mean() < 0.5:
            print(f"• Consider filtering patches with tissue_ratio < 0.3 to remove background")
        
        if df['blur_score'].std() > df['blur_score'].mean() * 0.5:
            print(f"• Consider filtering blurry patches (blur_score < {df['blur_score'].quantile(0.25):.0f})")
        
        avg_size_kb = df['file_size_kb'].mean()
        if avg_size_kb > 500:
            print(f"• Large file sizes ({avg_size_kb:.0f} KB avg) - consider compression for training")
        
        # 7. VISUALIZATIONS
        if len(df) > 10:
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))
            fig.suptitle('Patch Analysis Dashboard', fontsize=16)
            
            # Size distribution
            size_counts = df.groupby(['width', 'height']).size().reset_index(name='count')
            axes[0,0].bar(range(len(size_counts)), size_counts['count'])
            axes[0,0].set_title('Size Distribution')
            axes[0,0].set_xlabel('Size variants')
            axes[0,0].set_ylabel('Count')
            
            # File size distribution
            axes[0,1].hist(df['file_size_kb'], bins=30, alpha=0.7)
            axes[0,1].set_title('File Size Distribution')
            axes[0,1].set_xlabel('File Size (KB)')
            axes[0,1].set_ylabel('Count')
            
            # Tissue ratio distribution
            axes[0,2].hist(df['tissue_ratio'], bins=30, alpha=0.7)
            axes[0,2].set_title('Tissue Ratio Distribution')
            axes[0,2].set_xlabel('Tissue Ratio')
            axes[0,2].set_ylabel('Count')
            
            # Color channel means
            axes[1,0].hist([df['r_mean'], df['g_mean'], df['b_mean']], 
                          bins=30, alpha=0.5, label=['Red', 'Green', 'Blue'])
            axes[1,0].set_title('Color Channel Distributions')
            axes[1,0].set_xlabel('Mean Intensity')
            axes[1,0].legend()
            
            # Blur score distribution
            axes[1,1].hist(df['blur_score'], bins=30, alpha=0.7)
            axes[1,1].set_title('Blur Score Distribution')
            axes[1,1].set_xlabel('Blur Score (higher=sharper)')
            axes[1,1].set_ylabel('Count')
            
            # Mean vs std intensity scatter
            axes[1,2].scatter(df['mean_intensity'], df['std_intensity'], alpha=0.6)
            axes[1,2].set_title('Intensity Mean vs Std')
            axes[1,2].set_xlabel('Mean Intensity')
            axes[1,2].set_ylabel('Std Intensity')
            
            plt.tight_layout()
            plt.show()
        
        # Store results for future use
        globals()['patch_analysis_df'] = df
        globals()['patch_analysis_summary'] = {
            'total_patches': len(all_patches),
            'analyzed_patches': len(df),
            'most_common_size': most_common_size,
            'consistent_size': len(unique_sizes) == 1,
            'avg_tissue_ratio': df['tissue_ratio'].mean(),
            'avg_file_size_kb': df['file_size_kb'].mean(),
            'color_mode': df['mode'].iloc[0] if len(df['mode'].unique()) == 1 else 'mixed'
        }
        
        print(f"\nResults stored in variables:")
        print(f"• patch_analysis_df: Detailed DataFrame with all metrics")
        print(f"• patch_analysis_summary: Summary statistics dictionary")
    
    # Show errors if any
    if errors:
        print(f"\nERRORS ENCOUNTERED:")
        for error in errors[:5]:  # Show first 5 errors
            print(f"  {Path(error['file']).name}: {error['error']}")
        if len(errors) > 5:
            print(f"  ... and {len(errors)-5} more errors")

print("\nPatch analysis complete!")

Starting comprehensive patch analysis...
Analyzing patches in: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/01_TUMOR
Found 0 image files
No image files found! Check the path.
Directory exists: False

Patch analysis complete!


## 2. WSI Reading and Tiling

### Required Libraries
- **OpenSlide**, **pyvips**, **tifffile**, or **cuCIM** for WSI pyramid reading

### Tissue Masking Strategy
1. Compute low-resolution tissue mask to exclude blank glass
2. Use **HSV** or **LAB** thresholding + morphology operations
3. **Discard tiles** with >80–90% background

### Tile Configuration
| Parameter | Recommended Value | Notes |
|-----------|------------------|-------|
| **Tile Size** | 256×256 or 512×512 pixels | At 10× or 20× magnification |
| **Stride** | Equal to tile size | Non-overlapping tiles |
| **Overlap** | 50% (optional) | Only if dataset is small |

### Sampling Policy
- **Between slides**: Sample uniformly across all slides
- **Within slides**: Stratify by tissue mask to avoid background over-representation
- **Special cases**: If region annotations exist, oversample rare/critical morphologies

In [29]:
# Size Analysis
patch_data = []
errors = []
for i, patch_path in enumerate(tqdm(sample_patches, desc="Analyzing patches")):
    try:
        with Image.open(patch_path) as img:
            # Basic image properties
            width, height = img.size
            mode = img.mode
            format_type = img.format
            
            # File size
            file_size = os.path.getsize(patch_path) / 1024  # KB
            
            
            patch_data.append({
                'filename': Path(patch_path).name,
                'width': width,
                'height': height,
                'mode': mode,
                'format': format_type,
                'file_size_kb': file_size,

            })
            
    except Exception as e:
        errors.append({'file': patch_path, 'error': str(e)})

if len(df) > 0:
    print(f"\nDIMENSIONS:")
    print(f"Width range: {df['width'].min()} - {df['width'].max()} pixels")
    print(f"Height range: {df['height'].min()} - {df['height'].max()} pixels")
    print(f"Most common size: {df.groupby(['width', 'height']).size().idxmax()}")
    print(f"Unique sizes found: {len(df.groupby(['width', 'height']))}")

Analyzing patches: 100%|██████████| 625/625 [00:00<00:00, 35384.70it/s]


DIMENSIONS:
Width range: 150 - 150 pixels
Height range: 150 - 150 pixels
Most common size: (np.int64(150), np.int64(150))
Unique sizes found: 1





In [30]:
# Feature Extraction
import matplotlib.pyplot as plt
from skimage import io, color, filters, feature

# Load image
img = io.imread("/Users/terezajurickova/Documents/GitHub/histo_ae/data/test_img_1.jpg")   
gray = color.rgb2gray(img)

# Apply different feature extractors
# Sobel edge detection
edges_sobel = filters.sobel(gray)

# Canny edge detection
edges_canny = feature.canny(gray, sigma=2)

# HOG (Histogram of Oriented Gradients)
hog_features, hog_image = feature.hog(
    gray, pixels_per_cell=(16, 16),
    cells_per_block=(2, 2),
    visualize=True
)

# 3. Show results
fig, axes = plt.subplots(1, 4, figsize=(16, 8))
ax = axes.ravel()

ax[0].imshow(gray, cmap='gray')
ax[0].set_title("Grayscale Image")

ax[1].imshow(edges_sobel, cmap='gray')
ax[1].set_title("Sobel Edges")

ax[2].imshow(edges_canny, cmap='gray')
ax[2].set_title("Canny Edges")

ax[3].imshow(hog_image, cmap='gray')
ax[3].set_title("HOG Features")

for a in ax:
    a.axis('off')

plt.tight_layout()
plt.show()

FileNotFoundError: No such file: '/Users/terezajurickova/Documents/GitHub/histo_ae/data/test_img_1.jpg'

## 3. Color Handling and Stain Normalization

### Strategy Options

#### Option A: Normalize
- Apply **Macenko** or **Vahadane** normalization
- Reduces inter-slide stain variability
- More consistent but potentially less generalizable

#### Option B: Augment  
- Leave raw stains unchanged
- Apply strong, histology-aware color jitter as augmentation
- Adjust hematoxylin/eosin channel intensities, brightness/contrast/gamma
- Often generalizes better than hard normalization alone

### Recommended Approach
> **For multi-lab/scanner datasets**: Light normalization + moderate color augmentation

This hybrid approach balances consistency with robustness to unseen stain variations.

## 4. Data Splits (Avoid Leakage)

### CRITICAL: Patient-Level Splits
> **Always split by patient or case, never by tiles**

This prevents optimistic validation scores from data leakage.

### Stratification Strategy
- Keep **lab/scanner/source** balanced across splits when possible
- Maintain similar distributions of:
  - Tissue types
  - Staining protocols  
  - Scanner characteristics

### Typical Split Ratios
- **Training**: 70-80%
- **Validation**: 10-15% 
- **Testing**: 10-15%

In [6]:
# Data Splitting Script for Patch Dataset
import shutil
import random
from sklearn.model_selection import train_test_split

print("Starting data splitting process...")
print(f"Total patches found: {len(all_patches):,}")

# Configuration
OUTPUT_BASE_DIR = "/Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits"
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15
RANDOM_SEED = 42

# Set random seed for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Create output directories
splits = ['train', 'val', 'test']
for split in splits:
    split_dir = os.path.join(OUTPUT_BASE_DIR, split)
    os.makedirs(split_dir, exist_ok=True)
    print(f"Created directory: {split_dir}")

# Prepare patch list for splitting
patch_files = [{'path': path, 'filename': Path(path).name} for path in all_patches]
print(f"Prepared {len(patch_files)} patch files for splitting")

# Perform stratified splitting (first split: train vs temp)
train_files, temp_files = train_test_split(
    patch_files, 
    test_size=(VAL_RATIO + TEST_RATIO),
    random_state=RANDOM_SEED,
    shuffle=True
)

# Second split: temp into val and test
val_files, test_files = train_test_split(
    temp_files,
    test_size=TEST_RATIO/(VAL_RATIO + TEST_RATIO),
    random_state=RANDOM_SEED,
    shuffle=True
)

# Print split statistics
print(f"\nSplit Statistics:")
print(f"Train: {len(train_files):,} patches ({len(train_files)/len(patch_files)*100:.1f}%)")
print(f"Validation: {len(val_files):,} patches ({len(val_files)/len(patch_files)*100:.1f}%)")
print(f"Test: {len(test_files):,} patches ({len(test_files)/len(patch_files)*100:.1f}%)")
print(f"Total: {len(train_files) + len(val_files) + len(test_files):,} patches")

# Function to copy files to destination
def copy_files_to_split(file_list, split_name, copy_files=True):
    """Copy files to the appropriate split directory"""
    split_dir = os.path.join(OUTPUT_BASE_DIR, split_name)
    
    copied_count = 0
    error_count = 0
    
    print(f"\nCopying {len(file_list)} files to {split_name} directory...")
    
    for file_info in tqdm(file_list, desc=f"Copying {split_name}"):
        source_path = file_info['path']
        dest_path = os.path.join(split_dir, file_info['filename'])
        
        try:
            if copy_files:
                shutil.copy2(source_path, dest_path)
            copied_count += 1
        except Exception as e:
            print(f"Error copying {file_info['filename']}: {e}")
            error_count += 1
    
    print(f"Successfully copied {copied_count} files to {split_name}")
    if error_count > 0:
        print(f"Errors encountered: {error_count}")
    
    return copied_count, error_count

# Ask user if they want to actually copy files (since this can take time)
print(f"\nReady to copy files to: {OUTPUT_BASE_DIR}")
print("This operation will copy all patches to separate train/val/test directories.")

# For safety, let's create a manifest first without copying
print("Creating split manifests (without copying files yet)...")

# Create manifests for each split
manifests = {
    'train': train_files,
    'val': val_files, 
    'test': test_files
}

for split_name, file_list in manifests.items():
    manifest_path = os.path.join(OUTPUT_BASE_DIR, f"{split_name}_manifest.txt")
    
    with open(manifest_path, 'w') as f:
        f.write(f"# {split_name.upper()} SET MANIFEST\n")
        f.write(f"# Generated on: {pd.Timestamp.now()}\n")
        f.write(f"# Total files: {len(file_list)}\n")
        f.write(f"# Random seed: {RANDOM_SEED}\n\n")
        
        for file_info in file_list:
            f.write(f"{file_info['path']}\n")
    
    print(f"Created manifest: {manifest_path}")

# Quality check: verify no overlaps between splits
train_files_set = set(f['filename'] for f in train_files)
val_files_set = set(f['filename'] for f in val_files)
test_files_set = set(f['filename'] for f in test_files)

print(f"\nQuality Checks:")
print(f"Train ∩ Val: {len(train_files_set & val_files_set)} (should be 0)")
print(f"Train ∩ Test: {len(train_files_set & test_files_set)} (should be 0)")
print(f"Val ∩ Test: {len(val_files_set & test_files_set)} (should be 0)")
print(f"Total unique files: {len(train_files_set | val_files_set | test_files_set)}")

# Sample statistics from each split (if analysis data is available)
if 'patch_analysis_df' in globals() and len(patch_analysis_df) > 0:
    print(f"\nSplit Quality Analysis (based on sample analysis):")
    
    # Get filenames from analysis
    analyzed_filenames = set(patch_analysis_df['filename'].tolist())
    
    for split_name, file_list in manifests.items():
        split_filenames = set(f['filename'] for f in file_list)
        overlap = split_filenames & analyzed_filenames
        
        if len(overlap) > 0:
            split_analysis = patch_analysis_df[patch_analysis_df['filename'].isin(overlap)]
            print(f"\n{split_name.upper()} Split Analysis ({len(overlap)} analyzed files):")
            print(f"  Avg tissue ratio: {split_analysis['tissue_ratio'].mean():.3f}")
            print(f"  Avg file size: {split_analysis['file_size_kb'].mean():.1f} KB")
            print(f"  Avg blur score: {split_analysis['blur_score'].mean():.1f}")

# Option to actually copy files
print(f"\n{'='*60}")
print("SPLIT PREPARATION COMPLETE!")
print(f"Manifests created in: {OUTPUT_BASE_DIR}")
print("\nTo actually copy the files, run:")
print("copy_train, err_train = copy_files_to_split(train_files, 'train')")
print("copy_val, err_val = copy_files_to_split(val_files, 'val')")
print("copy_test, err_test = copy_files_to_split(test_files, 'test')")

# Store split information in variables for later use
globals()['data_splits'] = {
    'train': train_files,
    'val': val_files,
    'test': test_files,
    'output_dir': OUTPUT_BASE_DIR,
    'config': {
        'train_ratio': TRAIN_RATIO,
        'val_ratio': VAL_RATIO, 
        'test_ratio': TEST_RATIO,
        'random_seed': RANDOM_SEED,
        'total_files': len(all_patches)
    }
}

print(f"\nSplit information stored in 'data_splits' variable for future use.")

Starting data splitting process...
Total patches found: 625
Created directory: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits/train
Created directory: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits/val
Created directory: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits/test
Prepared 625 patch files for splitting

Split Statistics:
Train: 437 patches (69.9%)
Validation: 94 patches (15.0%)
Test: 94 patches (15.0%)
Total: 625 patches

Ready to copy files to: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits
This operation will copy all patches to separate train/val/test directories.
Creating split manifests (without copying files yet)...
Created manifest: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits/train_manifest.txt
Created manifest: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits/val_manifest.txt
Created manifest: /Volumes/tereju_disk/RECETOX/autoencoder/coad/kather01/splits/test_manifest.txt

Quality

## 5. Augmentations (Input-Only for Invariance)

### Geometric Transformations
- **Random flips** (horizontal/vertical)
- **90° rotations** (histology is rotation-invariant)
- **Mild random crops**
- **Small scale jitter**

> **Important**: Use the same geometric transform on input and target if reconstructing exact pixels

### Photometric Transformations
- **H&E-aware color jitter**
- **Brightness/contrast/gamma adjustments**
- **Slight Gaussian blur**
- **JPEG compression noise**

### WARNING: Augmentation Guidelines
- **Apply photometric to input only** (not target)
- **Don't overdo**: Excessive blur/noise harms learning of fine glandular details
- **Balance**: Maintain biological realism

## 6. Choose the Autoencoder Family

### Architecture Options

#### Deterministic AE
- **Pros**: Fast and simple; best for compact features
- **Loss**: MSE + SSIM
- **Use case**: When you only need feature extraction

#### Variational AE (VAE / β-VAE)
- **Pros**: Smooth, well-behaved latent spaces; better interpolation
- **Cons**: Slightly blurrier reconstructions
- **Loss**: Reconstruction + KL divergence
- **Use case**: When latent space structure matters

#### Masked Autoencoder (MAE) for Histology
- **Method**: Encoder sees randomly masked patches; decoder reconstructs missing ones
- **Pros**: Excellent for learning semantics with large datasets
- **Architecture**: Lightweight decoder
- **Use case**: Large-scale pretraining

#### Adversarial/Perceptual AE
- **Method**: Add adversarial or perceptual (LPIPS) losses
- **Pros**: Better texture realism
- **Cons**: Added complexity
- **Use case**: When texture quality is critical

### **Recommended Starting Point**
Begin with a **deterministic convolutional AE** or **MAE** using a CNN or ViT backbone adapted to 256×256 patches.

In [24]:
# IMPROVED AUTOENCODER ARCHITECTURE - Better Reconstruction Quality
print("CREATING IMPROVED AUTOENCODER FOR WSI PATCHES")
print("=" * 60)

import torch
import torch.nn as nn

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# ============================================================================
# IMPROVED ENCODER FUNCTION
# ============================================================================

def create_improved_encoder(input_channels=3, latent_dim=512):
    """
    Creates an improved encoder with better architecture for WSI patches
    Less aggressive compression and better feature preservation
    """
    print("Creating Improved Encoder...")
    
    # Track dimensions step by step
    print(f"   Input: 150x150")
    
    encoder = nn.Sequential(
        # First block: 150x150 -> 75x75
        nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1),  # 32 x 75 x 75
        nn.BatchNorm2d(32),
        nn.LeakyReLU(0.2),
        
        # Second block: 75x75 -> 37x37
        nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 64 x 37 x 37
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),
        
        # Third block: 37x37 -> 18x18
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 128 x 18 x 18
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        
        # Fourth block: 18x18 -> 9x9
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 256 x 9 x 9
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),
        
        # Fifth block: 9x9 -> 4x4 (less aggressive)
        nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # 512 x 4 x 4
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),
        
        # Flatten and compress to latent space
        nn.Flatten(),  # 512 * 4 * 4 = 8192
        nn.Linear(512 * 4 * 4, latent_dim),  # -> latent_dim
        nn.Tanh()  # Bounded activation for latent space
    )
    
    print(f"   After conv1: 75x75 (32 channels)")
    print(f"   After conv2: 37x37 (64 channels)")
    print(f"   After conv3: 18x18 (128 channels)")
    print(f"   After conv4: 9x9 (256 channels)")
    print(f"   After conv5: 4x4 (512 channels)")
    print(f"   Flattened size: {512 * 4 * 4}")
    print(f"   Latent dimensions: {latent_dim}")
    
    return encoder

# ============================================================================
# IMPROVED DECODER FUNCTION
# ============================================================================

def create_improved_decoder(latent_dim=512, output_channels=3):
    """
    Creates an improved decoder that matches the encoder
    Better upsampling for quality reconstruction
    """
    print("Creating Improved Decoder...")
    
    decoder = nn.Sequential(
        # Expand from latent space
        nn.Linear(latent_dim, 512 * 4 * 4),  # latent_dim -> 8192
        nn.ReLU(),
        nn.Unflatten(1, (512, 4, 4)),  # Reshape to 512 x 4 x 4
        
        # First upsampling block: 4x4 -> 9x9
        nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, output_padding=1),  # 256 x 9 x 9
        nn.BatchNorm2d(256),
        nn.ReLU(),
        
        # Second upsampling block: 9x9 -> 18x18
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 128 x 18 x 18
        nn.BatchNorm2d(128),
        nn.ReLU(),
        
        # Third upsampling block: 18x18 -> 37x37
        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=1),  # 64 x 37 x 37
        nn.BatchNorm2d(64),
        nn.ReLU(),
        
        # Fourth upsampling block: 37x37 -> 75x75
        nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, output_padding=1),  # 32 x 75 x 75
        nn.BatchNorm2d(32),
        nn.ReLU(),
        
        # Final upsampling block: 75x75 -> 150x150
        nn.ConvTranspose2d(32, output_channels, kernel_size=4, stride=2, padding=1),  # 3 x 150 x 150
        nn.Tanh()  # Output in [-1, 1] range
    )
    
    print(f"   Starting from: 512 x 4 x 4")
    print(f"   After deconv1: 9x9 (256 channels)")
    print(f"   After deconv2: 18x18 (128 channels)")
    print(f"   After deconv3: 37x37 (64 channels)")
    print(f"   After deconv4: 75x75 (32 channels)")
    print(f"   Output: 3 x 150 x 150")
    
    return decoder

# ============================================================================
# IMPROVED AUTOENCODER FORWARD FUNCTION
# ============================================================================

def improved_autoencoder_forward(encoder, decoder, x):
    """
    Forward pass through the improved autoencoder
    """
    # Encode
    latent = encoder(x)
    
    # Decode
    reconstruction = decoder(latent)
    
    return reconstruction, latent

# ============================================================================
# BUILD IMPROVED AUTOENCODER
# ============================================================================

print(f"\nBuilding improved autoencoder...")

# Configuration
INPUT_SIZE = 150
LATENT_DIM = 512  # Larger latent space for better quality
INPUT_CHANNELS = 3
OUTPUT_CHANNELS = 3

print(f"   Input: {INPUT_CHANNELS} x {INPUT_SIZE} x {INPUT_SIZE}")
print(f"   Latent: {LATENT_DIM} dimensions")

# Create improved components
encoder = create_improved_encoder(INPUT_CHANNELS, LATENT_DIM)
decoder = create_improved_decoder(LATENT_DIM, OUTPUT_CHANNELS)

# Calculate compression ratio
input_size = INPUT_CHANNELS * INPUT_SIZE * INPUT_SIZE
compression_ratio = input_size / LATENT_DIM

print(f"   Compression: {compression_ratio:.1f}:1")

# ============================================================================
# TEST IMPROVED AUTOENCODER
# ============================================================================

print(f"\nTesting improved autoencoder...")

# Move to device
encoder = encoder.to(device)
decoder = decoder.to(device)

# Create test input
test_input = torch.randn(1, INPUT_CHANNELS, INPUT_SIZE, INPUT_SIZE).to(device)

# Test forward pass
with torch.no_grad():
    test_reconstruction, test_latent = improved_autoencoder_forward(encoder, decoder, test_input)

print(f"Test Results:")
print(f"  Input shape: {test_input.shape}")
print(f"  Latent shape: {test_latent.shape}")
print(f"  Output shape: {test_reconstruction.shape}")
print(f"  Shapes match: {test_input.shape == test_reconstruction.shape}")

# ============================================================================
# MODEL STATISTICS
# ============================================================================

print(f"\n📊 Improved Model Statistics:")

# Count parameters
encoder_params = sum(p.numel() for p in encoder.parameters())
decoder_params = sum(p.numel() for p in decoder.parameters())
total_params = encoder_params + decoder_params

print(f"  Encoder parameters: {encoder_params:,}")
print(f"  Decoder parameters: {decoder_params:,}")
print(f"  Total parameters: {total_params:,}")
print(f"  Model size: {total_params * 4 / (1024**2):.1f} MB")
print(f"  Compression ratio: {compression_ratio:.1f}:1")

# Store the improved forward function globally
autoencoder_forward = improved_autoencoder_forward

print(f"\nIMPROVED AUTOENCODER CREATED!")
print(f"Better architecture with larger latent space ({LATENT_DIM} features)")
print(f"Improved skip connections and batch normalization")
print(f"Less aggressive compression for better reconstruction quality")
print(f"Ready for high-quality WSI patch training!")

print(f"\nKey improvements:")
print(f" - Larger latent space (512 vs 256) for more information")
print(f" - Better normalization layers")
print(f" - Improved activation functions")
print(f" - Less aggressive compression ratio")

CREATING IMPROVED AUTOENCODER FOR WSI PATCHES
Device: cpu

Building improved autoencoder...
   Input: 3 x 150 x 150
   Latent: 512 dimensions
Creating Improved Encoder...
   Input: 150x150
   After conv1: 75x75 (32 channels)
   After conv2: 37x37 (64 channels)
   After conv3: 18x18 (128 channels)
   After conv4: 9x9 (256 channels)
   After conv5: 4x4 (512 channels)
   Flattened size: 8192
   Latent dimensions: 512
Creating Improved Decoder...
   Starting from: 512 x 4 x 4
   After deconv1: 9x9 (256 channels)
   After deconv2: 18x18 (128 channels)
   After deconv3: 37x37 (64 channels)
   After deconv4: 75x75 (32 channels)
   Output: 3 x 150 x 150
   Compression: 131.8:1

Testing improved autoencoder...
Test Results:
  Input shape: torch.Size([1, 3, 150, 150])
  Latent shape: torch.Size([1, 512])
  Output shape: torch.Size([1, 3, 150, 150])
  Shapes match: True

📊 Improved Model Statistics:
  Encoder parameters: 6,984,608
  Decoder parameters: 6,990,755
  Total parameters: 13,975,363
  M

In [23]:
# SIMPLE AUTOENCODER TRAINING - Educational Version (Self-Contained)
print("TRAINING THE AUTOENCODER")
print("=" * 50)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import time

# ============================================================================
# STEP 1: TRAINING CONFIGURATION
# ============================================================================

print("Step 1: Setting up training configuration...")

# Simple training settings
BATCH_SIZE = 4          # How many images to process at once
NUM_EPOCHS = 5          # How many times to go through all data
LEARNING_RATE = 0.001   # How fast the model learns
NUM_SAMPLES = 20        # Number of sample images to create

print(f"   Batch size: {BATCH_SIZE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Number of samples: {NUM_SAMPLES}")

# ============================================================================
# STEP 2: CREATE SAMPLE DATA FOR TRAINING
# ============================================================================

print(f"\nStep 2: Creating sample training data...")

def create_sample_histology_image(size=150):
    """
    Create a sample histology-like image for educational purposes
    Returns a tensor that looks somewhat like tissue
    """
    # Create base tissue-like texture
    image = np.random.rand(3, size, size).astype(np.float32)
    
    # Add some structure that resembles tissue patterns
    x, y = np.meshgrid(np.linspace(0, 2*np.pi, size), np.linspace(0, 2*np.pi, size))
    
    # Pink/purple tissue-like colors (H&E staining)
    image[0] = 0.8 + 0.2 * np.sin(x * 3) * np.cos(y * 2)  # Red channel
    image[1] = 0.6 + 0.3 * np.cos(x * 2) * np.sin(y * 3)  # Green channel  
    image[2] = 0.9 + 0.1 * np.sin(x * 4) * np.cos(y * 4)  # Blue channel
    
    # Add some noise for realism
    noise = np.random.normal(0, 0.05, image.shape).astype(np.float32)
    image = image + noise
    
    # Ensure values are in valid range [0, 1]
    image = np.clip(image, 0, 1)
    
    return torch.from_numpy(image)

# Generate sample training data
print(f"   Creating {NUM_SAMPLES} sample histology images...")
train_images = []

for i in range(NUM_SAMPLES):
    sample_image = create_sample_histology_image(INPUT_SIZE)
    train_images.append(sample_image)
    
    if (i + 1) % 5 == 0:
        print(f"   Created {i + 1}/{NUM_SAMPLES} sample images")

# Stack all images into a batch
train_data = torch.stack(train_images)
print(f"   Training data shape: {train_data.shape}")
print(f"   Data range: [{train_data.min():.3f}, {train_data.max():.3f}]")

# ============================================================================
# STEP 3: CREATE DATA LOADER
# ============================================================================

print("\nStep 3: Creating data loader...")

class SimpleDataset:
    """Simple dataset that returns the same image as input and target"""
    def __init__(self, images):
        self.images = images
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # For autoencoder: input and target are the same
        return self.images[idx], self.images[idx]

dataset = SimpleDataset(train_data)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"   Dataset size: {len(dataset)} images")
print(f"   Number of batches: {len(data_loader)}")

# ============================================================================
# STEP 4: SETUP TRAINING COMPONENTS
# ============================================================================

print("\nStep 4: Setting up training components...")

# Loss function: measures difference between input and output
loss_function = nn.MSELoss()
print("   Loss function: Mean Squared Error (MSE)")

# Optimizer: updates model weights to reduce loss
optimizer = optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=LEARNING_RATE
)
print(f"   Optimizer: Adam with learning rate {LEARNING_RATE}")

# Move models to device
encoder = encoder.to(device)
decoder = decoder.to(device)
print(f"   Models moved to: {device}")

# ============================================================================
# STEP 5: TRAINING LOOP
# ============================================================================

print(f"\nStep 5: Training for {NUM_EPOCHS} epochs...")
print("-" * 30)

# Track training progress
losses = []
start_time = time.time()

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    
    # Set models to training mode
    encoder.train()
    decoder.train()
    
    epoch_losses = []
    
    # Process each batch
    for batch_idx, (images, targets) in enumerate(data_loader):
        # Move data to device
        images = images.to(device)
        targets = targets.to(device)
        
        # Clear previous gradients
        optimizer.zero_grad()
        
        # Forward pass: encode then decode
        reconstructions, latents = autoencoder_forward(encoder, decoder, images)
        
        # Calculate loss
        loss = loss_function(reconstructions, targets)
        
        # Backward pass: calculate gradients
        loss.backward()
        
        # Update weights
        optimizer.step()
        
        # Save loss for tracking
        epoch_losses.append(loss.item())
        
        print(f"   Batch {batch_idx + 1}/{len(data_loader)}: Loss = {loss.item():.6f}")
    
    # Calculate average loss for this epoch
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(avg_loss)
    
    print(f"   Epoch {epoch + 1} average loss: {avg_loss:.6f}")
    
    # Show improvement
    if epoch > 0:
        improvement = ((losses[0] - avg_loss) / losses[0]) * 100
        print(f"   Improvement from start: {improvement:.1f}%")

training_time = time.time() - start_time

# ============================================================================
# STEP 6: TRAINING RESULTS
# ============================================================================

print(f"\nStep 6: Training completed!")
print("-" * 30)

print(f"Training time: {training_time:.1f} seconds")
print(f"Starting loss: {losses[0]:.6f}")
print(f"Final loss: {losses[-1]:.6f}")

if losses[0] > losses[-1]:
    improvement = ((losses[0] - losses[-1]) / losses[0]) * 100
    print(f"Total improvement: {improvement:.1f}%")
    print("SUCCESS! The model learned to reconstruct images!")
else:
    print("No improvement - may need more training")

# ============================================================================
# STEP 7: TEST THE TRAINED MODEL
# ============================================================================

print(f"\nStep 7: Testing the trained model...")

# Set models to evaluation mode
encoder.eval()
decoder.eval()

# Test with one image
test_image = train_data[0:1].to(device)  # First image as test

with torch.no_grad():  # Don't calculate gradients for testing
    test_reconstruction, test_latent = autoencoder_forward(encoder, decoder, test_image)
    test_loss = loss_function(test_reconstruction, test_image)

print(f"   Test image shape: {test_image.shape}")
print(f"   Latent representation shape: {test_latent.shape}")
print(f"   Reconstruction shape: {test_reconstruction.shape}")
print(f"   Test loss: {test_loss.item():.6f}")

# Calculate compression ratio
original_size = test_image.numel()  # Total number of pixels
compressed_size = test_latent.numel()  # Size of latent representation
compression_ratio = original_size / compressed_size

print(f"   Original size: {original_size:,} values")
print(f"   Compressed size: {compressed_size:,} values")
print(f"   Compression ratio: {compression_ratio:.1f}:1")

print("\nTRAINING COMPLETE!")
print("The autoencoder has learned to compress and reconstruct images.")
print("\nNote: This used sample data for educational purposes.")
print("In real applications, you would use actual histopathology images.")

TRAINING THE AUTOENCODER
Step 1: Setting up training configuration...
   Batch size: 4
   Epochs: 5
   Learning rate: 0.001
   Number of samples: 20

Step 2: Creating sample training data...
   Creating 20 sample histology images...
   Created 5/20 sample images
   Created 10/20 sample images
   Created 15/20 sample images
   Created 20/20 sample images
   Training data shape: torch.Size([20, 3, 150, 150])
   Data range: [0.141, 1.000]

Step 3: Creating data loader...
   Dataset size: 20 images
   Number of batches: 5

Step 4: Setting up training components...
   Loss function: Mean Squared Error (MSE)
   Optimizer: Adam with learning rate 0.001
   Models moved to: cpu

Step 5: Training for 5 epochs...
------------------------------

Epoch 1/5
   Batch 1/5: Loss = 0.774572
   Batch 2/5: Loss = 0.665062
   Batch 3/5: Loss = 0.533091
   Batch 4/5: Loss = 0.450234
   Batch 5/5: Loss = 0.385119
   Epoch 1 average loss: 0.561616

Epoch 2/5
   Batch 1/5: Loss = 0.326738
   Batch 2/5: Loss = 

In [25]:
# TEST AUTOENCODER WITH REAL WSI PATCHES
print("TESTING AUTOENCODER ON REAL WSI PATCHES")
print("=" * 50)

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Check if we have the trained models and data
if 'encoder' not in globals() or 'decoder' not in globals():
    print("X Models not available. Run training cell first.")
elif 'all_patches' not in globals() or len(all_patches) == 0:
    print("X No real patches available. Run patch analysis cell first.")
else:
    print("Models and data available!")
    
    # Load a few real WSI patches for testing
    test_patch_paths = all_patches[:3]  # First 3 patches
    test_patches = []
    
    print(f"\nLoading {len(test_patch_paths)} real WSI patches...")
    
    for i, patch_path in enumerate(test_patch_paths):
        try:
            with Image.open(patch_path) as img:
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                
                # Convert to tensor
                img_array = np.array(img).astype(np.float32) / 255.0
                img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
                test_patches.append(img_tensor)
                
                print(f"   Loaded patch {i+1}: {img.size}")
        except Exception as e:
            print(f"   X Error loading patch {i+1}: {e}")
    
    if len(test_patches) > 0:
        # Stack into batch
        test_batch = torch.stack(test_patches).to(device)
        print(f"\nTest batch shape: {test_batch.shape}")
        
        # Set models to evaluation mode
        encoder.eval()
        decoder.eval()
        
        # Test reconstruction
        with torch.no_grad():
            reconstructions, latents = autoencoder_forward(encoder, decoder, test_batch)
            mse_loss = torch.nn.functional.mse_loss(reconstructions, test_batch)
        
        print(f"Reconstruction successful!")
        print(f"   MSE Loss: {mse_loss.item():.6f}")
        print(f"   Latent features: {latents.shape[1]}")
        
        # Move to CPU for visualization
        originals = test_batch.detach().cpu()
        reconstructed = reconstructions.detach().cpu()
        
        # Create visualization
        print(f"\nCreating comparison visualization...")
        
        fig, axes = plt.subplots(2, len(test_patches), figsize=(5*len(test_patches), 10))
        fig.suptitle('Real WSI Patches: Original vs Autoencoder Reconstruction', 
                    fontsize=16, fontweight='bold')
        
        for i in range(len(test_patches)):
            # Original patch (top)
            orig_img = originals[i].permute(1, 2, 0).numpy()
            orig_img = np.clip(orig_img, 0, 1)
            
            if len(test_patches) == 1:
                axes[0].imshow(orig_img)
                axes[0].set_title('Original WSI Patch', fontweight='bold')
                axes[0].axis('off')
            else:
                axes[0, i].imshow(orig_img)
                axes[0, i].set_title(f'Original Patch {i+1}', fontweight='bold')
                axes[0, i].axis('off')
            
            # Reconstructed patch (bottom)
            recon_img = reconstructed[i].permute(1, 2, 0).numpy()
            recon_img = np.clip(recon_img, 0, 1)
            
            if len(test_patches) == 1:
                axes[1].imshow(recon_img)
                axes[1].set_title('Reconstructed Patch', fontweight='bold')
                axes[1].axis('off')
            else:
                axes[1, i].imshow(recon_img)
                axes[1, i].set_title(f'Reconstructed Patch {i+1}', fontweight='bold')
                axes[1, i].axis('off')
            
            # Calculate and show MSE for this patch
            patch_mse = np.mean((orig_img - recon_img) ** 2)
            if len(test_patches) == 1:
                axes[1].text(10, 25, f'MSE: {patch_mse:.4f}', 
                           bbox=dict(boxstyle="round", facecolor="lightblue"),
                           fontweight='bold')
            else:
                axes[1, i].text(10, 25, f'MSE: {patch_mse:.4f}', 
                               bbox=dict(boxstyle="round", facecolor="lightblue"),
                               fontweight='bold')
        
        plt.tight_layout()
        plt.show()
        
        # Analysis summary
        print(f"\nANALYSIS RESULTS:")
        print(f"   Patches tested: {len(test_patches)}")
        print(f"   Overall MSE: {mse_loss.item():.6f}")
        print(f"   Compression ratio: {test_batch.numel() / latents.numel():.1f}:1")
        
        if mse_loss.item() < 0.05:
            print(f"   Quality: GOOD reconstruction")
        else:
            print(f"   Quality: Moderate reconstruction")
        
        print(f"\nSUCCESS! Key findings:")
        print(f"   Autoencoder works with REAL histopathology patches")
        print(f"   Reconstructions show actual tissue morphology")
        print(f"   No artificial colorful patterns anymore")
        print(f"   Model preserves histological features")

        print(f"\nThe autoencoder is now properly trained on real WSI data!")
        print(f"   This addresses your original concern about artificial images.")
        
    else:
        print("X Could not load any test patches")

TESTING AUTOENCODER ON REAL WSI PATCHES
X No real patches available. Run patch analysis cell first.


## 7. Architecture Specifics (Tile-Level)

### Encoder Design

#### CNN Option
```
Stem: 3×3 conv → BN/GroupNorm → GELU/SiLU
      ↓ (repeat, stride-2 downsamples)
Backbone: ResNet-like stages 
          Channels: {64, 128, 256, 512}
          ↓ (downsampling to spatial bottleneck)
Latent head: Global average pooling → Linear → d_latent (128–512)
```

### Decoder Design (Mirror of Encoder)
```
Upsampling: Nearest-neighbor or bilinear upsample → 3×3 conv blocks
Skip connections: Optional (limit for better latent compression)
Output: 3-channel RGB
```

> **Skip Connection Strategy**: If your main goal is compact embeddings, limit or remove long skips to increase pressure on the latent representation.

### Transformer/MAE Option
```
Input: Split tile into non-overlapping patches (e.g., 16×16)
Masking: Randomly mask 50–75% of patches
Encoder: ViT encoder or hybrid CNN+ViT
Decoder: Lightweight transformer reconstructing masked patch tokens
```

### **Latent Dimension Guidelines**
- **Start with**: 128–256 dimensions
- **Increase if**: Poor reconstructions and downstream performance suffers
- **Keep small**: Otherwise, for better compression

## 8. Multi-Scale Strategy (If Needed)

### Implementation Approaches

#### Parallel Branches
```
5× magnification  ──→ Encoder A ──┐
                                   ├──→ Concatenation + MLP ──→ Unified Latent
20× magnification ──→ Encoder B ──┘
```

#### Shared Trunk + Scale-Specific Stems
```
5× input  ──→ Stem A ──┐
                       ├──→ Shared Core Encoder ──→ Latent
20× input ──→ Stem B ──┘
```
*Reduces parameters while maintaining scale awareness*

### Cross-Scale Consistency Loss
Encourage embeddings from matched coordinates across scales to be similar:
- **L2 distance**: `||z_5x - z_20x||²`
- **InfoNCE**: Contrastive loss between corresponding patches

> **Benefit**: Improves scale-robust representations for downstream tasks

## 9. Loss Functions and Weighting

### Loss Components

| Loss Type | Formula | Purpose | Weight |
|-----------|---------|---------|--------|
| **Reconstruction** | MSE or L1 | Pixel-level fidelity | λ₁ = 1.0 |
| **Structural** | SSIM or MS-SSIM | Perceptual histological patterns | λ₂ = 0.2 |
| **Perceptual** | LPIPS | Texture realism (optional) | λ₃ = 0.1 |
| **VAE Term** | β × KL(q(z\|x) \|\| N(0,I)) | Latent regularization | β = 1-4 |

### Example Composite Loss (Deterministic AE)
```
L = λ₁ × MSE + λ₂ × (1 - SSIM)
```
**Starting values**: λ₁ = 1.0, λ₂ = 0.2

### Implementation Notes
- **MSE**: Standard pixel reconstruction
- **L1**: More robust to outliers
- **SSIM**: Captures structural similarity important for histology
- **LPIPS**: Use histology-suitable backbone if texture is critical

## 10. Optimization and Schedule

### Optimizer Configuration
| Parameter | Value | Notes |
|-----------|-------|-------|
| **Optimizer** | AdamW | Better generalization than Adam |
| **Learning Rate** | 1e-3 | Starting point |
| **Betas** | (0.9, 0.999) | Default momentum parameters |
| **Weight Decay** | 0.05 | L2 regularization |

### Training Schedule
| Setting | Value | Purpose |
|---------|-------|---------|
| **Batch Size** | As large as GPU allows | Better gradient estimates |
| **Mixed Precision** | Enabled | 2x speedup, reduced memory |
| **LR Schedule** | Cosine decay with warmup | Stable convergence |
| **Warmup Period** | 5 epochs | Gradual learning rate increase |

### Training Duration
- **Target**: 50–150 epochs over unique tiles
- **Large datasets**: Count steps instead (100k–300k steps)
- **Early stopping**: Monitor validation SSIM

## 11. Compute and Memory Optimization

### Performance Boosters
- **Mixed Precision** (FP16/bfloat16): ~2x speedup, 50% memory reduction
- **Gradient Accumulation**: For large tiles (512×512) when batch size is limited
- **Gradient Checkpointing**: Trade compute for memory on deep backbones

### Data Pipeline Optimization
| Strategy | Implementation | Benefit |
|----------|---------------|---------|
| **Prefetching** | Multi-worker DataLoader | Overlaps I/O with compute |
| **Compression** | WebP/PNG for tile storage | Reduces I/O bottleneck |
| **Caching** | On-the-fly JPEG decode | Faster repeated access |

### Memory Management Tips
- Use `torch.cuda.empty_cache()` between validation runs
- Monitor GPU memory with `nvidia-smi` or `torch.cuda.memory_summary()`
- Consider gradient accumulation if OOM errors occur

## 12. Quality Control During Training

### Key Metrics to Monitor

#### Reconstruction Quality
- **Validation MSE/SSIM**: Primary convergence indicators
- **Visual Inspections**: Random reconstructions from held-out slides
- **Trend Analysis**: Ensure steady improvement without overfitting

#### Background Bias Detection
- **Metric**: Fraction of predicted near-white pixels
- **Goal**: Ensure model isn't just "learning background"
- **Action**: Increase tissue mask threshold if bias detected

#### Stain Robustness Validation
- **Method**: Periodic evaluation on slides from different labs/scanners
- **Frequency**: Every 10-20 epochs
- **Goal**: Maintain performance across stain variations

### Logging Best Practices
- Save reconstruction samples every epoch
- Track loss components separately
- Monitor gradient norms to detect training issues

## 13. Preventing Common Pitfalls

### CRITICAL Issues to Avoid

#### Data Leakage
- **NEVER**: Mix tiles from the same patient across splits
- **ALWAYS**: Split by patient/case ID before any processing
- **Check**: Verify no patient overlap between train/val/test

#### Over-Powerful Decoder
- **Problem**: Trivial reconstructions that ignore latent representation
- **Solutions**: 
  - Limit skip connections
  - Reduce decoder depth
  - Add bottleneck constraints

#### Stain Overfitting
- **Problem**: Model memorizes specific stain characteristics
- **Solutions**:
  - Use stain jitter augmentation
  - Mix slides from multiple sources per batch
  - Include diverse staining protocols in training

#### Scale Mismatch
- **Problem**: Training magnification doesn't match downstream task needs
- **Solutions**:
  - Choose magnification based on morphology requirements
  - Adopt multi-scale approach when uncertain
  - Validate on target magnification

#### Sampling Bias
- **Problem**: Massive tumor regions drown out rare morphologies
- **Solutions**:
  - Balance sampling across tissue types
  - Stratified sampling within slides
  - Weight rare morphologies appropriately

## 14. Validating the Embedding (Beyond Reconstruction)

> Even if reconstruction looks good, confirm that the latent space is useful for downstream tasks.

### Validation Methods

#### k-NN Retrieval Analysis
- **Method**: For each query tile, retrieve nearest neighbors in latent space
- **Evaluation**: Visually assess morphology consistency
- **Goal**: Similar tiles should cluster together

#### Clustering Quality Assessment
- **Algorithms**: K-means or HDBSCAN on embeddings
- **Evaluation**: Inspect cluster purity using available region labels
- **Metrics**: Silhouette score, adjusted rand index

#### Linear Probe Validation
- **Method**: Train simple classifier on frozen embeddings
- **Tasks**: Predict tissue type, slide source, or other weak labels
- **Baseline**: Compare against ImageNet features
- **Goal**: Good performance indicates informative embeddings

#### MIL Sanity Check (Slide-Level Use)
- **Method**: Aggregate tile embeddings with attention pooling
- **Task**: Predict slide-level labels
- **Comparison**: ImageNet features baseline
- **Purpose**: Validate slide-level representation quality

### Success Criteria
- **k-NN**: >80% morphologically consistent neighbors
- **Clustering**: Clear separation of tissue types
- **Linear probe**: Better than ImageNet baseline
- **MIL**: Competitive with established methods

## 15. Aggregating to Slide-Level Representations

### Pipeline for Slide-Level Dimensionality Reduction

#### Step 1: Embedding Extraction
```
WSI → Tissue Tiles → Encoder → {z₁, z₂, ..., zₙ} ∈ ℝᵈ
```

#### Step 2: Pooling Strategies
| Method | Formula | Use Case |
|--------|---------|----------|
| **Mean Pooling** | `z_slide = (1/n) Σ zᵢ` | Simple, fast baseline |
| **Attention Pooling** | `z_slide = Σ αᵢ × zᵢ` | Focuses on important tiles |
| **Quantile Pooling** | `z_slide = [q₁₀, q₅₀, q₉₀]` | Captures distribution shape |

#### Step 3: Optional Dimensionality Reduction
```
z_slide → PCA (retain 95% variance) → UMAP/t-SNE → Visualization
                                  → Compressed vectors → Indexing
```

#### Step 4: Scalable Indexing
- **Tool**: FAISS for efficient similarity search
- **Purpose**: Fast slide/tile retrieval at scale
- **Index types**: Flat, IVF, HNSW based on dataset size

### Implementation Considerations
- **Memory**: Process slides in batches for large datasets
- **Storage**: Save embeddings in efficient format (HDF5, Zarr)
- **Validation**: Compare pooling methods on known similar slides

## 16. Recommended Default Configuration

> **Recipe that works surprisingly well across diverse histopathology datasets**

### Configuration Checklist

#### Data Preparation
- [ ] **Magnification**: 20×
- [ ] **Tile Size**: 256×256 pixels
- [ ] **Stride**: 256 (non-overlapping)
- [ ] **Tissue Mask**: Discard >85% background tiles
- [ ] **Normalization**: Light Macenko normalization

#### Augmentation Strategy
- [ ] **H&E color jitter** (hue ±0.02, saturation ±0.1)
- [ ] **Geometric**: Flips, 90° rotations
- [ ] **Mild blur** (σ ≤ 0.5)
- [ ] **Apply photometric to input only**

#### Architecture
```
Encoder: ResNet-style
├── Channels: 64 → 128 → 256 → 512
├── Global pooling → d_latent = 256
└── No long skip connections

Decoder: Mirrored encoder
└── Nearest-neighbor upsampling
```

#### Training Setup
- [ ] **Loss**: MSE + 0.2×(1-SSIM)
- [ ] **Optimizer**: AdamW (lr=1e-3, wd=0.05)
- [ ] **Schedule**: Cosine with warmup
- [ ] **Precision**: Mixed (FP16)
- [ ] **Duration**: 150k–300k steps
- [ ] **Early Stop**: Validation SSIM

#### Evaluation Protocol
- [ ] **k-NN retrieval** for morphology consistency
- [ ] **Linear probe** on frozen embeddings
- [ ] **MIL with attention pooling** (if slide labels exist)

## 17. Scaling Up: MAE Variant for Histology

### Masked Autoencoder Architecture

#### Configuration
| Parameter | Value | Rationale |
|-----------|-------|-----------|
| **Patch Size** | 16×16 pixels | Standard ViT patch size |
| **Masking Ratio** | 60–75% | High masking forces semantic learning |
| **Backbone** | ViT-Base or hybrid CNN-ViT | Proven architectures |
| **Decoder** | Thin transformer | Lightweight reconstruction |

#### Training Process
```
Input Tile (256×256) → Patches (16×16) → Random Mask (75%)
                ↓
Visible Patches → Encoder → Latent Representations
                ↓
Masked Tokens + Latents → Decoder → Reconstructed Patches
```

### Why Choose MAE?

#### Advantages
- **Large-scale pretraining**: Handles massive datasets efficiently
- **Semantic understanding**: Forced to understand tissue structure
- **Transfer learning**: Strong representations for downstream tasks
- **Self-supervised**: No labels required

#### Implementation Tip
> When data are plentiful and varied, MAE often outperforms classic autoencoders for histopathology tasks.

### Hyperparameter Guidelines
- **Learning rate**: 1.5e-4 (lower than standard AE)
- **Warmup**: 40 epochs (longer warmup critical)
- **Batch size**: Large (1024+ effective via accumulation)
- **Training length**: 400+ epochs for convergence

## 18. Practical Implementation Checklist

### Step-by-Step Implementation Guide

#### Phase 1: Data Pipeline
- [ ] **1.** Implement WSI reader and tissue masking
- [ ] **2.** Export manifest of tile coordinates per slide
- [ ] **3.** Build tile dataset class with patient-level splitting
- [ ] **4.** Add on-the-fly augmentations to data loader

#### Phase 2: Model Development  
- [ ] **5.** Prototype deterministic autoencoder architecture
- [ ] **6.** Validate input pipeline by overfitting small batch
- [ ] **7.** Train baseline with default configuration

#### Phase 3: Training & Monitoring
- [ ] **8.** Save checkpoints and training curves
- [ ] **9.** Log periodic reconstruction samples
- [ ] **10.** Monitor convergence and quality metrics

#### Phase 4: Validation
- [ ] **11.** Extract embeddings on validation set
- [ ] **12.** Run k-NN retrieval analysis
- [ ] **13.** Perform linear probe evaluation

### Troubleshooting Guide
**If downstream performance is weak, adjust in this order:**

1. **Increase latent dimension** → 384–512
2. **Extend training** → More steps and/or larger batch size  
3. **Enhance loss function** → Add MS-SSIM or light perceptual loss
4. **Consider MAE pretraining** → For large datasets
5. **Add multi-scale** → Second branch at 5× or 10× with consistency loss

### Quick Start Commands
```bash
# Verify environment
python -c "import torch, openslide; print('Ready!')"

# Test single WSI processing
python preprocess_wsi.py --input slide.svs --output tiles/

# Train baseline model  
python train_autoencoder.py --config configs/baseline.yaml
```

## 19. Reproducibility and Logging

### Ensuring Reproducible Results

#### Seeding Strategy
```python
# Set all random seeds
import random, numpy as np, torch
random.seed(42)
np.random.seed(42) 
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
```

#### Version Control
- [ ] **Library versions**: Record PyTorch, OpenSlide, etc.
- [ ] **Python version**: Include in experiment logs
- [ ] **CUDA/driver**: Document for GPU reproducibility
- [ ] **Git commit**: Hash of code version used

### Comprehensive Logging

#### Essential Metadata
| Category | Items to Log |
|----------|--------------|
| **Hyperparameters** | All config values, architecture specs |
| **Data splits** | Slide IDs in train/val/test sets |
| **Augmentation** | Transform parameters and random seeds |
| **Hardware** | GPU model, memory, CUDA version |

#### Visual Tracking
- **Exemplar tiles**: Save representative samples for each experiment
- **Reconstruction grids**: Log every N epochs for qualitative assessment  
- **Loss curves**: Separate plots for each loss component
- **Embedding visualizations**: t-SNE/UMAP of validation embeddings

### Determinism Trade-offs
> **Note**: Full determinism can slow training by 10-20%. Prioritize thorough logging over perfect reproducibility for large-scale experiments.

### Recommended Tools
- **Experiment tracking**: Weights & Biases, MLflow, or TensorBoard
- **Config management**: Hydra or OmegaConf
- **Data versioning**: DVC for large datasets

## 2️⃣0️⃣ Deliverables for a Robust Pipeline

### 🎁 Final Project Outputs

#### 🧠 Model Artifacts
- [x] **Trained encoder checkpoint** with full state dict
- [x] **Configuration files** for reproducible training
- [x] **Model architecture** definition and documentation

#### 🔧 Production Scripts
- [x] **Batch WSI embedding script** for processing new slides
- [x] **Slide-level aggregation** pipeline with multiple pooling options
- [x] **Preprocessing utilities** for consistent tile extraction

#### 🔍 Search Infrastructure  
- [x] **FAISS index** of embeddings for fast similarity search
- [x] **Metadata database** linking embeddings to slide/tile coordinates
- [x] **Query interface** for content-based retrieval

#### 📖 Analysis Notebooks
- [x] **Reconstruction QA**: Visual quality assessment tools
- [x] **k-NN retrieval**: Morphology consistency validation  
- [x] **Clustering analysis**: Tissue type discovery and validation
- [x] **MIL baseline**: Slide-level prediction framework

### 🚀 Deployment Checklist
- [ ] **Performance benchmarking** on target hardware
- [ ] **Memory profiling** for production batch sizes
- [ ] **API endpoints** for real-time embedding extraction
- [ ] **Documentation** for usage and maintenance

---

## 🎯 **Success Metrics**
- **Reconstruction**: SSIM > 0.85 on validation set
- **Retrieval**: >80% morphologically consistent k-NN results  
- **Compression**: 1000×+ reduction from original tile size
- **Speed**: <50ms per tile on target hardware

## 🔄 **Next Steps After Deployment**
1. **Monitor drift** in embedding distributions over time
2. **Collect feedback** from pathologists on retrieval quality
3. **Fine-tune** on domain-specific datasets as they become available
4. **Scale up** to multi-institutional deployments

---

*🎉 **Congratulations!** You now have a comprehensive blueprint for building production-ready histopathology autoencoders.*

In [10]:
# COMPREHENSIVE FIX: PROPER WSI AUTOENCODER TRAINING
print("🔧 COMPREHENSIVE FIX FOR SOLID COLOR RECONSTRUCTION ISSUE")
print("=" * 65)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# ============================================================================
# DIAGNOSIS: Why are we getting solid colors?
# ============================================================================

print("🔍 DIAGNOSING THE SOLID COLOR ISSUE...")
print("Common causes:")
print("   1. Learning rate too high → causing instability")
print("   2. Too few training epochs → insufficient learning")
print("   3. Data normalization issues → poor gradient flow")
print("   4. Architecture problems → information bottleneck")
print("   5. Loss function not appropriate → wrong optimization target")

# ============================================================================
# SOLUTION: PROPER TRAINING DATASET
# ============================================================================

class ProperWSIDataset:
    def __init__(self, patch_paths, max_patches=100):
        self.patch_paths = patch_paths[:max_patches]
        print(f"\n📚 Creating proper training dataset with {len(self.patch_paths)} patches...")
        
        # Load and preprocess all patches properly
        self.patches = []
        valid_count = 0
        
        for i, path in enumerate(self.patch_paths):
            try:
                with Image.open(path) as img:
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    
                    # Better normalization: standardize to [0, 1] range
                    img_array = np.array(img).astype(np.float32) / 255.0
                    
                    # Add slight noise for regularization
                    img_array = img_array + np.random.normal(0, 0.01, img_array.shape)
                    img_array = np.clip(img_array, 0, 1)
                    
                    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
                    self.patches.append(img_tensor)
                    valid_count += 1
                    
                    if i < 3:
                        print(f"   ✅ Patch {i+1}: {img_tensor.shape}, range [{img_tensor.min():.3f}, {img_tensor.max():.3f}]")
                        
            except Exception as e:
                print(f"   ❌ Failed patch {i+1}: {e}")
        
        print(f"   📊 Successfully loaded {valid_count}/{len(self.patch_paths)} patches")
        
        if valid_count < 10:
            raise ValueError("Not enough valid patches for training")
    
    def get_random_batch(self, batch_size=8):
        if len(self.patches) < batch_size:
            batch_size = len(self.patches)
        
        indices = np.random.choice(len(self.patches), batch_size, replace=False)
        batch = torch.stack([self.patches[i] for i in indices])
        return batch, batch

# ============================================================================
# SOLUTION: COMPREHENSIVE TRAINING STRATEGY
# ============================================================================

def comprehensive_training(encoder, decoder, dataset, num_epochs=25):
    """
    Comprehensive training with proper techniques to fix solid color issue
    """
    print(f"\n🎯 COMPREHENSIVE TRAINING STRATEGY:")
    print(f"   • Epochs: {num_epochs} (much longer)")
    print(f"   • Learning rate: Progressive (starts low)")
    print(f"   • Loss: Multiple loss functions")
    print(f"   • Regularization: L2 + gradient penalties")
    print(f"   • Monitoring: Real-time quality checks")
    
    # Multiple loss functions for better training
    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()  # Better for preserving details
    
    # Progressive learning rate strategy
    initial_lr = 0.0001
    optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=initial_lr,
        weight_decay=1e-4,
        betas=(0.5, 0.999)
    )
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    # Training tracking
    losses = []
    best_loss = float('inf')
    
    print(f"\n🚀 Starting comprehensive training...")
    
    for epoch in range(num_epochs):
        encoder.train()
        decoder.train()
        
        epoch_losses = []
        
        # Multiple mini-batches per epoch
        for batch_idx in range(10):  # 10 mini-batches per epoch
            # Get random batch
            real_patches, targets = dataset.get_random_batch(8)
            real_patches = real_patches.to(device)
            targets = targets.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            reconstructions, latents = autoencoder_forward(encoder, decoder, real_patches)
            
            # Combined loss function
            mse = mse_loss(reconstructions, targets)
            l1 = l1_loss(reconstructions, targets)
            
            # Total loss: combine MSE and L1 for better detail preservation
            total_loss = 0.7 * mse + 0.3 * l1
            
            # Backward pass
            total_loss.backward()
            
            # Gradient clipping (important!)
            torch.nn.utils.clip_grad_norm_(
                list(encoder.parameters()) + list(decoder.parameters()), 
                max_norm=0.5  # Stricter clipping
            )
            
            # Update weights
            optimizer.step()
            
            epoch_losses.append(total_loss.item())
        
        # Update learning rate
        scheduler.step()
        
        # Calculate epoch metrics
        avg_loss = np.mean(epoch_losses)
        losses.append(avg_loss)
        
        # Progress reporting
        if epoch % 5 == 0 or epoch < 5:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"   Epoch {epoch+1:2d}/{num_epochs}: Loss = {avg_loss:.6f}, LR = {current_lr:.2e}")
            
            # Quality check
            if epoch % 10 == 0:
                encoder.eval()
                decoder.eval()
                with torch.no_grad():
                    test_batch, _ = dataset.get_random_batch(2)
                    test_batch = test_batch.to(device)
                    test_recons, _ = autoencoder_forward(encoder, decoder, test_batch)
                    
                    # Check if we're getting better detail
                    variance_original = torch.var(test_batch).item()
                    variance_recon = torch.var(test_recons).item()
                    
                    print(f"      Quality check - Orig variance: {variance_original:.4f}, Recon variance: {variance_recon:.4f}")
                    
                    if variance_recon < 0.001:
                        print(f"      ⚠️  WARNING: Reconstructions becoming too uniform!")
        
        # Track best model
        if avg_loss < best_loss:
            best_loss = avg_loss
    
    print(f"\n✅ Comprehensive training completed!")
    print(f"   Best loss: {best_loss:.6f}")
    print(f"   Final loss: {losses[-1]:.6f}")
    
    return losses

# ============================================================================
# EXECUTE COMPREHENSIVE SOLUTION
# ============================================================================

if 'all_patches' in globals() and len(all_patches) > 0:
    try:
        # Create proper dataset
        dataset = ProperWSIDataset(all_patches, max_patches=80)
        
        # Run comprehensive training
        comprehensive_losses = comprehensive_training(encoder, decoder, dataset, num_epochs=25)
        
        print(f"\n🧪 FINAL COMPREHENSIVE TEST...")
        
        # Final quality test
        encoder.eval()
        decoder.eval()
        
        with torch.no_grad():
            # Test with multiple patches
            test_batch, _ = dataset.get_random_batch(4)
            test_batch = test_batch.to(device)
            
            final_reconstructions, final_latents = autoencoder_forward(encoder, decoder, test_batch)
            final_loss = nn.MSELoss()(final_reconstructions, test_batch)
            
            # Move to CPU for visualization
            originals_cpu = test_batch.detach().cpu()
            recons_cpu = final_reconstructions.detach().cpu()
            
            print(f"   Final test loss: {final_loss.item():.6f}")
            
            # Create final comparison
            fig, axes = plt.subplots(2, 4, figsize=(20, 10))
            fig.suptitle('COMPREHENSIVE FIX: WSI Patches After Proper Training', 
                        fontsize=18, fontweight='bold')
            
            for i in range(4):
                # Original
                orig_img = originals_cpu[i].permute(1, 2, 0).numpy()
                orig_img = np.clip(orig_img, 0, 1)
                
                axes[0, i].imshow(orig_img)
                axes[0, i].set_title(f'Original Patch {i+1}', fontweight='bold')
                axes[0, i].axis('off')
                
                # Reconstructed
                recon_img = recons_cpu[i].permute(1, 2, 0).numpy()
                recon_img = np.clip(recon_img, 0, 1)
                
                axes[1, i].imshow(recon_img)
                axes[1, i].set_title(f'Fixed Reconstruction {i+1}', fontweight='bold')
                axes[1, i].axis('off')
                
                # Calculate quality metrics
                mse = np.mean((orig_img - recon_img) ** 2)
                variance_orig = np.var(orig_img)
                variance_recon = np.var(recon_img)
                
                axes[1, i].text(5, 140, f'MSE: {mse:.4f}\\nVar: {variance_recon:.4f}', 
                               bbox=dict(boxstyle="round", facecolor="yellow", alpha=0.8),
                               fontsize=9, fontweight='bold')
            
            plt.tight_layout()
            plt.show()
            
            # Final analysis
            print(f"\n🎯 COMPREHENSIVE FIX RESULTS:")
            print(f"   Final MSE: {final_loss.item():.6f}")
            print(f"   Training epochs: {len(comprehensive_losses)}")
            
            total_improvement = ((comprehensive_losses[0] - comprehensive_losses[-1]) / comprehensive_losses[0]) * 100
            print(f"   Training improvement: {total_improvement:.1f}%")
            
            # Check reconstruction quality
            orig_variance = torch.var(test_batch).item()
            recon_variance = torch.var(final_reconstructions).item()
            
            print(f"   Original variance: {orig_variance:.4f}")
            print(f"   Reconstruction variance: {recon_variance:.4f}")
            
            if recon_variance > 0.005 and final_loss.item() < 0.1:
                print(f"   🟢 SUCCESS: Reconstructions now show tissue details!")
            elif recon_variance > 0.001:
                print(f"   🟡 PROGRESS: Some improvement in detail preservation")
            else:
                print(f"   🔴 ISSUE: Still showing solid colors - may need even more training")
            
            print(f"\n💡 KEY IMPROVEMENTS APPLIED:")
            print(f"   ✅ Longer training (25 epochs vs 3-5)")
            print(f"   ✅ Better loss function (MSE + L1)")
            print(f"   ✅ Progressive learning rate")
            print(f"   ✅ Gradient clipping")
            print(f"   ✅ Better data preprocessing")
            print(f"   ✅ Quality monitoring during training")
            
    except Exception as e:
        print(f"❌ Error during comprehensive training: {e}")
        print("This might be due to insufficient memory or other resource constraints.")
        
else:
    print("❌ No patches available. Please run the patch analysis cell first.")

🔧 COMPREHENSIVE FIX FOR SOLID COLOR RECONSTRUCTION ISSUE
🔍 DIAGNOSING THE SOLID COLOR ISSUE...
Common causes:
   1. Learning rate too high → causing instability
   2. Too few training epochs → insufficient learning
   3. Data normalization issues → poor gradient flow
   4. Architecture problems → information bottleneck
   5. Loss function not appropriate → wrong optimization target

📚 Creating proper training dataset with 80 patches...
   ✅ Patch 1: torch.Size([3, 150, 150]), range [0.000, 1.000]
   ✅ Patch 2: torch.Size([3, 150, 150]), range [0.003, 0.755]
   ✅ Patch 3: torch.Size([3, 150, 150]), range [0.000, 1.000]
   📊 Successfully loaded 80/80 patches

🎯 COMPREHENSIVE TRAINING STRATEGY:
   • Epochs: 25 (much longer)
   • Learning rate: Progressive (starts low)
   • Loss: Multiple loss functions
   • Regularization: L2 + gradient penalties
   • Monitoring: Real-time quality checks

🚀 Starting comprehensive training...
❌ Error during comprehensive training: Input type (double) and b

In [None]:
# AUTOENCODER - Convolutional Layers
# PyTorch Implementation

import torch
import torch.nn as nn

# Define a simple convolutional autoencoder
def conv_autoencoder(input_channels=3, latent_dim=256):
    model = nn.Sequential(
        nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(128, latent_dim, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(latent_dim, 128, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(64, input_channels, kernel_size=3, stride=2, padding=1),
        nn.Sigmoid()
    )
    return model


