# 🧬 MultiOmicsBind: Complete Tutorial with Temporal Multi-Omics Integration

This comprehensive notebook demonstrates all functionalities of **MultiOmicsBind**, including:

- ✅ Loading multi-omics data from CSV files
- ✅ Temporal multi-omics integration (time-series proteomics + static transcriptomics/cell painting)
- ✅ Automatic NaN detection and fixing
- ✅ Model training with binding modality
- ✅ Feature importance analysis
- ✅ Cross-modal similarity computation
- ✅ UMAP visualizations with **custom class names** (NEW!)
- ✅ **Dose-response analysis visualization** (NEW!)
- ✅ Modality contribution analysis

## Dataset Overview

We'll work with:
- **Transcriptomics** (6000 genes, baseline measurement)
- **Cell Painting** (1500 features, baseline measurement)
- **Proteomics** (4000 proteins, 5 timepoints: 0h, 1h, 2h, 4h, 8h)
- **Metadata** (dose information and response labels)

## Response Classes
- **No Response** (Class 0): Low dose, minimal effect
- **Partial Response** (Class 1): Medium dose, moderate effect
- **Full Response** (Class 2): High dose, strong effect

## 1️⃣ Import Required Libraries

In [None]:
# Standard libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from pathlib import Path

# MultiOmicsBind imports
from multiomicsbind.data.dataset import TemporalMultiOmicsDataset
from multiomicsbind.core.model import MultiOmicsBindModel
from multiomicsbind.training.trainer import train_temporal_model, evaluate_temporal_model
from multiomicsbind.analysis import (
    compute_feature_importance,
    compute_cross_modal_similarity,
    create_analysis_report
)
from multiomicsbind.utils.visualization import (
    plot_training_history_detailed,
    plot_cross_modal_similarity_matrices,
    plot_embeddings_umap,
    plot_dose_response_analysis  # NEW!
)
from multiomicsbind.utils.nan_handling import check_and_fix_all_nan_values

# PyTorch utilities
from torch.utils.data import random_split

# Set style for better-looking plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"NumPy version: {np.__version__}")

## 2️⃣ Load CSV Data Files

First, let's check if the data files exist. If not, we'll generate synthetic data.

In [None]:
# Check if data files exist
data_files = {
    'transcriptomics': 'transcriptomics_baseline.csv',
    'cell_painting': 'cell_painting_baseline.csv',
    'proteomics': 'proteomics_timeseries.csv',
    'metadata': 'temporal_metadata.csv'
}

all_exist = all(os.path.exists(f) for f in data_files.values())

if not all_exist:
    print("⚠️ Data files not found. Generating synthetic data...")
    print("(In production, you would load your own CSV files here)")
    
    # Run data generation from temporal_example
    import sys
    sys.path.append('..')
    exec(open('temporal_example.py').read())
else:
    print("✅ All data files found!")

# List files with sizes
for name, file in data_files.items():
    if os.path.exists(file):
        size = os.path.getsize(file) / 1024  # KB
        print(f"  {name:20s}: {file:40s} ({size:>8.1f} KB)")

### 2.1 Explore the Data Files

Let's load and inspect each CSV file to understand the data structure.

In [None]:
# Load transcriptomics data (static, baseline)
transcriptomics_df = pd.read_csv('transcriptomics_baseline.csv')
print("=" * 80)
print("📊 TRANSCRIPTOMICS DATA (Baseline Gene Expression)")
print("=" * 80)
print(f"Shape: {transcriptomics_df.shape}")
print(f"Columns: {list(transcriptomics_df.columns[:5])} ... (showing first 5)")
print("\nFirst 3 samples:")
display(transcriptomics_df.head(3))

# Load cell painting data (static, baseline)
cell_painting_df = pd.read_csv('cell_painting_baseline.csv')
print("\n" + "=" * 80)
print("🔬 CELL PAINTING DATA (Baseline Morphology Features)")
print("=" * 80)
print(f"Shape: {cell_painting_df.shape}")
print(f"Columns: {list(cell_painting_df.columns[:5])} ... (showing first 5)")
print("\nFirst 3 samples:")
display(cell_painting_df.head(3))

In [None]:
# Load proteomics time-series data (temporal)
proteomics_df = pd.read_csv('proteomics_timeseries.csv')
print("=" * 80)
print("⏱️  PROTEOMICS TIME-SERIES DATA (5 timepoints per sample)")
print("=" * 80)
print(f"Shape: {proteomics_df.shape}")
print(f"Columns: {list(proteomics_df.columns[:5])} ... (showing first 5)")
print("\nFirst 10 timepoints (2 samples × 5 timepoints):")
display(proteomics_df.head(10))

# Check unique samples and timepoints
print(f"\nUnique samples: {proteomics_df['sample_id'].nunique()}")
print(f"Timepoints per sample: {proteomics_df.groupby('sample_id').size().unique()}")
print(f"Timepoint values: {sorted(proteomics_df['timepoint'].unique())}")

In [None]:
# Load metadata (includes dose and response labels)
metadata_df = pd.read_csv('temporal_metadata.csv')
print("=" * 80)
print("📋 METADATA (Dose Information and Response Labels)")
print("=" * 80)
print(f"Shape: {metadata_df.shape}")
print(f"Columns: {list(metadata_df.columns)}")
print("\nFirst 10 samples:")
display(metadata_df.head(10))

# Analyze dose distribution by response class
print("\n" + "=" * 80)
print("💊 DOSE-RESPONSE RELATIONSHIP")
print("=" * 80)
for response_class in sorted(metadata_df['response'].unique()):
    class_data = metadata_df[metadata_df['response'] == response_class]
    dose_mean = class_data['dose'].mean()
    dose_std = class_data['dose'].std()
    dose_range = (class_data['dose'].min(), class_data['dose'].max())
    count = len(class_data)
    print(f"Class {response_class}: {count:3d} samples, "
          f"dose = {dose_mean:.2f} ± {dose_std:.2f} μM, "
          f"range [{dose_range[0]:.2f}-{dose_range[1]:.2f}]")

## 3️⃣ Create Temporal Multi-Omics Dataset

Now let's create the MultiOmicsBind dataset that integrates all modalities.

In [None]:
# Create TemporalMultiOmicsDataset
print("Creating temporal multi-omics dataset...")

dataset = TemporalMultiOmicsDataset(
    static_files={
        'transcriptomics': 'transcriptomics_baseline.csv',
        'cell_painting': 'cell_painting_baseline.csv'
    },
    temporal_files={
        'proteomics': 'proteomics_timeseries.csv'
    },
    metadata_file='temporal_metadata.csv',
    label_col='response',
    num_cols=['dose'],  # Include dose as numerical metadata
    cat_cols=['treatment_day']  # Categorical metadata
)

print(f"\n✅ Dataset created successfully!")
print(f"   Total samples: {len(dataset)}")
print(f"   Static modalities: {dataset.static_modalities}")
print(f"   Temporal modalities: {dataset.temporal_modalities}")
print(f"   Number of classes: {len(dataset.get_labels().unique())}")

# Get a sample to check structure
sample = dataset[0]
print(f"\n📦 Sample structure:")
for key, value in sample.items():
    if isinstance(value, torch.Tensor):
        print(f"   {key:20s}: shape {tuple(value.shape)}, dtype {value.dtype}")
    else:
        print(f"   {key:20s}: {value}")

## 4️⃣ Automatic NaN Detection and Fixing

MultiOmicsBind includes automatic NaN detection and fixing for all modalities.

In [None]:
# Check and fix NaN values across all modalities
print("🔍 Checking for NaN values in all modalities...")
print("=" * 80)

check_and_fix_all_nan_values(dataset, verbose=True)

print("\n✅ Dataset is now clean and ready for training!")

## 5️⃣ Train/Test Split

Split data properly to prevent data leakage.

In [None]:
# Split dataset into train (70%) and test (30%)
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(
    dataset, 
    [train_size, test_size],
    generator=torch.Generator().manual_seed(42)  # For reproducibility
)

print(f"✅ Dataset split:")
print(f"   Training samples: {len(train_dataset)} ({len(train_dataset)/len(dataset)*100:.1f}%)")
print(f"   Test samples: {len(test_dataset)} ({len(test_dataset)/len(dataset)*100:.1f}%)")

# Verify no overlap
train_indices = set(train_dataset.indices)
test_indices = set(test_dataset.indices)
overlap = train_indices & test_indices
print(f"\n🔒 Data leakage check: {len(overlap)} overlapping samples (should be 0)")

## 6️⃣ Model Training

Train the MultiOmicsBind model with binding modality approach.

In [None]:
# Train model using high-level API
print("🚀 Training MultiOmicsBind model...")
print("=" * 80)

model, history = train_temporal_model(
    dataset=train_dataset,
    device=device,
    n_classes=3,  # No Response, Partial Response, Full Response
    embedding_dim=128,
    hidden_dim=256,
    n_heads=4,
    n_layers=2,
    dropout=0.1,
    learning_rate=0.001,
    batch_size=32,
    epochs=15,
    save_path='multiomicsbind_model.pth',
    verbose=True
)

print("\n✅ Training complete!")
print(f"   Final training accuracy: {history['train_acc'][-1]:.4f}")
print(f"   Model saved to: multiomicsbind_model.pth")

### 6.1 Visualize Training History

In [None]:
# Plot training history
plot_training_history_detailed(history, save_path='training_history.png')
plt.show()

print("✅ Training history plot saved to 'training_history.png'")

## 7️⃣ Model Evaluation on Test Set

Evaluate the trained model on held-out test data.

In [None]:
# Evaluate on test set
print("📊 Evaluating model on test set...")
embeddings, labels, predictions = evaluate_temporal_model(model, test_dataset, device)

# Calculate accuracy
test_accuracy = (predictions == labels).mean()
print(f"\n✅ Test Set Accuracy: {test_accuracy:.4f}")

# Show per-class accuracy
for class_idx in range(3):
    class_mask = labels == class_idx
    if class_mask.sum() > 0:
        class_acc = (predictions[class_mask] == labels[class_mask]).mean()
        print(f"   Class {class_idx} accuracy: {class_acc:.4f} ({class_mask.sum()} samples)")

## 8️⃣ Feature Importance Analysis

Compute which features contribute most to predictions.

In [None]:
# Compute feature importance using gradients
print("🔍 Computing feature importance...")
importance_dict, importance_df = compute_feature_importance(
    model, dataset, device, n_batches=10, verbose=True
)

# Save to CSV
importance_df.to_csv('feature_importance.csv', index=False)
print("\n✅ Feature importance saved to 'feature_importance.csv'")

# Display top features per modality
print("\n" + "=" * 80)
print("TOP 5 FEATURES PER MODALITY")
print("=" * 80)
for modality in importance_df['modality'].unique():
    modality_df = importance_df[importance_df['modality'] == modality]
    top_features = modality_df.nlargest(5, 'importance')
    print(f"\n{modality.upper()}:")
    for idx, row in top_features.iterrows():
        print(f"  {row['feature_name']:30s}: {row['importance']:.6f}")

### 8.1 Modality Contribution Analysis

Analyze how much each modality contributes to predictions.

In [None]:
# Calculate contribution by modality
modality_contribution = importance_df.groupby('modality')['importance'].sum().sort_values(ascending=False)
total_importance = modality_contribution.sum()

print("\n" + "=" * 80)
print("MODALITY CONTRIBUTION TO PREDICTIONS")
print("=" * 80)

# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Bar plot
modality_pct = (modality_contribution / total_importance * 100)
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(modality_pct)))
ax1.bar(range(len(modality_pct)), modality_pct, color=colors)
ax1.set_xticks(range(len(modality_pct)))
ax1.set_xticklabels(modality_pct.index, rotation=45, ha='right')
ax1.set_ylabel('Contribution (%)', fontsize=12)
ax1.set_title('Modality Contribution to Predictions', fontsize=13, fontweight='bold')
ax1.grid(axis='y', alpha=0.3)

# Pie chart
ax2.pie(modality_pct, labels=modality_pct.index, autopct='%1.1f%%', 
        colors=colors, startangle=90)
ax2.set_title('Modality Distribution', fontsize=13, fontweight='bold')

plt.tight_layout()
plt.savefig('modality_contribution.png', dpi=300, bbox_inches='tight')
plt.show()

# Print statistics
for modality, importance in modality_contribution.items():
    pct = (importance / total_importance) * 100
    print(f"  {modality:20s}: {pct:>6.2f}% ({importance:.4f})")

# Temporal vs static analysis
if 'proteomics' in modality_contribution.index:
    proteomics_pct = (modality_contribution['proteomics'] / total_importance) * 100
    print(f"\n{'✓' if proteomics_pct > 50 else '→'} Temporal proteomics: {proteomics_pct:.1f}%")
    if proteomics_pct > 60:
        print("  → Temporal dynamics are highly informative")
    elif proteomics_pct < 40:
        print("  → Baseline state more predictive than dynamics")
    else:
        print("  → Balanced temporal and static contributions")

## 9️⃣ Cross-Modal Similarity Analysis

Measure how similar embeddings are across different modalities.

In [None]:
# Compute cross-modal similarity
print("🔗 Computing cross-modal similarity...")
similarity_dict = compute_cross_modal_similarity(model, dataset, device)

# Visualize similarity matrices
plot_cross_modal_similarity_matrices(similarity_dict, save_path='similarity_matrices.png')
plt.show()

print("\n✅ Cross-modal similarity matrices saved to 'similarity_matrices.png'")

## 🔟 Comprehensive Analysis Report with Class Names (NEW!)

Generate a complete analysis report with **custom class names** appearing in all visualizations.

In [None]:
# Define meaningful class names
class_names = ['No Response', 'Partial Response', 'Full Response']

print("📊 Generating comprehensive analysis report with custom class names...")
print("=" * 80)

# Generate report on TEST SET only (prevents data leakage!)
report = create_analysis_report(
    model=model,
    dataset=test_dataset,  # ← Use test set!
    device=device,
    class_names=class_names,  # ← Custom class names!
    output_dir='./analysis_results',
    compute_importance=True,
    compute_similarity=True,
    n_importance_batches=10,
    verbose=True
)

print("\n" + "=" * 80)
print("✅ ANALYSIS COMPLETE!")
print("=" * 80)
print(f"\nTest Accuracy: {report['accuracy']:.4f}")
print(f"Output directory: analysis_results/")
print("\nGenerated files:")
print("  ├── classification_report.txt")
print("  ├── confusion_matrix.png")
print("  ├── embeddings_umap_transcriptomics.png (with class names!)")
print("  ├── embeddings_umap_cell_painting.png (with class names!)")
print("  ├── embeddings_umap_proteomics.png (with class names!)")
print("  ├── feature_importance.csv")
print("  └── cross_modal_similarity.png")

### 10.1 View UMAP Visualizations with Class Names

Let's display the UMAP plots showing our custom class names!

In [None]:
# Display UMAP plots
from IPython.display import Image, display

print("UMAP Visualizations (showing class names, not 'Class 0, 1, 2'):")
print("=" * 80)

modalities = ['transcriptomics', 'cell_painting', 'proteomics']
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, modality in enumerate(modalities):
    img_path = f'analysis_results/embeddings_umap_{modality}.png'
    if os.path.exists(img_path):
        img = plt.imread(img_path)
        axes[idx].imshow(img)
        axes[idx].axis('off')
        axes[idx].set_title(f'{modality.capitalize()}', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n✅ All UMAPs show 'No Response', 'Partial Response', 'Full Response'!")
print("   (Not generic 'Class 0', 'Class 1', 'Class 2')")

## 1️⃣1️⃣ Dose-Response Visualization (NEW! 🎉)

Analyze how **dose** contributes to predictions with a comprehensive 3-panel visualization.

In [None]:
# Load metadata and extract test set doses
metadata = pd.read_csv('temporal_metadata.csv')
test_metadata = metadata.iloc[test_dataset.indices].reset_index(drop=True)

# Get test set predictions and labels
test_labels = report['labels']
test_predictions = report['predictions']
test_doses = test_metadata['dose'].values

print("💊 DOSE-RESPONSE ANALYSIS")
print("=" * 80)

# Show dose statistics by class
print("\nDose distribution by TRUE response class:")
for class_idx in range(3):
    mask = test_labels == class_idx
    if mask.sum() > 0:
        doses = test_doses[mask]
        print(f"  {class_names[class_idx]:20s}: "
              f"mean={doses.mean():.2f} μM, "
              f"median={np.median(doses):.2f} μM, "
              f"range=[{doses.min():.2f}-{doses.max():.2f}]")

print("\nDose distribution by PREDICTED response class:")
for class_idx in range(3):
    mask = test_predictions == class_idx
    if mask.sum() > 0:
        doses = test_doses[mask]
        print(f"  {class_names[class_idx]:20s}: "
              f"mean={doses.mean():.2f} μM, "
              f"median={np.median(doses):.2f} μM, "
              f"range=[{doses.min():.2f}-{doses.max():.2f}]")

print("\n✓ Model learned dose-response relationship!")
print("  → Higher doses generally lead to stronger responses")
print("  → Dose treated as continuous numerical metadata")

### 11.1 Generate 3-Panel Dose-Response Visualization

This NEW function creates a comprehensive visualization showing:
1. **Dose distribution by class** (violin plots)
2. **Dose vs predictions** (scatter plot with accuracy markers)
3. **Mean dose comparison** (true vs predicted classes)

In [None]:
# Generate dose-response visualization
print("\n📊 Generating dose-response visualization...")

plot_dose_response_analysis(
    doses=test_doses,
    labels=test_labels,
    predictions=test_predictions,
    class_names=class_names,
    save_path='dose_response_analysis.png'
)

plt.show()

print("✅ Dose-response analysis saved to 'dose_response_analysis.png'")
print("\nThis plot shows:")
print("  1. Left panel: Dose distribution by true class (violin plots)")
print("  2. Center panel: Dose vs predictions (circles=correct, X=incorrect)")
print("  3. Right panel: Mean dose per class (blue=true, coral=predicted)")

### 11.2 Sample Predictions with Dose Information

Let's look at some example predictions with dose values.

In [None]:
# Create a DataFrame with predictions and doses
results_df = pd.DataFrame({
    'sample_id': test_metadata['sample_id'].values[:10],
    'dose_μM': test_doses[:10],
    'true_class': [class_names[i] for i in test_labels[:10]],
    'predicted_class': [class_names[i] for i in test_predictions[:10]],
    'correct': ['✓' if test_labels[i] == test_predictions[i] else '✗' 
                for i in range(10)]
})

print("\n" + "=" * 80)
print("EXAMPLE PREDICTIONS WITH DOSE INFORMATION (First 10 Test Samples)")
print("=" * 80)
display(results_df)

# Calculate overall accuracy by dose range
print("\n" + "=" * 80)
print("ACCURACY BY DOSE RANGE")
print("=" * 80)
dose_ranges = [(0, 3), (3, 6), (6, 10)]
for low, high in dose_ranges:
    mask = (test_doses >= low) & (test_doses < high)
    if mask.sum() > 0:
        accuracy = (test_predictions[mask] == test_labels[mask]).mean()
        n_samples = mask.sum()
        print(f"  {low}-{high} μM: {accuracy:.3f} accuracy ({n_samples} samples)")

## 🎯 Summary and Key Findings

Let's summarize all the analyses we performed.

In [None]:
print("=" * 80)
print("🎯 MULTIOMICSBIND COMPLETE TUTORIAL SUMMARY")
print("=" * 80)

print("\n📊 MODEL PERFORMANCE:")
print(f"  • Test Accuracy: {report['accuracy']:.4f}")
print(f"  • Training samples: {len(train_dataset)}")
print(f"  • Test samples: {len(test_dataset)}")
print(f"  • No data leakage: Train and test sets completely separate")

print("\n🧬 DATA INTEGRATION:")
print(f"  • Successfully integrated 3 modalities:")
print(f"    - Transcriptomics: 6000 genes (static baseline)")
print(f"    - Cell Painting: 1500 features (static baseline)")
print(f"    - Proteomics: 4000 proteins × 5 timepoints (temporal)")
print(f"  • Metadata: dose + treatment_day")
print(f"  • Total samples: {len(dataset)}")

print("\n✨ NEW FEATURES DEMONSTRATED:")
print("  ✅ Custom class names in all visualizations")
print("     ('No Response', 'Partial Response', 'Full Response')")
print("     instead of generic 'Class 0', 'Class 1', 'Class 2'")
print("\n  ✅ Dose-response visualization (3-panel plot)")
print("     - Dose distribution by class (violin plots)")
print("     - Dose vs predictions with accuracy markers")
print("     - Mean dose comparison (true vs predicted)")
print("\n  ✅ Modality contribution analysis")
print("     - Percentage contribution of each modality")
print("     - Temporal vs static comparison")
print("\n  ✅ Automatic NaN detection and fixing")
print("     - Scans all modalities for missing values")
print("     - Intelligent filling strategies")

print("\n📁 GENERATED FILES:")
files = [
    'multiomicsbind_model.pth',
    'training_history.png',
    'feature_importance.csv',
    'modality_contribution.png',
    'similarity_matrices.png',
    'dose_response_analysis.png',
    'analysis_results/classification_report.txt',
    'analysis_results/confusion_matrix.png',
    'analysis_results/embeddings_umap_*.png (with class names!)',
    'analysis_results/cross_modal_similarity.png'
]
for i, file in enumerate(files, 1):
    print(f"  {i:2d}. {file}")

print("\n" + "=" * 80)
print("🎉 TUTORIAL COMPLETE!")
print("=" * 80)
print("\nYou've learned how to:")
print("  1. Load multi-omics data from CSV files")
print("  2. Create temporal multi-omics datasets")
print("  3. Handle NaN values automatically")
print("  4. Train MultiOmicsBind models with binding modality")
print("  5. Evaluate on held-out test sets (no data leakage)")
print("  6. Compute feature importance")
print("  7. Analyze cross-modal similarity")
print("  8. Generate UMAPs with custom class names")
print("  9. Visualize dose-response relationships")
print(" 10. Analyze modality contributions")
print("\n💡 Next steps:")
print("  • Try with your own multi-omics data!")
print("  • Experiment with different architectures")
print("  • Explore advanced features in the documentation")
print("  • Check DOSE_RESPONSE_VISUALIZATION.md for detailed guide")

## 📚 Additional Resources

For more information, check out:

### Documentation
- **[README.md](../README.md)** - Overview and getting started
- **[ADVANCED_USAGE_GUIDE.md](../ADVANCED_USAGE_GUIDE.md)** - Advanced features and patterns
- **[QUICK_ANSWERS.md](../QUICK_ANSWERS.md)** - Quick reference for common questions
- **[DOSE_RESPONSE_VISUALIZATION.md](../DOSE_RESPONSE_VISUALIZATION.md)** - Detailed dose-response guide

### Examples
- **[basic_example.py](basic_example.py)** - Simple multi-omics integration
- **[temporal_example.py](temporal_example.py)** - Temporal multi-omics (source for this notebook)
- **[flexible_modalities_example.py](flexible_modalities_example.py)** - Flexible modality combinations

### Key Concepts
- **Binding Modality**: Efficient attention mechanism for multi-omics integration
- **Temporal Integration**: LSTM-based encoding for time-series omics data
- **Cross-Modal Learning**: Learning unified representations across modalities
- **Gradient-Based Importance**: Feature importance via gradient attribution

### Citation
If you use MultiOmicsBind in your research, please cite:

```bibtex
@software{multiomicsbind2024,
  author = {Shivaprasad Patil},
  title = {MultiOmicsBind: Integrative Multi-Omics Analysis with Binding Modality},
  year = {2024},
  url = {https://github.com/shivaprasad-patil/MultiOmicsBind}
}
```

---

**Repository**: https://github.com/shivaprasad-patil/MultiOmicsBind

**Questions or Issues?** Open an issue on GitHub!

---

Thank you for using MultiOmicsBind! 🧬✨