# Nitroplast Import Code Discovery Tutorial

This notebook demonstrates the complete pipeline for discovering hidden protein targeting signals using ESM-2 embeddings and supervised contrastive learning.

## Overview

**Goal**: Identify the "hidden" targeting signals in 148 nitroplast-localized proteins that lack the known uTP motif.

**Approach**:
1. Train ESM-2 + LoRA to distinguish nitroplast from cytosolic proteins
2. Use supervised contrastive learning to group all positives (uTP+ and uTP-) together
3. Analyze attention weights to discover what the model "sees" in the hidden proteins
4. Predict novel nitroplast-localized proteins in the full proteome

## Setup

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from pathlib import Path
import yaml

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Data Exploration

Let's first understand our dataset.

In [None]:
from utils.data_utils import load_fasta, validate_sequences

# Load data
pos_sequences, pos_ids = load_fasta('../data/raw/nitroplast_proteins.fasta')
neg_sequences, neg_ids = load_fasta('../data/raw/host_cytosolic.fasta')

print(f"Positive (nitroplast) proteins: {len(pos_sequences)}")
print(f"Negative (cytosolic) proteins: {len(neg_sequences)}")

# Sequence length distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].hist([len(s) for s in pos_sequences], bins=30, alpha=0.7, label='Nitroplast', color='red')
axes[0].hist([len(s) for s in neg_sequences], bins=30, alpha=0.7, label='Cytosolic', color='blue')
axes[0].set_xlabel('Sequence Length')
axes[0].set_ylabel('Count')
axes[0].set_title('Sequence Length Distribution')
axes[0].legend()

# Average length comparison
lengths = {
    'Nitroplast': [len(s) for s in pos_sequences],
    'Cytosolic': [len(s) for s in neg_sequences]
}
axes[1].boxplot(lengths.values(), labels=lengths.keys())
axes[1].set_ylabel('Sequence Length')
axes[1].set_title('Length Comparison')

plt.tight_layout()
plt.show()

print(f"\nAverage length (nitroplast): {np.mean([len(s) for s in pos_sequences]):.1f} aa")
print(f"Average length (cytosolic): {np.mean([len(s) for s in neg_sequences]):.1f} aa")

## 2. Data Preparation

We'll use sequence similarity-based splitting to prevent data leakage.

In [None]:
from utils.data_utils import prepare_datasets

# Load config
with open('../configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Prepare datasets
train_dataset, val_dataset, test_dataset = prepare_datasets(
    positive_fasta='../data/raw/nitroplast_proteins.fasta',
    negative_fasta='../data/raw/host_cytosolic.fasta',
    output_dir='../data/processed/',
    config=config,
    force_recompute=False
)

print(f"Train: {len(train_dataset)} sequences")
print(f"Val: {len(val_dataset)} sequences")
print(f"Test: {len(test_dataset)} sequences")

# Check label distribution
for name, dataset in [('Train', train_dataset), ('Val', val_dataset), ('Test', test_dataset)]:
    n_pos = sum(dataset.labels)
    n_neg = len(dataset) - n_pos
    print(f"\n{name} set:")
    print(f"  Positive: {n_pos} ({n_pos/len(dataset)*100:.1f}%)")
    print(f"  Negative: {n_neg} ({n_neg/len(dataset)*100:.1f}%)")

## 3. Model Architecture

Our model consists of:
1. **ESM-2 Encoder** (650M parameters) with LoRA adaptation
2. **Projection Head** that maps to 128-dimensional contrastive space

In [None]:
from models.esm_encoder import ESMEncoder
from models.projector import ProjectionHead, NitroplastContrastiveModel

# Build encoder
encoder = ESMEncoder(
    model_name=config['model']['esm_model_name'],
    use_lora=config['model']['use_lora'],
    lora_config={
        'r': config['model']['lora_r'],
        'alpha': config['model']['lora_alpha'],
        'dropout': config['model']['lora_dropout'],
        'target_modules': config['model']['lora_target_modules']
    },
    freeze_layers=config['model']['freeze_esm_layers'],
    pooling_method=config['model']['pooling_method'],
    device=device
)

# Build projection head
projector = ProjectionHead(
    input_dim=config['model']['projector']['input_dim'],
    hidden_dims=config['model']['projector']['hidden_dims'],
    output_dim=config['model']['projector']['output_dim'],
    dropout=config['model']['projector']['dropout'],
    use_batch_norm=config['model']['projector']['use_batch_norm']
)

# Combine into full model
model = NitroplastContrastiveModel(encoder, projector).to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")

## 4. Test Forward Pass

Verify the model works with sample sequences.

In [None]:
# Test with sample sequences
test_sequences = train_dataset.sequences[:4]
test_labels = torch.tensor(train_dataset.labels[:4])

print("Sample sequences:")
for i, (seq, label) in enumerate(zip(test_sequences, test_labels)):
    print(f"{i+1}. Length: {len(seq)}, Label: {'Nitroplast' if label == 1 else 'Cytosolic'}")

# Forward pass
model.eval()
with torch.no_grad():
    outputs = model(test_sequences)

print(f"\nOutput shapes:")
print(f"  ESM-2 embeddings: {outputs['embeddings'].shape}")
print(f"  Contrastive embeddings: {outputs['projected'].shape}")
print(f"  Residue embeddings: {outputs['residue_embeddings'].shape}")

# Check L2 normalization
norms = outputs['projected'].norm(dim=1)
print(f"\nL2 norms (should be ~1.0): {norms}")

## 5. Training (Simplified Version)

For the full training, run: `python train.py --config configs/config.yaml`

Here we'll demonstrate a few training steps.

In [None]:
from torch.utils.data import DataLoader
from utils.training_utils import collate_fn, create_optimizer
from models.projector import SupConLoss

# Create data loader
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

# Setup training
loss_fn = SupConLoss(temperature=config['training']['temperature'])
optimizer = create_optimizer(model, learning_rate=2e-4, weight_decay=0.01)

# Training step
model.train()
batch = next(iter(train_loader))

sequences = batch['sequences']
labels = batch['labels'].to(device)

# Forward pass
outputs = model(sequences)
projected = outputs['projected']

# Compute loss
loss = loss_fn(projected, labels)

print(f"Batch size: {len(sequences)}")
print(f"Labels: {labels}")
print(f"Loss: {loss.item():.4f}")

# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()

print("\n✓ Training step completed successfully!")

## 6. Load Pre-trained Model

After training completes, load the best model for analysis.

In [None]:
from utils.inference_utils import load_model_for_inference

# Load trained model
checkpoint_path = '../results/checkpoints/best_model.pt'

if Path(checkpoint_path).exists():
    model = load_model_for_inference(
        checkpoint_path=checkpoint_path,
        config=config,
        device=device
    )
    print("✓ Model loaded successfully")
else:
    print("⚠ Model not found. Run training first: python train.py")

## 7. Visualize Learned Embeddings

Use t-SNE to visualize how well the model separates nitroplast from cytosolic proteins.

In [None]:
from utils.visualization import visualize_embeddings

# Compute embeddings for test set
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=collate_fn
)

all_embeddings = []
all_labels = []
all_ids = []

model.eval()
with torch.no_grad():
    for batch in test_loader:
        sequences = batch['sequences']
        labels = batch['labels']
        ids = batch['ids']
        
        embeddings = model.get_embeddings(sequences).cpu().numpy()
        
        all_embeddings.append(embeddings)
        all_labels.append(labels.numpy())
        all_ids.extend(ids)

embeddings = np.vstack(all_embeddings)
labels = np.concatenate(all_labels)

# Visualize
visualize_embeddings(
    embeddings=embeddings,
    labels=labels,
    ids=all_ids,
    output_path='../results/plots/test_embeddings.png',
    method='tsne',
    title='Test Set Protein Embeddings'
)

print("✓ Embedding visualization saved")

## 8. Attention Analysis

Analyze what the model is attending to in the 148 "hidden" proteins.

In [None]:
from utils.visualization import visualize_attention_map, extract_high_attention_regions

# Load a hidden protein (no uTP motif)
# For demonstration, use first test protein
test_protein_id = test_dataset.ids[0]
test_sequence = test_dataset.sequences[0]

print(f"Analyzing: {test_protein_id}")
print(f"Length: {len(test_sequence)} aa")

# Get attention weights
attention = model.encoder.get_attention_weights(
    [test_sequence],
    layer_idx=-1,
    aggregate_heads='mean'
)
attention = attention[0].cpu().numpy()

# Visualize
visualize_attention_map(
    sequence=test_sequence,
    attention=attention,
    protein_id=test_protein_id,
    output_path=f'../results/attention_maps/{test_protein_id}_attention.png',
    top_k_residues=20,
    window_size=30
)

# Extract high-attention regions
regions = extract_high_attention_regions(
    sequence=test_sequence,
    attention=attention,
    window_size=15,
    min_attention=0.1
)

print(f"\nTop 5 high-attention regions:")
for i, region in enumerate(regions[:5], 1):
    print(f"{i}. Position {region['position']}: {region['sequence']} (score: {region['avg_attention']:.3f})")

## 9. Zero-Shot Prediction

Predict novel nitroplast-localized proteins in the full proteome.

In [None]:
from utils.inference_utils import NitroplastPredictor, compute_reference_embeddings

# Compute reference embeddings
ref_embeddings, ref_labels, ref_ids = compute_reference_embeddings(
    model=model,
    dataset=test_dataset,
    batch_size=32,
    device=device
)

# Create predictor
predictor = NitroplastPredictor(
    model=model,
    reference_embeddings=ref_embeddings,
    reference_labels=ref_labels,
    reference_ids=ref_ids,
    distance_metric='cosine',
    confidence_threshold=0.8,
    device=device
)

# Example prediction
example_seq = test_dataset.sequences[10]
example_id = test_dataset.ids[10]

prediction = predictor.predict_single(example_seq, example_id)

print(f"Protein: {prediction['protein_id']}")
print(f"Prediction: {'Nitroplast' if prediction['prediction'] == 1 else 'Cytosolic'}")
print(f"Confidence: {prediction['confidence']:.3f}")
print(f"\nNearest neighbors:")
for i, nn in enumerate(prediction['nearest_neighbors'], 1):
    label = 'Nitroplast' if nn['label'] == 1 else 'Cytosolic'
    print(f"  {i}. {nn['id']} ({label}), distance: {nn['distance']:.3f}")

## 10. Summary and Next Steps

### What We've Accomplished:
1. ✓ Built a contrastive learning model to distinguish nitroplast proteins
2. ✓ Visualized learned embedding space
3. ✓ Analyzed attention patterns to discover hidden signals
4. ✓ Made zero-shot predictions on novel proteins

### Next Steps:
1. **Analyze all 148 hidden proteins** systematically
   - Run: `python analyze_attention.py --checkpoint results/checkpoints/best_model.pt ...`
2. **Look for consensus motifs** in high-attention regions
3. **Predict full proteome**
   - Run: `python predict.py --checkpoint ... --proteome ...`
4. **Validate discoveries experimentally**
   - Test predicted proteins for nitroplast localization
   - Mutate high-attention regions to test functionality