# Protein Sub-Cellular Localization Pipeline
## Complete End-to-End Pipeline for 4D TIFF Image Analysis

This notebook demonstrates the complete pipeline for:
1. Loading and preprocessing 4D TIFF images
2. Cellpose segmentation of neuronal structures
3. Feature extraction and graph construction
4. Graph-CNN model training and evaluation
5. Visualization of results
6. Final prediction on sample images

**Author:** Protein Localization Analysis Team  
**Date:** 2024  
**Environment:** Ubuntu + JupyterLab

## 1. Setup and Imports

In [None]:
# Standard library imports
import os
import sys
from pathlib import Path
import pickle
import json
import warnings
warnings.filterwarnings('ignore')

# Scientific computing
import numpy as np
import pandas as pd
from scipy import ndimage
from sklearn.model_selection import train_test_split

# Image processing
import tifffile
import cv2
from skimage import measure, morphology

# Deep learning
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from cellpose import models

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx

# Add src to path
sys.path.append('../src')

from preprocessing.preprocess import PreprocessingPipeline, TIFFProcessor, CellposeSegmenter, FeatureExtractor
from graph.graph_builder import GraphBuilder, GraphDataset
from models.train import GraphCNN, ModelTrainer, prepare_data_loaders
from visualization.plots import VisualizationSuite

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration and Paths

In [None]:
# Define paths
INPUT_DIR = "/mnt/d/5TH_SEM/CELLULAR/input"
OUTPUT_DIR = "/mnt/d/5TH_SEM/CELLULAR/output"
MODELS_DIR = f"{OUTPUT_DIR}/models"

# Create output directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

# Model parameters
NUM_CLASSES = 6  # soma, dendrite, axon, nucleus, synapse, mitochondria
CLASS_NAMES = ['Soma', 'Dendrite', 'Axon', 'Nucleus', 'Synapse', 'Mitochondria']
BATCH_SIZE = 32
NUM_EPOCHS = 100
LEARNING_RATE = 0.001

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

print(f"\nConfiguration:")
print(f"  Input Directory: {INPUT_DIR}")
print(f"  Output Directory: {OUTPUT_DIR}")
print(f"  Models Directory: {MODELS_DIR}")
print(f"  Number of Classes: {NUM_CLASSES}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Training Epochs: {NUM_EPOCHS}")

## 3. Data Loading and Preprocessing

### 3.1 Scan for TIFF Files

In [None]:
# Initialize preprocessing pipeline
preprocessing_pipeline = PreprocessingPipeline(input_dir=INPUT_DIR, output_dir=OUTPUT_DIR)

# Scan for TIFF files
tiff_files = preprocessing_pipeline.processor.scan_directories()

print(f"\nFound {len(tiff_files)} TIFF files:")
for i, file in enumerate(tiff_files[:5], 1):
    print(f"  {i}. {file.name}")
if len(tiff_files) > 5:
    print(f"  ... and {len(tiff_files) - 5} more")

### 3.2 Process Sample Image

Let's process a single image to understand the pipeline:

In [None]:
# Load a sample TIFF file
if len(tiff_files) > 0:
    sample_file = tiff_files[0]
    print(f"Processing sample file: {sample_file.name}")
    
    # Load image
    sample_img = preprocessing_pipeline.processor.load_tiff(sample_file)
    
    print(f"\nImage properties:")
    print(f"  Shape: {sample_img.shape}")
    print(f"  Dtype: {sample_img.dtype}")
    print(f"  Min value: {sample_img.min()}")
    print(f"  Max value: {sample_img.max()}")
    print(f"  Mean value: {sample_img.mean():.2f}")
else:
    print("No TIFF files found. Creating synthetic data for demonstration...")
    # Create synthetic data
    sample_img = np.random.rand(10, 512, 512) * 255
    sample_img = sample_img.astype(np.uint8)

### 3.3 Image Segmentation with Cellpose

In [None]:
# Normalize and segment
img_norm = preprocessing_pipeline.processor.normalize_image(sample_img)

print("Running Cellpose segmentation...")
masks, seg_metadata = preprocessing_pipeline.segmenter.segment(img_norm)

print(f"\nSegmentation results:")
print(f"  Number of cells detected: {seg_metadata['num_cells']}")
print(f"  Mask shape: {masks.shape}")
print(f"  Unique regions: {len(np.unique(masks)) - 1}")

### 3.4 Feature Extraction

In [None]:
# Prepare 2D image for feature extraction
if sample_img.ndim == 4:
    img_2d = np.max(sample_img, axis=(0, 1))
elif sample_img.ndim == 3:
    img_2d = np.max(sample_img, axis=0) if sample_img.shape[-1] > 3 else sample_img
else:
    img_2d = sample_img

# Extract features
region_features = preprocessing_pipeline.feature_extractor.extract_region_properties(img_2d, masks)
spatial_features = preprocessing_pipeline.feature_extractor.extract_spatial_features(masks)

print(f"\nExtracted features for {len(region_features)} regions")
print(f"\nSample region features:")
if len(region_features) > 0:
    sample_features = region_features[0]
    for key, value in list(sample_features.items())[:8]:
        print(f"  {key}: {value}")

### 3.5 Process All Images

In [None]:
# Process all images (this may take time)
print("Processing all TIFF files...")
print("Note: This step may take several minutes depending on the number and size of images.\n")

if len(tiff_files) > 0:
    results = preprocessing_pipeline.process_all()
    print(f"\nSuccessfully preprocessed {len(results)} images")
else:
    print("Creating demo preprocessed data...")
    # Create demo data
    results = []
    for i in range(10):
        demo_result = {
            'filename': f'demo_{i}.tif',
            'masks': masks,
            'region_features': region_features,
            'spatial_features': spatial_features,
            'num_regions': len(region_features)
        }
        results.append(demo_result)
    
    # Save demo data
    with open(f"{OUTPUT_DIR}/preprocessed_data.pkl", 'wb') as f:
        pickle.dump(results, f)
    print(f"Created {len(results)} demo samples")

## 4. Graph Construction

### 4.1 Build Biological Graphs

In [None]:
# Initialize graph dataset
print("Building graphs from preprocessed data...")
graph_dataset = GraphDataset(f"{OUTPUT_DIR}/preprocessed_data.pkl")
graph_dataset.load_and_build_graphs()

# Save graphs
graph_dataset.save_graphs(f"{OUTPUT_DIR}/graphs.pkl")

print(f"\nGraph dataset statistics:")
print(f"  Total graphs: {len(graph_dataset)}")

if len(graph_dataset) > 0:
    sample_graph, sample_label = graph_dataset.get_graph(0)
    print(f"\nSample graph properties:")
    print(f"  Number of nodes: {sample_graph.num_nodes}")
    print(f"  Number of edges: {sample_graph.edge_index.shape[1]}")
    print(f"  Node feature dimension: {sample_graph.x.shape[1]}")
    print(f"  Label: {sample_label}")

### 4.2 Visualize Sample Graph

In [None]:
# Visualize a sample graph
if len(graph_dataset) > 0:
    # Build NetworkX graph for visualization
    sample_result = results[0]
    G = graph_dataset.graph_builder.build_graph_from_features(sample_result)
    
    viz_suite = VisualizationSuite(output_dir=OUTPUT_DIR)
    viz_suite.plot_graph_visualization(G, title="Sample Biological Graph")
    
    print("Graph visualization saved!")
    print(f"  Nodes: {G.number_of_nodes()}")
    print(f"  Edges: {G.number_of_edges()}")
    print(f"  Density: {nx.density(G):.4f}")

## 5. Model Training

### 5.1 Prepare Data Loaders

In [None]:
# Prepare train and test loaders
print("Preparing data loaders...")

# Add dummy labels for demonstration (in real scenario, these would come from annotations)
with open(f"{OUTPUT_DIR}/graphs.pkl", 'rb') as f:
    graph_data = pickle.load(f)

graphs = graph_data['graphs']
# Create random labels for demonstration
labels = np.random.randint(0, NUM_CLASSES, len(graphs))

# Add labels to graphs
for graph, label in zip(graphs, labels):
    graph.y = torch.tensor([label], dtype=torch.long)

# Update saved data
graph_data['labels'] = labels.tolist()
with open(f"{OUTPUT_DIR}/graphs.pkl", 'wb') as f:
    pickle.dump(graph_data, f)

# Create data loaders
train_loader, test_loader = prepare_data_loaders(
    f"{OUTPUT_DIR}/graphs.pkl",
    batch_size=BATCH_SIZE,
    test_size=0.2
)

print(f"\nData split:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Testing batches: {len(test_loader)}")

### 5.2 Initialize Model

In [None]:
# Get feature dimension from first batch
sample_batch = next(iter(train_loader))
num_features = sample_batch.x.shape[1]

print(f"Model configuration:")
print(f"  Input features: {num_features}")
print(f"  Output classes: {NUM_CLASSES}")
print(f"  Hidden channels: 64")
print(f"  Number of layers: 3")

# Initialize Graph-CNN model
model = GraphCNN(
    num_features=num_features,
    num_classes=NUM_CLASSES,
    hidden_channels=64,
    num_layers=3
)

print(f"\nModel architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

### 5.3 Train Model

In [None]:
# Initialize trainer
trainer = ModelTrainer(model, device=device)

print(f"Starting training for {NUM_EPOCHS} epochs...\n")

# Train the model
metrics = trainer.train(
    train_loader=train_loader,
    val_loader=test_loader,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    num_classes=NUM_CLASSES
)

print("\n" + "="*50)
print("Training Complete!")
print("="*50)

### 5.4 Save Model and Metrics

In [None]:
# Save trained model
model_path = f"{MODELS_DIR}/graph_cnn.pth"
trainer.save_model(model_path)

# Save metrics
metrics_path = f"{MODELS_DIR}/metrics.json"
with open(metrics_path, 'w') as f:
    # Convert numpy arrays to lists for JSON
    metrics_json = {
        k: v.tolist() if isinstance(v, np.ndarray) else v
        for k, v in metrics.items()
    }
    json.dump(metrics_json, f, indent=2)

print(f"Model saved to: {model_path}")
print(f"Metrics saved to: {metrics_path}")

## 6. Model Evaluation

### 6.1 Display Evaluation Metrics

In [None]:
print("\n" + "="*50)
print("EVALUATION METRICS")
print("="*50)

print(f"\nOverall Performance:")
print(f"  Accuracy:    {metrics['accuracy']:.4f}")
print(f"  Precision:   {metrics['precision']:.4f}")
print(f"  Recall:      {metrics['recall']:.4f}")
print(f"  F1-Score:    {metrics['f1_score']:.4f}")
print(f"  Specificity: {metrics['specificity']:.4f}")

print(f"\nPer-Class Specificity:")
for i, (class_name, spec) in enumerate(zip(CLASS_NAMES, metrics['specificity_per_class'])):
    print(f"  {class_name:15s}: {spec:.4f}")

# Create metrics summary dataframe
metrics_df = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'Specificity'],
    'Value': [
        metrics['accuracy'],
        metrics['precision'],
        metrics['recall'],
        metrics['f1_score'],
        metrics['specificity']
    ]
})

print("\n" + metrics_df.to_string(index=False))

### 6.2 Confusion Matrix

In [None]:
# Plot confusion matrix
cm = np.array(metrics['confusion_matrix'])

viz_suite = VisualizationSuite(output_dir=OUTPUT_DIR)
viz_suite.plot_confusion_matrix(cm, CLASS_NAMES, save_name="confusion_matrix.png")

print("Confusion Matrix:")
print(cm)
print("\nConfusion matrix visualization saved!")

### 6.3 Training History

In [None]:
# Plot training history
viz_suite.plot_training_history(trainer.history, save_name="training_history.png")

print("Training history visualization saved!")
print(f"\nFinal training accuracy: {trainer.history['train_acc'][-1]:.4f}")
print(f"Final validation accuracy: {trainer.history['val_acc'][-1]:.4f}")

## 7. Visualization Suite

### 7.1 Image Overlay

In [None]:
# Generate image overlay visualization
if len(results) > 0:
    sample_result = results[0]
    
    # Get original image
    if len(tiff_files) > 0:
        original_img = tifffile.imread(str(tiff_files[0]))
        if original_img.ndim == 4:
            original_img = np.max(original_img, axis=(0, 1))
        elif original_img.ndim == 3 and original_img.shape[-1] > 3:
            original_img = np.max(original_img, axis=0)
    else:
        original_img = sample_img if sample_img.ndim == 2 else sample_img[0]
    
    viz_suite.plot_image_overlay(
        original_img,
        sample_result['masks'],
        title="Segmentation Overlay",
        save_name="sample_overlay.png"
    )
    
    print("Image overlay visualization saved!")

### 7.2 Compartment Distribution

In [None]:
# Create synthetic compartment data for visualization
compartment_data = {
    'Soma': np.random.normal(100, 20, 30).tolist(),
    'Dendrite': np.random.normal(80, 15, 35).tolist(),
    'Axon': np.random.normal(60, 10, 25).tolist(),
    'Nucleus': np.random.normal(120, 25, 20).tolist(),
    'Synapse': np.random.normal(40, 8, 40).tolist(),
    'Mitochondria': np.random.normal(70, 12, 30).tolist()
}

# Grouped bar plot
viz_suite.plot_grouped_bar_with_points(
    compartment_data,
    ylabel="Mean Intensity (a.u.)",
    title="Protein Intensity by Compartment",
    save_name="compartment_intensity.png"
)

# Box and violin plot
viz_suite.plot_box_violin(
    compartment_data,
    ylabel="Intensity Distribution",
    title="Compartment Intensity Distribution",
    save_name="intensity_distribution.png"
)

print("Compartment distribution visualizations saved!")

### 7.3 Colocalization Analysis

In [None]:
# Generate synthetic dual-channel data
channel_a = np.random.rand(512, 512) * 255
channel_b = channel_a * 0.7 + np.random.rand(512, 512) * 76.5  # Partially correlated

# Colocalization scatter plot
viz_suite.plot_colocalization_scatter(
    channel_a,
    channel_b,
    title="Channel A vs Channel B Colocalization",
    save_name="colocalization.png"
)

# Colocalization metrics
coloc_metrics = {
    'Pearson': np.corrcoef(channel_a.flatten(), channel_b.flatten())[0, 1],
    'Manders M1': 0.72,
    'Manders M2': 0.68,
    'Overlap': 0.75
}

viz_suite.plot_colocalization_metrics(
    coloc_metrics,
    save_name="coloc_metrics.png"
)

print("Colocalization visualizations saved!")
print(f"\nColocalization metrics:")
for key, value in coloc_metrics.items():
    print(f"  {key}: {value:.3f}")

### 7.4 Intensity Profile

In [None]:
# Generate synthetic intensity profile data
distances = np.linspace(0, 200, 100)
intensities = 100 * np.exp(-distances / 50) + np.random.normal(0, 5, 100)

viz_suite.plot_intensity_profile(
    distances,
    intensities,
    title="Protein Intensity vs Distance from Soma",
    save_name="intensity_profile.png"
)

print("Intensity profile visualization saved!")

## 8. Final Prediction Demo

### 8.1 Load Trained Model

In [None]:
# Load the trained model for inference
inference_model = GraphCNN(num_features=num_features, num_classes=NUM_CLASSES)
inference_trainer = ModelTrainer(inference_model, device=device)
inference_trainer.load_model(f"{MODELS_DIR}/graph_cnn.pth")

inference_model.eval()
print("Model loaded for inference!")

### 8.2 Predict on Sample TIFF

In [None]:
# Select a test sample
test_sample = next(iter(test_loader))
test_sample = test_sample.to(device)

# Make prediction
with torch.no_grad():
    output = inference_model(test_sample.x, test_sample.edge_index, test_sample.batch)
    probabilities = torch.softmax(output, dim=1)
    predictions = output.argmax(dim=1)

# Display results
print("\n" + "="*50)
print("PREDICTION RESULTS")
print("="*50)

for i, (pred, prob) in enumerate(zip(predictions, probabilities)):
    predicted_class = CLASS_NAMES[pred.item()]
    confidence = prob[pred].item()
    
    print(f"\nSample {i+1}:")
    print(f"  Predicted Class: {predicted_class}")
    print(f"  Confidence: {confidence:.2%}")
    print(f"\n  Class Probabilities:")
    for class_name, p in zip(CLASS_NAMES, prob):
        print(f"    {class_name:15s}: {p.item():.2%}")
    
    if i >= 2:  # Show first 3 samples
        break

## 9. Summary and Next Steps

### 9.1 Pipeline Summary

In [None]:
print("\n" + "="*60)
print("PIPELINE SUMMARY")
print("="*60)

summary = f"""
âœ… Data Processing:
   - Processed {len(results)} TIFF images
   - Segmented {sum(r['num_regions'] for r in results)} total regions
   - Extracted morphological and intensity features

âœ… Graph Construction:
   - Built {len(graph_dataset)} biological graphs
   - Average nodes per graph: {np.mean([g.num_nodes for g in graphs]):.1f}
   - Average edges per graph: {np.mean([g.edge_index.shape[1] for g in graphs]):.1f}

âœ… Model Training:
   - Trained Graph-CNN with {total_params:,} parameters
   - Final accuracy: {metrics['accuracy']:.2%}
   - F1-Score: {metrics['f1_score']:.4f}

âœ… Outputs Saved:
   - Preprocessed data: {OUTPUT_DIR}/preprocessed_data.pkl
   - Graph data: {OUTPUT_DIR}/graphs.pkl
   - Trained model: {MODELS_DIR}/graph_cnn.pth
   - Metrics: {MODELS_DIR}/metrics.json
   - Visualizations: {OUTPUT_DIR}/*.png

ðŸ“Š Performance Metrics:
   - Accuracy:    {metrics['accuracy']:.4f}
   - Precision:   {metrics['precision']:.4f}
   - Recall:      {metrics['recall']:.4f}
   - F1-Score:    {metrics['f1_score']:.4f}
   - Specificity: {metrics['specificity']:.4f}
"""

print(summary)

print("="*60)
print("Pipeline execution completed successfully!")
print("="*60)

### 9.2 Next Steps and Recommendations

1. **Improve Model Performance:**
   - Collect more annotated training data
   - Try Graph Attention Networks (GAT)
   - Implement data augmentation
   - Fine-tune hyperparameters

2. **Advanced Analysis:**
   - Implement temporal analysis for 4D data
   - Add multi-channel colocalization
   - Develop hierarchical compartment models

3. **Deployment:**
   - Use the Flask web interface: `python src/frontend/app.py`
   - Deploy on cloud platform
   - Create batch processing pipeline

4. **Documentation:**
   - Add more inline comments
   - Create user guide
   - Document model architecture choices

## 10. References and Citations

- **Cellpose:** Stringer, C., et al. (2021). "Cellpose: a generalist algorithm for cellular segmentation." Nature Methods.
- **PyTorch Geometric:** Fey, M., & Lenssen, J. E. (2019). "Fast Graph Representation Learning with PyTorch Geometric."
- **Graph Neural Networks:** Kipf, T. N., & Welling, M. (2017). "Semi-Supervised Classification with Graph Convolutional Networks."

---

**End of Notebook**