# Protein Sub-Cellular Localization in Neurons
## Complete Pipeline for 4D TIFF Microscopy Image Analysis

This notebook demonstrates the complete end-to-end pipeline for analyzing protein sub-cellular localization in neuronal microscopy images.

### Pipeline Overview:
1. **Preprocessing**: Load and segment TIFF images
2. **Feature Extraction**: Extract biological and computational features
3. **Graph Construction**: Build graph representations
4. **Model Training**: Train Graph-CNN and VGG-16 models
5. **Evaluation**: Compute metrics and visualizations
6. **Inference**: Predict localization on new images

## 1. Setup and Imports

In [None]:
# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore')

# Deep learning
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Custom modules
import sys
sys.path.append('..')

from preprocessing.segmentation import DirectoryHandler, TIFFLoader, CellposeSegmenter
from preprocessing.feature_extraction import FeatureExtractor, FeatureStorage
from graph_construction.graph_builder import GraphConstructor, PyTorchGeometricConverter, GraphStorage
from models.graph_cnn import GraphCNN
from models.vgg16 import VGG16Classifier
from models.combined_model import CombinedModel
from models.trainer import ModelTrainer, ProteinLocalizationDataset, create_data_loaders
from visualization.plotters import SegmentationVisualizer, StatisticalPlotter, ColocalizationAnalyzer
from visualization.graph_viz import GraphVisualizer
from visualization.metrics import MetricsEvaluator

print("✓ All imports successful")
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## 2. Configuration

In [None]:
# Directory paths
INPUT_DIR = "/mnt/d/5TH_SEM/CELLULAR/input"
OUTPUT_DIR = "/mnt/d/5TH_SEM/CELLULAR/output/output"
MODELS_DIR = os.path.join(OUTPUT_DIR, "models")
VIZ_DIR = os.path.join(OUTPUT_DIR, "visualizations")

# Create directories
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(VIZ_DIR, exist_ok=True)

# Model parameters
NUM_CLASSES = 10  # Adjust based on your dataset
BATCH_SIZE = 16
NUM_EPOCHS = 100
LEARNING_RATE = 0.001

print("✓ Configuration complete")

## 3. Data Loading and Preprocessing

In [None]:
# Scan directory for TIFF files
print("Scanning for TIFF files...")
dir_handler = DirectoryHandler(INPUT_DIR)
tiff_files = dir_handler.scan_directory()

print(f"Found {len(tiff_files)} TIFF files")
if len(tiff_files) > 0:
    print(f"Sample files: {tiff_files[:3]}")

In [None]:
# Load and segment images
loader = TIFFLoader()
segmenter = CellposeSegmenter(model_type='cyto2')
feature_extractor = FeatureExtractor()

# Process first image as example
if len(tiff_files) > 0:
    example_file = tiff_files[0]
    print(f"\nProcessing example: {example_file}")
    
    # Load image
    image = loader.load_tiff(example_file)
    if image is not None:
        print(f"✓ Loaded image shape: {image.shape}")
        
        # Segment
        masks, seg_info = segmenter.segment_image(image)
        if masks is not None:
            print(f"✓ Segmentation complete: {seg_info['num_cells']} cells")
            
            # Extract features
            features = feature_extractor.extract_all_features(image, masks)
            print(f"✓ Extracted features: {features.shape}")
            print(f"  Feature columns: {list(features.columns)[:10]}...")

## 4. Visualization: Segmentation Results

In [None]:
# Visualize segmentation
seg_viz = SegmentationVisualizer(output_dir=VIZ_DIR)

if 'masks' in locals() and masks is not None:
    seg_viz.plot_segmentation_overlay(
        image, masks,
        title="Segmentation Overlay",
        filename="example_segmentation.png"
    )
    
    # Display in notebook
    from IPython.display import Image, display
    display(Image(os.path.join(VIZ_DIR, "example_segmentation.png")))

## 5. Graph Construction

In [None]:
# Construct graph
graph_constructor = GraphConstructor(proximity_threshold=50)

if 'features' in locals() and not features.empty:
    graph = graph_constructor.construct_graph(features, masks)
    print(f"✓ Graph constructed:")
    print(f"  - Nodes: {graph.number_of_nodes()}")
    print(f"  - Edges: {graph.number_of_edges()}")
    print(f"  - Average degree: {2*graph.number_of_edges()/graph.number_of_nodes():.2f}")
    
    # Convert to PyTorch Geometric format
    pg_converter = PyTorchGeometricConverter()
    graph_data = pg_converter.to_pytorch_geometric(graph)
    print(f"\n✓ PyTorch Geometric format:")
    print(f"  - Node features shape: {graph_data['x'].shape}")
    print(f"  - Edge index shape: {graph_data['edge_index'].shape}")

In [None]:
# Visualize graph
graph_viz = GraphVisualizer(output_dir=VIZ_DIR)

if 'graph' in locals():
    graph_viz.plot_graph(
        graph,
        title="Protein Localization Graph",
        filename="example_graph.png",
        layout='spring'
    )
    
    # Display
    display(Image(os.path.join(VIZ_DIR, "example_graph.png")))
    
    # Graph statistics
    graph_viz.plot_graph_statistics(
        graph,
        title="Graph Statistics",
        filename="graph_stats.png"
    )

## 6. Model Definition

In [None]:
# Define models
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Graph-CNN
graph_cnn = GraphCNN(
    in_channels=20,  # Number of node features
    hidden_channels=64,
    out_channels=NUM_CLASSES,
    num_layers=3,
    dropout=0.5
).to(device)

print("✓ Graph-CNN initialized")
print(f"  Parameters: {sum(p.numel() for p in graph_cnn.parameters())}")

# VGG-16
vgg16 = VGG16Classifier(
    num_classes=NUM_CLASSES,
    pretrained=False,  # Set True if you have internet
    in_channels=1
).to(device)

print("✓ VGG-16 initialized")
print(f"  Parameters: {sum(p.numel() for p in vgg16.parameters())}")

## 7. Training (Demo)

Note: This is a demonstration. In practice, you would need a labeled dataset.

In [None]:
# Create dummy dataset for demonstration
print("Creating demo dataset...")

# In practice, you would load your actual data here
# For demo, create synthetic data
dummy_images = [torch.randn(1, 224, 224) for _ in range(100)]
dummy_graphs = [graph_data for _ in range(100)]  # Reuse example graph
dummy_labels = [torch.randint(0, NUM_CLASSES, (1,)).item() for _ in range(100)]

# Create dataset
# dataset = ProteinLocalizationDataset(dummy_images, dummy_graphs, dummy_labels)
# train_loader, val_loader = create_data_loaders(dataset, train_split=0.8, batch_size=BATCH_SIZE)

print("✓ Demo dataset created")
print("  Note: Replace with actual labeled data for real training")

In [None]:
# Training loop (commented out for demo)
"""
# Initialize trainer
trainer = ModelTrainer(
    model=graph_cnn,
    device=device,
    learning_rate=LEARNING_RATE
)

# Train model
trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCHS,
    early_stopping_patience=10,
    save_dir=MODELS_DIR
)
"""

print("Training code ready (uncomment to run)")

## 8. Evaluation Metrics

In [None]:
# Demo evaluation with synthetic results
metrics_evaluator = MetricsEvaluator(output_dir=VIZ_DIR)

# Create synthetic predictions for demo
y_true = np.random.randint(0, NUM_CLASSES, 100)
y_pred = np.random.randint(0, NUM_CLASSES, 100)

# Calculate metrics
metrics = metrics_evaluator.calculate_all_metrics(
    y_true, y_pred,
    class_names=[f'Class_{i}' for i in range(NUM_CLASSES)]
)

print("Evaluation Metrics:")
print(f"  Accuracy: {metrics['accuracy']:.2f}%")
print(f"  Precision: {metrics['precision_avg']:.2f}%")
print(f"  Recall: {metrics['recall_avg']:.2f}%")
print(f"  F1-Score: {metrics['f1_avg']:.2f}%")
print(f"  Specificity: {metrics['specificity_avg']:.2f}%")

In [None]:
# Plot confusion matrix
cm = np.array(metrics['confusion_matrix'])
metrics_evaluator.plot_confusion_matrix(
    cm,
    class_names=[f'C{i}' for i in range(NUM_CLASSES)],
    title="Confusion Matrix",
    filename="confusion_matrix.png"
)

# Display
display(Image(os.path.join(VIZ_DIR, "confusion_matrix.png")))

In [None]:
# Plot metrics comparison
metrics_evaluator.plot_metrics_comparison(
    metrics,
    title="Performance Metrics",
    filename="metrics_comparison.png"
)

display(Image(os.path.join(VIZ_DIR, "metrics_comparison.png")))

## 9. Inference on New Image

In [None]:
def predict_localization(image_path, model, device='cpu'):
    """
    Complete inference pipeline for a new TIFF image
    
    Args:
        image_path: Path to TIFF file
        model: Trained model
        device: Device to use
    
    Returns:
        Dictionary with predictions and visualizations
    """
    # Load
    loader = TIFFLoader()
    image = loader.load_tiff(image_path)
    
    # Segment
    segmenter = CellposeSegmenter()
    masks, _ = segmenter.segment_image(image)
    
    # Extract features
    extractor = FeatureExtractor()
    features = extractor.extract_all_features(image, masks)
    
    # Build graph
    constructor = GraphConstructor()
    graph = constructor.construct_graph(features, masks)
    
    # Convert to tensor
    converter = PyTorchGeometricConverter()
    graph_data = converter.to_pytorch_geometric(graph)
    
    # Predict
    model.eval()
    with torch.no_grad():
        x = graph_data['x'].to(device)
        edge_index = graph_data['edge_index'].to(device)
        output = model(x, edge_index)
        pred = output.argmax(dim=1).item()
        confidence = torch.softmax(output, dim=1).max().item()
    
    return {
        'prediction': pred,
        'confidence': confidence,
        'num_cells': len(features),
        'graph_nodes': graph.number_of_nodes(),
        'graph_edges': graph.number_of_edges()
    }

print("✓ Inference function ready")
print("  Usage: result = predict_localization(image_path, model)")

## 10. Summary and Next Steps

In [None]:
print("="*60)
print("PROTEIN SUB-CELLULAR LOCALIZATION PIPELINE COMPLETE")
print("="*60)
print("\nImplemented Components:")
print("  ✓ Preprocessing pipeline with Cellpose segmentation")
print("  ✓ Feature extraction (spatial, morphological, intensity)")
print("  ✓ Graph construction with PyTorch Geometric")
print("  ✓ Graph-CNN and VGG-16 models")
print("  ✓ Combined CNN + Graph-CNN architecture")
print("  ✓ Training and evaluation framework")
print("  ✓ Comprehensive visualization suite")
print("  ✓ Web interface with Gradio")
print("\nOutput Directories:")
print(f"  - Models: {MODELS_DIR}")
print(f"  - Visualizations: {VIZ_DIR}")
print("\nNext Steps:")
print("  1. Collect and label training data")
print("  2. Train models on labeled dataset")
print("  3. Fine-tune hyperparameters")
print("  4. Deploy web interface")
print("  5. Generate predictions on new data")