# OpenFold Comprehensive 48-Layer Intermediate Representation Analysis

## Complete Guide to Multi-Layer Visualization and Analysis

**Author**: OpenFold Visualization Team  
**Purpose**: Extract and visualize intermediate representations from all 48 Evoformer layers  
**Features**: MSA, Pair, and Structure module analysis with advanced metrics

---

## üìã Environment Setup

### Prerequisites

This notebook requires the OpenFold environment with the following packages:

**Core Dependencies:**
- Python 3.10
- PyTorch 2.5+ (with CUDA support)
- NumPy
- Matplotlib
- SciPy (for clustering)

### Quick Setup

**Option 1: Use Existing OpenFold Environment**

If you have OpenFold already installed:
```bash
conda activate openfold-env
jupyter notebook OpenFold_Comprehensive_Analysis.ipynb
```

**Option 2: Install Minimal Requirements**

For visualization only (without full OpenFold):
```bash
conda create -n openfold-viz python=3.10
conda activate openfold-viz
conda install pytorch numpy matplotlib scipy jupyter -c pytorch
pip install ipykernel
python -m ipykernel install --user --name=openfold-viz
```

**Option 3: Full OpenFold Environment**

For complete OpenFold integration:
```bash
conda env create -f environment.yml
conda activate openfold-env
pip install ipykernel
python -m ipykernel install --user --name=openfold-env
```

### Required Files

Make sure `visualize_intermediate_reps_utils.py` is in the same directory as this notebook.

---

## Table of Contents

0. [Environment Setup](#env)
1. [Setup and Imports](#setup)
2. [Data Generation](#data)
3. [48-Layer Evolution Analysis](#evolution)
4. [Stratified Layer Comparison](#stratified)
5. [Convergence Analysis](#convergence)
6. [Layer Importance Ranking](#importance)
7. [Structure Module Analysis](#structure)
8. [Residue-Level Feature Analysis](#residue)
9. [Hierarchical Clustering](#clustering)
10. [Contact Map Integration](#contacts)

---


## 1. Setup and Imports

First, let's import all necessary libraries and set up our environment.


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import scipy
import os
import sys

# Import our comprehensive visualization utilities
from visualize_intermediate_reps_utils import *

# Set matplotlib style
plt.style.use('default')
%matplotlib inline

# Create output directory
output_dir = "notebook_outputs"
os.makedirs(output_dir, exist_ok=True)

print("‚úì All imports successful!")
print(f"‚úì Output directory: {output_dir}/")

# Verify dependencies
print("\nüì¶ Package Versions:")
print(f"  ‚Ä¢ Python: {sys.version.split()[0]}")
print(f"  ‚Ä¢ PyTorch: {torch.__version__}")
print(f"  ‚Ä¢ NumPy: {np.__version__}")
print(f"  ‚Ä¢ Matplotlib: {plt.matplotlib.__version__}")
print(f"  ‚Ä¢ SciPy: {scipy.__version__}")

# Check CUDA availability (optional)
if torch.cuda.is_available():
    print(f"\nüéÆ CUDA Available: {torch.cuda.get_device_name(0)}")
else:
    print("\nüíª Running on CPU (CUDA not available)")

print("\n‚úÖ Environment ready for analysis!")


## 2. Data Generation

### 2.1 Simulate 48 Evoformer Layers

OpenFold uses 48 Evoformer layers to process protein information. Let's simulate realistic layer representations.


In [None]:
# Protein parameters
n_seq = 15       # Number of MSA sequences
n_res = 100      # Number of residues
c_m = 256        # MSA channels
c_z = 128        # Pair channels
n_layers = 48    # OpenFold Evoformer layers
n_recycles = 8   # Structure module recycles

print(f"Generating data for:")
print(f"  ‚Ä¢ {n_seq} MSA sequences")
print(f"  ‚Ä¢ {n_res} residues")
print(f"  ‚Ä¢ {n_layers} Evoformer layers")
print(f"  ‚Ä¢ {n_recycles} recycles")

# Generate base representations
base_msa = torch.randn(n_seq, n_res, c_m)
base_pair = torch.randn(n_res, n_res, c_z)

# Create 48 layers with realistic convergence
msa_layers = {}
pair_layers = {}

for layer_idx in range(n_layers):
    noise_factor = 0.3 * np.exp(-layer_idx / 20)
    msa_layers[layer_idx] = base_msa + torch.randn_like(base_msa) * noise_factor
    pair_layers[layer_idx] = base_pair + torch.randn_like(base_pair) * noise_factor

# Generate structure module data
structure_output = {
    'backbone_frames': torch.randn(n_recycles, n_res, 7),
    'angles': torch.randn(n_recycles, n_res, 7, 2),
    'positions': torch.randn(n_recycles, n_res, 14, 3),
}

print(f"\n‚úì Generated {n_layers} layers successfully!")
print(f"  ‚Ä¢ MSA: {list(msa_layers.keys())[:5]}...{list(msa_layers.keys())[-2:]}")
print(f"  ‚Ä¢ Pair: {list(pair_layers.keys())[:5]}...{list(pair_layers.keys())[-2:]}")
print(f"  ‚Ä¢ Structure: {n_recycles} recycles")


## 3. Multi-Layer Evolution Analysis

Track how representations evolve across ALL 48 layers with multiple residues.


In [None]:
# Visualize evolution across all 48 layers
fig = plot_multilayer_evolution(
    msa_layers,
    residue_indices=[10, 25, 50, 75, 90],
    save_path=f"{output_dir}/multilayer_evolution.png",
    rep_type='msa',
    layer_sampling='uniform'
)
plt.show()

print("\nüìä Key Insights:")
print("  ‚Ä¢ Top plot: Representation magnitude evolution")
print("  ‚Ä¢ Bottom plot: Layer-to-layer changes")
print("  ‚Ä¢ Notice convergence in later layers!")


## 4. Stratified Layer Comparison

Side-by-side comparison of 13 strategically sampled layers.


In [None]:
# Sample strategic layers
sampled_layers = stratified_layer_sampling(48, strategy='grouped')
print(f"Comparing {len(sampled_layers)} layers: {sampled_layers}")

# MSA comparison
fig = plot_stratified_layer_comparison(
    msa_layers,
    layer_indices=sampled_layers,
    save_path=f"{output_dir}/stratified_msa.png",
    rep_type='msa'
)
plt.show()

# Pair comparison
fig = plot_stratified_layer_comparison(
    pair_layers,
    layer_indices=sampled_layers,
    save_path=f"{output_dir}/stratified_pair.png",
    rep_type='pair'
)
plt.show()

print("\nüí° Look for pattern emergence across layers!")


## 5. Convergence Analysis

Identify when representations stabilize across layers.


In [None]:
# MSA convergence
fig = plot_layer_convergence_analysis(
    msa_layers,
    save_path=f"{output_dir}/msa_convergence.png",
    rep_type='msa'
)
plt.show()

# Pair convergence
fig = plot_layer_convergence_analysis(
    pair_layers,
    save_path=f"{output_dir}/pair_convergence.png",
    rep_type='pair'
)
plt.show()

print("\nüìà High correlation = stable representations")


## 6. Layer Importance Ranking

Identify which layers contribute most using multiple metrics.


In [None]:
# Multi-metric importance analysis
fig = plot_layer_importance_ranking(
    msa_layers,
    save_path=f"{output_dir}/layer_importance.png",
    metrics=['variance', 'entropy', 'norm']
)
plt.show()

print("\nüéØ Top 5 layers are automatically labeled!")


## 7. Structure Module Analysis

Analyze backbone frames, angles, and atomic positions with 6 subplots.


In [None]:
fig = plot_structure_module_evolution(
    structure_output,
    save_path=f"{output_dir}/structure_evolution.png"
)
plt.show()

print("\nüß¨ 6-panel analysis: frames, angles, RMSD, 3D trajectory, displacement, changes")


## 8. Residue-Level Feature Analysis

Deep dive into individual residue features with 6-panel comprehensive analysis.


In [None]:
# Analyze residue 50 from layer 24 (middle layer)
residue_idx = 50
layer_idx = 24

fig = plot_residue_feature_analysis(
    msa_layers[layer_idx],
    residue_idx=residue_idx,
    save_path=f"{output_dir}/residue_{residue_idx}_analysis.png",
    rep_type='msa'
)
plt.show()

print(f"\nüîç 6-panel residue analysis: heatmap, distribution, stats, top channels, correlation, activation")


## 9. Hierarchical Clustering

Group layers by similarity using hierarchical clustering.


In [None]:
fig = plot_layer_clustering_dendrogram(
    msa_layers,
    save_path=f"{output_dir}/layer_clustering.png",
    method='ward'
)
plt.show()

print("\nüå≥ Dendrogram shows layer relationships, matrix shows distances")


## 10. Contact Map Integration

Correlate pair representations with predicted structural contacts.


In [None]:
# Generate mock contact map
contact_map = generate_mock_contact_map(n_res, contact_probability=0.15, seed=42)

fig = plot_pair_representation_heatmap(
    pair_layers[47],  # Final layer
    47,
    f"{output_dir}/pair_with_contacts.png",
    contact_map=contact_map,
    show_correlation=True
)
plt.show()

print("\nüîó Black dots = contacts, Pearson correlation computed!")


## üéâ Analysis Complete!

### Summary of Generated Visualizations

You've successfully created a comprehensive 48-layer analysis with:

1. ‚úÖ **Multi-Layer Evolution** - 48 layers, 5 residues tracked
2. ‚úÖ **Stratified MSA Comparison** - 13-layer grid
3. ‚úÖ **Stratified Pair Comparison** - 13-layer grid
4. ‚úÖ **MSA Convergence Analysis** - Correlation tracking
5. ‚úÖ **Pair Convergence Analysis** - Stability detection
6. ‚úÖ **Layer Importance Ranking** - 3 metrics analyzed
7. ‚úÖ **Structure Module Evolution** - 6-panel analysis
8. ‚úÖ **Residue Feature Analysis** - 6-panel deep dive
9. ‚úÖ **Hierarchical Clustering** - Layer grouping
10. ‚úÖ **Contact Map Integration** - Validation with correlation

All visualizations are saved in the `notebook_outputs/` directory!

### Next Steps:

- Apply to real OpenFold inference
- Compare multiple proteins
- Integrate with web interface
- Publish your findings

**Happy Analyzing! üß¨üî¨**


In [None]:
# List all generated files
import os
files = [f for f in os.listdir(output_dir) if f.endswith('.png')]
print(f"üìä Generated {len(files)} visualizations:")
for i, f in enumerate(sorted(files), 1):
    size = os.path.getsize(f"{output_dir}/{f}") / 1024
    print(f"  {i}. {f} ({size:.1f}KB)")
    
print(f"\n‚úÖ All files saved to: {output_dir}/")
