# Protein Sub-Cellular Localization in Neurons
## Complete Pipeline for TIFF Image Analysis, Graph Construction, and Classification

This notebook provides a complete, executable pipeline for:
1. **Preprocessing**: TIFF loading, Cellpose segmentation, feature extraction
2. **Graph Construction**: Building biological graphs for GNN analysis
3. **Model Training**: Graph-CNN, VGG-16, and hybrid models
4. **Visualization**: Publication-ready scientific figures
5. **Prediction**: End-to-end inference on new samples

## 1. Setup and Imports

In [None]:
# Install required packages
!pip install -q torch torchvision torch-geometric networkx scikit-learn scikit-image
!pip install -q matplotlib seaborn pandas numpy tifffile scipy
!pip install -q cellpose flask

print("✓ All packages installed successfully")

In [None]:
# Import all required libraries
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import warnings

warnings.filterwarnings('ignore')

# Add src directory to path
sys.path.append('../src')

# Import custom modules
from preprocessing import TIFFPreprocessor, preprocess_pipeline
from graph_builder import BiologicalGraphBuilder, build_graphs_pipeline
from models import ModelTrainer, train_model_pipeline
from visualization import ProteinVisualization, create_visualizations

print("✓ All imports successful")
print(f"Using PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration and Paths

In [None]:
# Define paths
INPUT_DIR = "/mnt/d/5TH_SEM/CELLULAR/input"
OUTPUT_DIR = "/mnt/d/5TH_SEM/CELLULAR/output"

# Create output directories
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
Path(OUTPUT_DIR + "/models").mkdir(parents=True, exist_ok=True)
Path(OUTPUT_DIR + "/visualizations").mkdir(parents=True, exist_ok=True)
Path(OUTPUT_DIR + "/graphs").mkdir(parents=True, exist_ok=True)

# Configuration
CONFIG = {
    'input_dir': INPUT_DIR,
    'output_dir': OUTPUT_DIR,
    'cellpose_diameter': 30.0,
    'distance_threshold': 50.0,
    'k_neighbors': 5,
    'num_classes': 5,
    'batch_size': 32,
    'epochs': 50,
    'learning_rate': 0.001
}

print("✓ Configuration set")
print(f"Input directory: {INPUT_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

## 3. Data Preprocessing
### Load TIFF files, perform segmentation, and extract features

In [None]:
# Initialize preprocessor
preprocessor = TIFFPreprocessor(INPUT_DIR, OUTPUT_DIR)

# Scan for TIFF files
tiff_files = preprocessor.scan_tiff_files()
print(f"\nFound {len(tiff_files)} TIFF files to process")

if len(tiff_files) > 0:
    print("\nFirst few files:")
    for i, f in enumerate(tiff_files[:5]):
        print(f"  {i+1}. {f.name}")
else:
    print("⚠️  No TIFF files found. Creating synthetic data for demonstration...")

In [None]:
# Process all TIFF files (or create synthetic data if none found)
if len(tiff_files) > 0:
    processed_results = preprocessor.process_all_tiffs()
else:
    # Create synthetic data for demonstration
    print("Creating synthetic data for demonstration...")
    
    synthetic_results = []
    for i in range(5):
        # Create synthetic image
        img = np.random.rand(256, 256) * 255
        img = img.astype(np.uint8)
        
        # Create synthetic masks
        masks = np.zeros((256, 256), dtype=int)
        for j in range(10):
            y, x = np.random.randint(20, 236, 2)
            rr, cc = np.ogrid[:256, :256]
            mask = ((rr - y) ** 2 + (cc - x) ** 2) < 100
            masks[mask] = j + 1
        
        # Create synthetic features
        features = []
        for label in range(1, 11):
            features.append({
                'label': label,
                'centroid_y': np.random.rand() * 256,
                'centroid_x': np.random.rand() * 256,
                'distance_from_center': np.random.rand() * 100,
                'area': np.random.rand() * 500 + 100,
                'perimeter': np.random.rand() * 100 + 50,
                'eccentricity': np.random.rand(),
                'solidity': np.random.rand() * 0.3 + 0.7,
                'mean_intensity': np.random.rand() * 200 + 50,
                'max_intensity': np.random.rand() * 255,
                'intensity_std': np.random.rand() * 30,
                'extent': np.random.rand(),
                'major_axis_length': np.random.rand() * 50,
                'minor_axis_length': np.random.rand() * 30,
                'orientation': np.random.rand() * np.pi,
                'bbox_min_row': 0, 'bbox_min_col': 0,
                'bbox_max_row': 256, 'bbox_max_col': 256
            })
        
        synthetic_results.append({
            'file_path': f'synthetic_{i}.tif',
            'file_name': f'synthetic_{i}.tif',
            'image_shape': (256, 256),
            'masks': masks,
            'features': features,
            'segmentation_metadata': {'n_cells': 10, 'method': 'synthetic'},
            'n_regions': len(features)
        })
    
    processed_results = synthetic_results
    print(f"✓ Created {len(processed_results)} synthetic samples")

print(f"\n✓ Preprocessing complete: {len(processed_results)} files processed")

In [None]:
# Display preprocessing summary
print("Preprocessing Summary:")
print("=" * 50)
for i, result in enumerate(processed_results[:3]):
    print(f"\nFile {i+1}: {result['file_name']}")
    print(f"  Shape: {result['image_shape']}")
    print(f"  Regions detected: {result['n_regions']}")
    print(f"  Segmentation method: {result['segmentation_metadata']['method']}")

## 4. Graph Construction
### Build biological graphs from segmented regions

In [None]:
# Build graphs for all processed results
graph_builder = BiologicalGraphBuilder(
    distance_threshold=CONFIG['distance_threshold'],
    k_neighbors=CONFIG['k_neighbors']
)

graph_results = graph_builder.process_results(
    processed_results,
    OUTPUT_DIR + "/graphs"
)

print(f"\n✓ Graph construction complete: {len(graph_results)} graphs created")

In [None]:
# Display graph statistics
print("Graph Statistics:")
print("=" * 50)
for i, result in enumerate(graph_results[:3]):
    print(f"\nGraph {i+1}: {result['file_name']}")
    print(f"  Nodes: {result['n_nodes']}")
    print(f"  Edges: {result['n_edges']}")
    print(f"  Density: {result['n_edges'] / (result['n_nodes'] * (result['n_nodes'] - 1) / 2) if result['n_nodes'] > 1 else 0:.4f}")

## 5. Model Training
### Train Graph-CNN for protein localization classification

In [None]:
# Extract PyG data objects
pyg_data_list = [result['pyg_data'] for result in graph_results]

print(f"Prepared {len(pyg_data_list)} graphs for training")
if len(pyg_data_list) > 0:
    print(f"Node features per graph: {pyg_data_list[0].x.shape[1]}")
    print(f"Example graph: {pyg_data_list[0].num_nodes} nodes, {pyg_data_list[0].edge_index.shape[1]} edges")

In [None]:
# Train Graph-CNN model
training_results = train_model_pipeline(
    pyg_data_list,
    OUTPUT_DIR,
    model_type='graph_cnn',
    num_classes=CONFIG['num_classes'],
    epochs=CONFIG['epochs'],
    batch_size=CONFIG['batch_size']
)

print("\n✓ Model training complete")

In [None]:
# Display training metrics
metrics = training_results['metrics']
print("\nModel Performance Metrics:")
print("=" * 50)
print(f"Accuracy:    {metrics['accuracy']:.4f}")
print(f"Precision:   {metrics['precision']:.4f}")
print(f"Recall:      {metrics['recall']:.4f}")
print(f"F1-Score:    {metrics['f1_score']:.4f}")
if 'specificity' in metrics:
    print(f"Specificity: {metrics['specificity']:.4f}")
print("\nConfusion Matrix:")
print(np.array(metrics['confusion_matrix']))

## 6. Visualization
### Create publication-ready scientific figures

In [None]:
# Initialize visualization
viz = ProteinVisualization(OUTPUT_DIR + "/visualizations")

# Create visualizations for first few samples
for i, (proc_result, graph_result) in enumerate(zip(processed_results[:3], graph_results[:3])):
    base_name = Path(proc_result['file_name']).stem
    
    print(f"\nCreating visualizations for {base_name}...")
    
    # Plot graph
    G = graph_result['networkx_graph']
    viz.plot_graph(G, filename=f"{base_name}_graph.png", show_labels=True)
    
    # Plot feature distributions
    viz.plot_feature_distributions(
        proc_result['features'],
        filename=f"{base_name}_features.png"
    )
    
    # Plot intensity profile
    viz.plot_intensity_profile(
        proc_result['features'],
        filename=f"{base_name}_intensity.png"
    )

print("\n✓ Visualizations created")

In [None]:
# Plot training history
history = training_results['history']
viz.plot_training_history(history, filename="training_history.png")

# Plot confusion matrix
cm = np.array(metrics['confusion_matrix'])
class_names = [f"Class {i}" for i in range(len(cm))]
viz.plot_confusion_matrix(cm, class_names, filename="confusion_matrix.png")

print("✓ Training visualizations created")

## 7. Prediction Demo
### Demonstrate end-to-end prediction on a sample

In [None]:
# Load trained model
model_path = Path(OUTPUT_DIR) / "models" / "graph_cnn_model.pt"

if model_path.exists():
    predictor = ModelTrainer(model_type='graph_cnn', num_classes=CONFIG['num_classes'])
    predictor.load_model(str(model_path), num_node_features=pyg_data_list[0].x.shape[1])
    print("✓ Model loaded successfully")
else:
    print("⚠️  Model not found")
    predictor = training_results['trainer']

In [None]:
# Make prediction on first sample
if len(pyg_data_list) > 0:
    sample = pyg_data_list[0]
    
    predictor.model.eval()
    with torch.no_grad():
        sample = sample.to(predictor.device)
        
        # Add batch dimension
        from torch_geometric.data import Batch
        batch_sample = Batch.from_data_list([sample])
        
        output = predictor.model(batch_sample)
        pred = output.argmax(dim=1).item()
        probs = torch.exp(output)[0]
        confidence = probs[pred].item()
    
    print("\nPrediction Results:")
    print("=" * 50)
    print(f"Sample: {processed_results[0]['file_name']}")
    print(f"Predicted Class: {pred}")
    print(f"Confidence: {confidence * 100:.2f}%")
    print("\nClass Probabilities:")
    for i, prob in enumerate(probs):
        print(f"  Class {i}: {prob.item() * 100:.2f}%")
else:
    print("No samples available for prediction")

## 8. Summary and Export
### Generate comprehensive summary report

In [None]:
# Create summary report
summary = {
    'pipeline': 'Protein Sub-Cellular Localization',
    'date': __import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'input_dir': INPUT_DIR,
    'output_dir': OUTPUT_DIR,
    'n_files_processed': len(processed_results),
    'n_graphs_created': len(graph_results),
    'model_type': 'Graph-CNN',
    'metrics': metrics,
    'config': CONFIG
}

# Save summary
import json
summary_path = Path(OUTPUT_DIR) / "pipeline_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2, default=str)

print("\n" + "=" * 70)
print("PIPELINE EXECUTION SUMMARY")
print("=" * 70)
print(f"Files processed: {summary['n_files_processed']}")
print(f"Graphs created: {summary['n_graphs_created']}")
print(f"Model accuracy: {metrics['accuracy']:.4f}")
print(f"\nAll outputs saved to: {OUTPUT_DIR}")
print(f"Summary saved to: {summary_path}")
print("\n✓ Pipeline execution complete!")

## 9. Web Interface Instructions

To run the web interface for file uploads and predictions:

```bash
cd ../frontend
python app.py
```

Then open your browser to: http://localhost:5000

The web interface allows you to:
- Upload TIFF files via drag-and-drop
- View real-time processing status
- See segmentation results
- View graph visualizations
- Get prediction results with confidence scores
- Download all analysis outputs