# 00 – InstanSeg Model Training

This notebook allows you to:
1. **Prepare training datasets** from annotated images
2. **Fine-tune existing InstanSeg models** on custom data
3. **Train new InstanSeg models** from scratch
4. **Validate and export trained models** for use in other notebooks

**Use Cases**:
- Custom pancreatic tissue that doesn't segment well with pretrained models
- Specialized staining protocols (custom H&E, immunofluorescence panels)
- Domain adaptation for tissue-specific cell morphologies

**Requirements**:
- Annotated training images (masks with instance labels)
- InstanSeg installed: `pip install instanseg-torch[full]`
- GPU recommended for training

In [None]:
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
from tqdm import tqdm

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
if device == 'cuda':
    print(f'  GPU: {torch.cuda.get_device_name(0)}')
    print(f'  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

# Import InstanSeg training utilities
try:
    from instanseg import InstanSeg
    from instanseg.utils.augmentation import Augmenter
    print('InstanSeg imported successfully')
except ImportError as e:
    print(f'Error importing InstanSeg: {e}')
    print('Install with: pip install instanseg-torch[full]')

## Configuration

In [None]:
# ===== Dataset Configuration =====
# Directory containing training images and masks
dataset_dir = Path('../data/training')
images_dir = dataset_dir / 'images'  # RGB images
masks_dir = dataset_dir / 'masks'    # Instance label masks

# Create directories if they don't exist
images_dir.mkdir(parents=True, exist_ok=True)
masks_dir.mkdir(parents=True, exist_ok=True)

# ===== Training Configuration =====
# Base model to fine-tune (or None to train from scratch)
base_model = 'brightfield_nuclei'  # 'brightfield_nuclei', 'fluorescence_nuclei_and_cells', or None

# Training parameters
epochs = 50
batch_size = 8
learning_rate = 1e-4
patch_size = 512  # Size of training patches
val_split = 0.2   # Fraction of data for validation

# Data augmentation
use_augmentation = True
augmentation_params = {
    'rotation': True,
    'flip': True,
    'scale': (0.8, 1.2),
    'elastic_deformation': True,
    'brightness': 0.2,
    'contrast': 0.2,
}

# ===== Output Configuration =====
output_dir = Path('../models/instanseg_custom')
output_dir.mkdir(parents=True, exist_ok=True)

model_name = 'pancreas_brightfield_v1'
checkpoint_interval = 10  # Save checkpoint every N epochs

print('Configuration set')
print(f'Dataset: {dataset_dir}')
print(f'Output: {output_dir / model_name}')

## Dataset Preparation

**Expected Format**:
- Images: RGB or grayscale images (PNG, TIFF, or JPG)
- Masks: 16-bit or 32-bit integer arrays where each unique value represents a cell instance
  - Background = 0
  - Cell 1 = 1, Cell 2 = 2, etc.

**Annotation Tools**:
- **QuPath**: Export instance segmentation masks
- **Cellpose**: Use GUI to annotate and export masks
- **CVAT/Label Studio**: Polygon annotations → convert to instance masks

In [None]:
def load_dataset(images_dir, masks_dir):
    """Load all images and corresponding masks."""
    image_files = sorted(images_dir.glob('*.png')) + sorted(images_dir.glob('*.tif')) + sorted(images_dir.glob('*.tiff'))
    
    dataset = []
    for img_path in image_files:
        # Find corresponding mask
        mask_path = masks_dir / f'{img_path.stem}_mask.tif'
        if not mask_path.exists():
            mask_path = masks_dir / f'{img_path.stem}.tif'
        if not mask_path.exists():
            print(f'Warning: No mask found for {img_path.name}')
            continue
        
        dataset.append({'image': img_path, 'mask': mask_path})
    
    return dataset

def visualize_sample(dataset, idx=0):
    """Visualize a sample from the dataset."""
    sample = dataset[idx]
    img = cv2.imread(str(sample['image']))
    mask = cv2.imread(str(sample['mask']), cv2.IMREAD_UNCHANGED)
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    axes[0].set_title('Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(mask > 0, cmap='gray')
    axes[1].set_title('Binary Mask', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    from skimage.color import label2rgb
    colored_mask = label2rgb(mask, bg_label=0)
    axes[2].imshow(colored_mask)
    n_cells = len(np.unique(mask)) - 1
    axes[2].set_title(f'Instance Labels ({n_cells} cells)', fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Load dataset
dataset = load_dataset(images_dir, masks_dir)
print(f'Loaded {len(dataset)} image-mask pairs')

if len(dataset) > 0:
    print('\nDataset statistics:')
    total_cells = 0
    for sample in dataset:
        mask = cv2.imread(str(sample['mask']), cv2.IMREAD_UNCHANGED)
        n_cells = len(np.unique(mask)) - 1
        total_cells += n_cells
    print(f'  Total cells: {total_cells:,}')
    print(f'  Avg cells per image: {total_cells / len(dataset):.0f}')
    
    # Visualize first sample
    print('\nVisualizing first sample:')
    visualize_sample(dataset, idx=0)
else:
    print('\n⚠ No training data found!')
    print(f'Add annotated images to: {images_dir}')
    print(f'Add corresponding masks to: {masks_dir}')

## Data Augmentation

Augmentation improves model generalization by creating variations of training samples.

In [None]:
def create_augmenter(params):
    """Create augmentation pipeline."""
    transforms = []
    
    if params.get('rotation'):
        transforms.append('rotation')
    if params.get('flip'):
        transforms.append('flip')
    if params.get('scale'):
        transforms.append(('scale', params['scale']))
    if params.get('elastic_deformation'):
        transforms.append('elastic')
    if params.get('brightness'):
        transforms.append(('brightness', params['brightness']))
    if params.get('contrast'):
        transforms.append(('contrast', params['contrast']))
    
    return transforms

def augment_sample(img, mask, transforms):
    """Apply augmentation transforms."""
    # Basic augmentations using numpy/cv2
    # (Real implementation would use albumentations or similar)
    
    # Random rotation
    if 'rotation' in transforms and np.random.rand() > 0.5:
        angle = np.random.uniform(-180, 180)
        h, w = img.shape[:2]
        M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
        img = cv2.warpAffine(img, M, (w, h))
        mask = cv2.warpAffine(mask, M, (w, h), flags=cv2.INTER_NEAREST)
    
    # Random flip
    if 'flip' in transforms:
        if np.random.rand() > 0.5:
            img = cv2.flip(img, 0)  # vertical
            mask = cv2.flip(mask, 0)
        if np.random.rand() > 0.5:
            img = cv2.flip(img, 1)  # horizontal
            mask = cv2.flip(mask, 1)
    
    return img, mask

# Test augmentation on first sample
if len(dataset) > 0 and use_augmentation:
    sample = dataset[0]
    img = cv2.imread(str(sample['image']))
    mask = cv2.imread(str(sample['mask']), cv2.IMREAD_UNCHANGED)
    
    transforms = create_augmenter(augmentation_params)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    for i in range(4):
        img_aug, mask_aug = augment_sample(img.copy(), mask.copy(), transforms)
        
        axes[0, i].imshow(cv2.cvtColor(img_aug, cv2.COLOR_BGR2RGB))
        axes[0, i].set_title(f'Augmented {i+1}', fontsize=12)
        axes[0, i].axis('off')
        
        from skimage.color import label2rgb
        colored = label2rgb(mask_aug, bg_label=0)
        axes[1, i].imshow(colored)
        axes[1, i].set_title(f'Mask {i+1}', fontsize=12)
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

## Model Training

**Note**: This cell provides a training template. InstanSeg's training API may require adjustment based on the installed version.

In [None]:
def train_instanseg_model(
    dataset,
    base_model=None,
    epochs=50,
    batch_size=8,
    learning_rate=1e-4,
    patch_size=512,
    val_split=0.2,
    output_dir=None,
    checkpoint_interval=10,
):
    """
    Train InstanSeg model.
    
    This is a template - adjust based on InstanSeg's actual training API.
    Consult InstanSeg documentation for the exact training procedure.
    """
    print('=' * 60)
    print('Training Configuration:')
    print(f'  Base model: {base_model or "from scratch"}')
    print(f'  Dataset size: {len(dataset)} samples')
    print(f'  Epochs: {epochs}')
    print(f'  Batch size: {batch_size}')
    print(f'  Learning rate: {learning_rate}')
    print(f'  Patch size: {patch_size}x{patch_size}')
    print('=' * 60)
    
    # Split dataset
    n_val = int(len(dataset) * val_split)
    n_train = len(dataset) - n_val
    
    indices = np.random.permutation(len(dataset))
    train_indices = indices[:n_train]
    val_indices = indices[n_train:]
    
    print(f'\nTrain samples: {n_train}')
    print(f'Val samples: {n_val}')
    
    # Initialize or load model
    if base_model:
        print(f'\nLoading base model: {base_model}')
        model = InstanSeg(base_model)
    else:
        print('\nInitializing model from scratch')
        # Model architecture initialization would go here
        raise NotImplementedError(
            'Training from scratch requires model architecture definition. '
            'Consult InstanSeg documentation for model initialization.'
        )
    
    # Training loop
    print('\nStarting training...')
    
    # NOTE: InstanSeg may have a built-in training method like:
    # model.train(
    #     train_data=train_dataset,
    #     val_data=val_dataset,
    #     epochs=epochs,
    #     batch_size=batch_size,
    #     learning_rate=learning_rate,
    #     ...
    # )
    
    print('\n⚠ Training template requires InstanSeg training API.')
    print('Consult InstanSeg documentation for:')
    print('  - model.fit() or model.train() method')
    print('  - Dataset format (PyTorch Dataset/DataLoader)')
    print('  - Loss functions and optimizers')
    print('  - Checkpoint saving')
    
    return None

# Run training (if dataset is available)
if len(dataset) >= 5:  # Minimum samples for meaningful training
    trained_model = train_instanseg_model(
        dataset=dataset,
        base_model=base_model,
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        patch_size=patch_size,
        val_split=val_split,
        output_dir=output_dir,
        checkpoint_interval=checkpoint_interval,
    )
else:
    print('⚠ Insufficient training data (need at least 5 samples)')
    print(f'Current: {len(dataset)} samples')

## Model Validation

Test the trained model on validation samples.

In [None]:
# Load trained model
model_path = output_dir / model_name / 'final_model.pth'

if model_path.exists():
    print(f'Loading model: {model_path}')
    # model = InstanSeg.load(model_path)
    print('Model loaded successfully')
    
    # Run inference on validation samples
    # ... validation code ...
else:
    print(f'⚠ Model not found: {model_path}')
    print('Train a model first or specify an existing model path')

## Export for Production

Save the trained model in a format compatible with notebook 01.

In [None]:
# Export model
export_path = output_dir / f'{model_name}_export'
export_path.mkdir(exist_ok=True)

print(f'Exporting model to: {export_path}')
print('\nTo use in notebook 01, set:')
print(f"  instanseg_model = '{export_path}'")
print('  backend = \'instanseg\'')

## Summary

This notebook provides a framework for training custom InstanSeg models. Key steps:

1. **Prepare annotated data** (images + instance masks)
2. **Configure training parameters** (epochs, batch size, learning rate)
3. **Train model** (fine-tune pretrained or train from scratch)
4. **Validate performance** on held-out test set
5. **Export model** for use in production notebooks

**Next Steps**:
- Collect more training data for better generalization
- Experiment with hyperparameters
- Test on diverse tissue samples
- Compare with pretrained models