# Protein Sub-Cellular Localization in Neurons

This notebook demonstrates the complete pipeline for processing and analyzing TIFF images from the OpenCell database to predict protein sub-cellular localization patterns.

## Pipeline Overview

1. **Environment Setup**
2. **Data Access & Sanity Checks**
3. **Image Preprocessing**
4. **Graph Construction**
5. **Labels Preparation**
6. **Model Design & Training**
7. **Inference**
8. **Evaluation & Visualization**

## 1. Environment Setup

In [None]:
# Import required libraries
import os
import sys
import yaml
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add parent directory to path
sys.path.append('..')

# Import pipeline modules
from utils.data_loader import TIFFDataLoader
from utils.preprocessor import ImagePreprocessor
from utils.graph_builder import GraphBuilder
from utils.visualizer import Visualizer

# Setup matplotlib
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

print("✓ Environment setup complete")

In [None]:
# Load configuration
with open('../config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"  Input directory: {config['data']['input_dir']}")
print(f"  Output directory: {config['data']['output_dir']}")
print(f"  Model type: {config['model']['type']}")

## 2. Data Access & Sanity Checks

In [None]:
# Set your data directory
DATA_DIR = "../data/raw"  # Change this to your TIFF images directory

# Initialize data loader
loader = TIFFDataLoader(DATA_DIR)

# Scan directory for TIFF files
image_files = loader.scan_directory()

print(f"Found {len(image_files)} TIFF files")
print("\nFirst 5 files:")
for f in image_files[:5]:
    print(f"  {f.name}")

In [None]:
# Load all images with validation
images = loader.load_all(validate=True)

# Print summary statistics
loader.print_summary()

In [None]:
# Visualize a sample image
if images:
    sample_name = list(images.keys())[0]
    sample_image, sample_metadata = images[sample_name]
    
    visualizer = Visualizer(config['visualization'])
    visualizer.visualize_image(sample_image, title=f"Sample: {sample_name}")
    
    print(f"\nImage metadata:")
    for key, value in sample_metadata.items():
        print(f"  {key}: {value}")

## 3. Image Preprocessing

In [None]:
# Initialize preprocessor
preprocessor = ImagePreprocessor(config['preprocessing'])

# Extract image arrays
image_arrays = {k: v[0] for k, v in images.items()}

# Preprocess a single image for demonstration
sample_name = list(image_arrays.keys())[0]
original = image_arrays[sample_name]
processed = preprocessor.preprocess(original, sample_name)

print(f"Original shape: {original.shape}, dtype: {original.dtype}")
print(f"Processed shape: {processed.shape}, dtype: {processed.dtype}")

In [None]:
# Visualize preprocessing results
visualizer.visualize_preprocessing(
    original, 
    processed,
    title=f"Preprocessing: {sample_name}"
)

In [None]:
# Preprocess all images
processed_images = preprocessor.preprocess_batch(
    image_arrays,
    output_dir=config['data']['processed_dir']
)

print(f"Preprocessed {len(processed_images)} images")

## 4. Graph Construction

In [None]:
# Initialize graph builder
graph_builder = GraphBuilder(config['graph'])

# Build graph from sample image
sample_name = list(processed_images.keys())[0]
sample_image = processed_images[sample_name]

graph_data = graph_builder.build_graph(sample_image, sample_name)

print(f"Graph constructed:")
print(f"  Number of nodes: {graph_data['num_nodes']}")
print(f"  Node features shape: {graph_data['node_features'].shape}")
print(f"  Number of edges: {graph_data['edges'].shape[1]}")
print(f"  Edge features shape: {graph_data['edge_features'].shape}")

In [None]:
# Visualize segmentation
visualizer.visualize_segmentation(
    sample_image,
    graph_data['segments'],
    title=f"Segmentation: {sample_name}"
)

In [None]:
# Build graphs for all images
graphs = graph_builder.build_batch(
    processed_images,
    output_dir=config['data']['graph_dir']
)

print(f"Built {len(graphs)} graphs")

## 5. Model Training

Training is typically done via the command line for better resource management.
However, you can also train from the notebook if desired.

In [None]:
# Training command
print("To train the model, run:")
print(f"  python ../train.py --data_dir {config['data']['graph_dir']} --model_type gnn --epochs 100")
print("\nOr for quick testing with fewer epochs:")
print(f"  python ../train.py --data_dir {config['data']['graph_dir']} --model_type gnn --epochs 10")

## 6. Inference

Run inference on images using a trained model.

In [None]:
# Check if trained model exists
model_path = Path(config['data']['output_dir']) / 'models' / 'best_model.pth'

if model_path.exists():
    print(f"✓ Model found: {model_path}")
    
    # Import inference module
    from inference import InferenceEngine
    
    # Initialize inference engine
    engine = InferenceEngine(str(model_path), config)
    
    # Run inference
    results = engine.predict_from_directory(DATA_DIR)
    
    print(f"\nGenerated predictions for {len(results)} images")
    
    # Display sample predictions
    print("\nSample predictions:")
    for i, (filename, result) in enumerate(list(results.items())[:5]):
        print(f"  {filename}:")
        print(f"    Predicted class: {result['class_name']}")
        print(f"    Confidence: {result['probabilities'].max():.4f}")
else:
    print(f"✗ Model not found: {model_path}")
    print("Please train the model first using train.py")

## 7. Evaluation & Visualization

In [None]:
# If predictions are available, visualize results
if 'results' in locals() and results:
    from collections import Counter
    
    # Count predictions per class
    pred_counts = Counter([r['prediction'] for r in results.values()])
    class_names = config['labels']['class_names']
    
    # Plot distribution
    fig, ax = plt.subplots(figsize=(12, 6))
    
    classes = [class_names[p] for p in sorted(pred_counts.keys())]
    counts = [pred_counts[p] for p in sorted(pred_counts.keys())]
    
    ax.bar(classes, counts)
    ax.set_xlabel('Predicted Localization Class')
    ax.set_ylabel('Count')
    ax.set_title('Prediction Distribution')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
    
    print("\nPredictions per class:")
    for class_name, count in zip(classes, counts):
        print(f"  {class_name}: {count}")

## Summary

This notebook demonstrated the complete protein sub-cellular localization pipeline:

1. ✓ Loaded and validated TIFF images
2. ✓ Preprocessed images (denoising, normalization, enhancement)
3. ✓ Constructed graph representations
4. ✓ Prepared data for training
5. Model training (via command line)
6. Inference on all samples
7. Evaluation and visualization

### Next Steps

- Train the model: `python train.py --data_dir data/graphs`
- Run inference: `python inference.py --model_path outputs/models/best_model.pth --input_dir data/raw`
- Evaluate results: `python evaluate.py --predictions_dir outputs/results`

### Key Features

- ✓ Batch processing of multiple TIFF files
- ✓ Automatic graph construction from images
- ✓ Flexible model architectures (GNN, CNN)
- ✓ Complete evaluation pipeline
- ✓ Ubuntu + Jupyter Lab compatible