# RibbonFold - Clean Setup for Google Colab

This notebook provides a streamlined setup for running RibbonFold on Google Colab. All redundant installations and debugging code have been removed.

## Important Notes:
- This notebook is specifically designed for Google Colab
- Requires GPU runtime (T4 or better recommended)
- Update file paths in cells 8-9 to match your input data


In [1]:
#@title 1. Install conda and setup environment
!pip install -q condacolab
import condacolab
condacolab.install()


✨🍰✨ Everything looks OK!


In [2]:
#@title 2. Create Python 3.9 environment with CUDA 11.8 and upload please
import condacolab
from google.colab import files
import shutil
condacolab.check()
!mamba create -y -n ribbon_env python=3.9
!mamba install -y -n ribbon_env -c nvidia cudatoolkit=11.8

# Upload files
uploaded = files.upload()

# Loop through uploaded files and rename accordingly
for fname in uploaded.keys():
    if fname.lower().endswith((".fa", ".fasta")):
        shutil.move(fname, "/content/protein.fasta")
        print(f"{fname} saved as /content/protein.fasta")
    elif fname.lower().endswith((".a3m",)):
        shutil.move(fname, "/content/msa.a3m")
        print(f"{fname} saved as /content/msa.a3m")
    else:
        print(f"Skipping {fname} (not fasta or a3m)")

✨🍰✨ Everything looks OK!

Looking for: ['python=3.9']

[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0G[+] 0.1s
conda-forge/linux-64   1%
conda-forge/noarch    ⣾  [2K[1A[2K[1A[2K[0G[+] 0.2s
conda-forge/linux-64  13%
conda-forge/noarch    17%[2K[1A[2K[1A[2K[0G[+] 0.3s
conda-forge/linux-64  25%
conda-forge/noarch    42%[2K[1A[2K[1A[2K[0G[+] 0.4s
conda-forge/linux-64  36%
conda-forge/noarch    66%[2K[1A[2K[1A[2K[0G[+] 0.5s
conda-forge/linux-64  49%
conda-forge/noarch    92%[2K[1A[2K[1A[2K[0G[+] 0.6s
conda-forge/linux-64  49%
conda-forge/noarch    92%[2K[1A[2K[1A[2K[0Gconda-forge/noarch                                
[+] 0.7s
conda-forge/linux-64  64%[2K[1A[2K[0G[+] 0.8s
conda-forge/linux-64  84%[2K[1A[2K[0G[+] 0.9s
conda-forge/linux-64  96%[2K[1A[2K[0G[+] 1.0s
conda-forge/linux-64  96%[2K[1A[2K[0Gconda-forge/linux-64                              
[?25hTransaction

  Prefix: /usr/local/envs/ribbon_env

  Updating specs:

   - python=3.9


  Pack

Saving 5oqv.fasta to 5oqv.fasta
Saving 5oqv_fixed.a3m to 5oqv_fixed.a3m
5oqv.fasta saved as /content/protein.fasta
5oqv_fixed.a3m saved as /content/msa.a3m


In [None]:
#@title 3. Install PyTorch 2.0.1 with CUDA 11.8
!conda run -n ribbon_env pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 -f https://download.pytorch.org/whl/torch_stable.html

# Verify installation
!conda run -n ribbon_env python -c "import torch; print('PyTorch:', torch.__version__); print('CUDA available:', torch.cuda.is_available())"


In [None]:
#@title 4. Install all required dependencies with compatible versions
# Install TensorFlow first with its required numpy version
!conda run -n ribbon_env pip install tensorflow-cpu==2.6.0 tensorflow-estimator==2.6.0

# Install remaining dependencies with compatible versions
!conda run -n ribbon_env pip install numpy==1.19.5 pandas==1.3.5 scipy==1.7.3 protobuf==3.19.6 torchtyping==0.1.4 functorch biopython dm-tree treelib tqdm ml_collections pytz python-dateutil contextlib2 PyYAML wrapt==1.12.1 gast==0.4.0 google-pasta==0.2.0 h5py==3.1.0 opt-einsum==3.3.0 termcolor==1.1.0 astunparse==1.6.3 flatbuffers==1.12 six==1.15.0 keras_preprocessing==1.1.2 keras_applications==1.0.8 absl-py==0.15.0


In [None]:
#@title 5. Clone RibbonFold repository
!git clone https://github.com/thp42/RibbonFold.git ribbonfold


In [None]:
#@title 6. Download and extract model checkpoints
!wget -O /content/model_checkpoints.tar.gz "https://zenodo.org/record/15128410/files/model_checkpoints.tar.gz?download=1"

# Extract model checkpoints with proper handling
!mkdir -p /content/ribbonfold/ckpt
!cd /content && tar -tzf model_checkpoints.tar.gz | head -5  # Check tar structure
!cd /content && tar -xzf model_checkpoints.tar.gz
!find /content -name "*.pt" -type f 2>/dev/null | head -5  # Find the extracted .pt files
!cp /content/model_ckpt_001.pt /content/ribbonfold/ckpt/ 2>/dev/null || find /content -name "model_ckpt_001.pt" -exec cp {} /content/ribbonfold/ckpt/ \;
!ls -la /content/ribbonfold/ckpt/  # Verify extraction


In [None]:
#@title 7. Setup Python path for AlphaFold modules
!conda run -n ribbon_env python -c "import site, pathlib; p=pathlib.Path(site.getsitepackages()[0])/'ribbonfold-af2.pth'; p.write_text('/content/ribbonfold/af2\\n'); print('Created path file:', p)"


In [None]:
#@title 8. Process MSA file (Update paths as needed)
INPUT_FASTA = "/content/protein.fasta"     # Update this path
INPUT_A3M   = "/content/msa.a3m"   # Update this path
OUT_PKL     = "/content/output.pkl.gz"

!conda run -n ribbon_env python /content/ribbonfold/process_msa_file.py --input_fasta {INPUT_FASTA} --msa_file {INPUT_A3M} --output {OUT_PKL}


In [None]:
#@title 9. Run RibbonFold inference
%cd /content/ribbonfold

# Update these parameters as needed
CHECKPOINT_PATH = "./ckpt/model_ckpt_001.pt"
INPUT_PKL = "/content/output.pkl.gz"
RIBBON_NAME = "5oqv"
OUTPUT_DIR = "./results"
ROUNDS = 10

!conda run -n ribbon_env python inference.py --checkpoint {CHECKPOINT_PATH} --input_pkl {INPUT_PKL} --ribbon_name {RIBBON_NAME} --output_dir {OUTPUT_DIR} --rounds {ROUNDS} --use_dropout true --use_init_structure true


In [None]:
#@title 10. View results
!ls -la /content/ribbonfold/results/
print("\nInference completed! Check the results directory for output PDB files.")


In [None]:
#@title 11. Zip and download results
import os
from google.colab import files
import zipfile

# Create zip file with results
results_path = "/content/ribbonfold/results"
zip_filename = "/content/ribbonfold_results.zip"

if os.path.exists(results_path):
    with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files_list in os.walk(results_path):
            for file in files_list:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, results_path)
                zipf.write(file_path, arcname)

    print(f"Results zipped to: {zip_filename}")
    print(f"Zip file size: {os.path.getsize(zip_filename) / (1024*1024):.2f} MB")

    # Download the zip file
    files.download(zip_filename)
    print("Download started! Check your browser's download folder.")
else:
    print("Results directory not found. Make sure inference completed successfully.")


In [None]:
#@title 12. Interactive Confidence Plots - View All Models {run: "auto"}
import os
import sys
import json
import glob
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, HTML, clear_output
import base64
from html import escape
import ipywidgets as widgets
from IPython.display import display

# Configuration
RESULTS_DIR = "/content/ribbonfold/results"  # Update this path if needed

def find_all_models(results_dir):
    """Find all models and rank them by mean pLDDT"""
    confidence_files = glob.glob(os.path.join(results_dir, "*confidence_*.json"))

    if not confidence_files:
        print("No confidence files found!")
        return []

    models = []

    for conf_file in confidence_files:
        try:
            with open(conf_file, 'r') as f:
                data = json.load(f)

            # Extract model info
            filename = os.path.basename(conf_file)
            mean_plddt = data.get('mean_plddt', 0)
            max_plddt = data.get('max_plddt', 0)
            min_plddt = data.get('min_plddt', 0)

            # Extract model identifier from filename
            model_id = filename.replace('_confidence.json', '').replace('_confidence_', '_')

            models.append({
                'id': model_id,
                'file': conf_file,
                'data': data,
                'mean_plddt': mean_plddt,
                'max_plddt': max_plddt,
                'min_plddt': min_plddt,
                'filename': filename
            })

        except Exception as e:
            print(f"Error reading {conf_file}: {e}")
            continue

    # Sort by mean pLDDT (descending)
    models.sort(key=lambda x: x['mean_plddt'], reverse=True)

    # Add rank information
    for i, model in enumerate(models):
        model['rank'] = i + 1

    return models

def plot_plddt(plddt_scores, save_path, title_suffix=""):
    """Plot pLDDT confidence scores"""
    plt.figure(figsize=(12, 4))
    residue_indices = list(range(1, len(plddt_scores) + 1))

    # Color mapping for confidence levels
    colors = []
    for score in plddt_scores:
        if score > 90:
            colors.append('#0053D6')  # Very high (blue)
        elif score > 70:
            colors.append('#65CBF3')  # Confident (light blue)
        elif score > 50:
            colors.append('#FFDB13')  # Low (yellow)
        else:
            colors.append('#FF7D45')  # Very low (orange)

    plt.bar(residue_indices, plddt_scores, color=colors, width=1.0)
    plt.xlabel('Residue')
    plt.ylabel('Predicted LDDT')
    plt.title(f'Predicted LDDT - {title_suffix} (Mean: {np.mean(plddt_scores):.1f})')
    plt.ylim(0, 100)
    plt.grid(True, alpha=0.3, axis='y')

    # Add mean pLDDT text
    mean_plddt = np.mean(plddt_scores)
    plt.text(0.02, 0.98, f'Mean pLDDT: {mean_plddt:.1f}',
             transform=plt.gca().transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    return save_path

def plot_pae(pae_matrix, save_path, title_suffix=""):
    """Plot PAE heatmap"""
    pae_array = np.array(pae_matrix)
    plt.figure(figsize=(8, 8))

    im = plt.imshow(pae_array, cmap='viridis_r', vmin=0, vmax=np.max(pae_array))
    cbar = plt.colorbar(im, shrink=0.8)
    cbar.set_label('Expected position error (Å)', rotation=270, labelpad=20)

    plt.xlabel('Scored residue')
    plt.ylabel('Aligned residue')
    plt.title(f'Predicted Aligned Error - {title_suffix} (Max: {np.max(pae_array):.1f} Å)')

    # Add max PAE text
    max_pae = np.max(pae_array)
    plt.text(0.02, 0.98, f'Max PAE: {max_pae:.1f} Å',
             transform=plt.gca().transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    return save_path

def plot_coverage_placeholder(num_residues, save_path, title_suffix=""):
    """Create placeholder coverage plot"""
    plt.figure(figsize=(12, 4))

    residue_indices = list(range(1, num_residues + 1))
    # Simulate coverage pattern (you could replace with actual MSA data)
    np.random.seed(42)  # For reproducible coverage pattern
    coverage = np.random.beta(2, 1, num_residues) * 100

    plt.bar(residue_indices, coverage, color='#2E8B57', width=1.0)
    plt.xlabel('Residue')
    plt.ylabel('Coverage (%)')
    plt.title(f'MSA Coverage - {title_suffix} (simulated)')
    plt.ylim(0, 100)
    plt.grid(True, alpha=0.3, axis='y')

    mean_coverage = np.mean(coverage)
    plt.text(0.02, 0.98, f'Mean coverage: {mean_coverage:.1f}%',
             transform=plt.gca().transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    return save_path

def create_plots_for_model(model, results_dir):
    """Create plots for a specific model"""
    data = model['data']
    model_id = model['id']
    rank = model['rank']

    # Create title suffix with rank and pLDDT info
    title_suffix = f"Rank #{rank} ({model_id})"

    output_prefix = os.path.join(results_dir, f"rank_{rank}_{model_id}")

    plot_files = {}

    # Plot pLDDT
    if 'plddt' in data:
        plddt_file = f"{output_prefix}_plddt.png"
        plot_plddt(data['plddt'], plddt_file, title_suffix)
        plot_files['plddt'] = plddt_file

    # Plot PAE if available
    if 'pae' in data:
        pae_file = f"{output_prefix}_pae.png"
        plot_pae(data['pae'], pae_file, title_suffix)
        plot_files['pae'] = pae_file

    # Create coverage placeholder
    if 'plddt' in data:
        coverage_file = f"{output_prefix}_coverage.png"
        plot_coverage_placeholder(len(data['plddt']), coverage_file, title_suffix)
        plot_files['coverage'] = coverage_file

    return plot_files

def image_to_data_url(filename):
    """Convert image to data URL for HTML display"""
    if not os.path.exists(filename):
        return ""
    ext = filename.split('.')[-1]
    prefix = f'data:image/{ext};base64,'
    with open(filename, 'rb') as f:
        img = f.read()
    return prefix + base64.b64encode(img).decode('utf-8')

def display_model_plots(model, plot_files, results_dir):
    """Display plots for the selected model"""
    rank = model['rank']
    model_id = model['id']
    data = model['data']

    # Get job name from results directory
    jobname = os.path.basename(results_dir.rstrip('/'))

    # Prepare image data URLs
    pae = ""
    plddt = ""
    coverage = ""

    if 'pae' in plot_files and os.path.isfile(plot_files['pae']):
        pae = image_to_data_url(plot_files['pae'])

    if 'plddt' in plot_files and os.path.isfile(plot_files['plddt']):
        plddt = image_to_data_url(plot_files['plddt'])

    if 'coverage' in plot_files and os.path.isfile(plot_files['coverage']):
        coverage = image_to_data_url(plot_files['coverage'])

    # Create model statistics
    stats_html = f"""
    <div style="background-color: #f0f0f0; padding: 15px; margin: 10px 0; border-radius: 5px;">
        <h3>Model Statistics - Rank #{rank}</h3>
        <div style="display: flex; flex-wrap: wrap; gap: 20px;">
            <div><strong>Model ID:</strong> {model_id}</div>
            <div><strong>Mean pLDDT:</strong> {data.get('mean_plddt', 0):.2f}</div>
            <div><strong>Max pLDDT:</strong> {data.get('max_plddt', 0):.2f}</div>
            <div><strong>Min pLDDT:</strong> {data.get('min_plddt', 0):.2f}</div>
            { '<div><strong>Max PAE:</strong> ' + f"{data.get('max_pae', 0):.2f} Å</div>" if 'max_pae' in data else '' }
        </div>
    </div>
    """

    # Display the plots using HTML
    display(HTML(f"""
    <style>
      img {{
        float:left;
        margin: 10px;
      }}
      .full {{
        max-width:90%;
        clear: both;
      }}
      .half {{
        max-width:45%;
      }}
      @media (max-width:640px) {{
        .half {{
          max-width:100%;
        }}
      }}
      .plot-container {{
        margin-bottom: 20px;
        overflow: hidden;
      }}
    </style>
    <div style="max-width:95%; padding:2em;">
      <h1>Confidence Plots for {escape(jobname)} - Rank #{rank}</h1>
      {stats_html}
      <div class="plot-container">
        { '<!--' if pae == '' else '' }<img src="{pae}" class="full" alt="PAE Plot" />{ '-->' if pae == '' else '' }
      </div>
      <div class="plot-container">
        <img src="{plddt}" class="half" alt="pLDDT Plot" />
        <img src="{coverage}" class="half" alt="Coverage Plot" />
      </div>
      <div style="clear: both; margin-top: 20px;">
        <p><strong>Plot descriptions:</strong></p>
        <ul>
          <li><strong>pLDDT Plot:</strong> Per-residue confidence scores (0-100). Higher scores indicate more reliable predictions.</li>
          { '<li><strong>PAE Plot:</strong> Predicted Aligned Error between residue pairs. Lower values (darker) indicate more confident relative positions.</li>' if pae != '' else '' }
          <li><strong>Coverage Plot:</strong> MSA coverage information (placeholder - shows estimated coverage pattern).</li>
        </ul>
      </div>
    </div>
    """))

# Global variables for widget interaction
current_models = []
current_results_dir = ""

def on_model_change(change):
    """Handle model selection change"""
    global current_models, current_results_dir

    if not current_models:
        return

    selected_rank = change['new']
    selected_model = None

    for model in current_models:
        if model['rank'] == selected_rank:
            selected_model = model
            break

    if selected_model:
        # Clear previous output
        clear_output(wait=True)

        # Recreate the dropdown
        model_options = [(f"Rank #{model['rank']}: {model['id']} (pLDDT: {model['mean_plddt']:.2f})",
                         model['rank']) for model in current_models]

        dropdown = widgets.Dropdown(
            options=model_options,
            value=selected_rank,
            description='Select Model:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='600px')
        )
        dropdown.observe(on_model_change, names='value')

        print("🎯 Interactive Model Viewer")
        print("=" * 50)
        display(dropdown)

        # Create and display plots for selected model
        print(f"\nGenerating plots for Rank #{selected_model['rank']}...")
        plot_files = create_plots_for_model(selected_model, current_results_dir)
        display_model_plots(selected_model, plot_files, current_results_dir)

# Main execution
print("🔍 Searching for models...")
print(f"Looking for results in: {RESULTS_DIR}")

# Find the most recent results directory (if multiple exist)
if os.path.exists(RESULTS_DIR):
    result_subdirs = [d for d in os.listdir(RESULTS_DIR) if os.path.isdir(os.path.join(RESULTS_DIR, d))]
    if result_subdirs:
        # Use the most recent directory
        result_subdirs.sort()
        actual_results_dir = os.path.join(RESULTS_DIR, result_subdirs[-1])
    else:
        actual_results_dir = RESULTS_DIR
else:
    actual_results_dir = RESULTS_DIR

current_results_dir = actual_results_dir
print(f"Using results directory: {actual_results_dir}")

# Check if results directory exists
if not os.path.exists(actual_results_dir):
    print(f"❌ Results directory not found: {actual_results_dir}")
    print("Make sure inference has completed successfully.")
else:
    # Find and rank all models
    try:
        models = find_all_models(actual_results_dir)
        current_models = models

        if models:
            print(f"✅ Found {len(models)} models")
            print("\n📊 Model Rankings:")
            print("-" * 60)
            for model in models:
                print(f"Rank #{model['rank']}: {model['id']} - pLDDT: {model['mean_plddt']:.2f}")

            # Create dropdown widget with all models
            model_options = [(f"Rank #{model['rank']}: {model['id']} (pLDDT: {model['mean_plddt']:.2f})",
                             model['rank']) for model in models]

            dropdown = widgets.Dropdown(
                options=model_options,
                value=1,  # Start with best model (rank 1)
                description='Select Model:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='600px')
            )

            dropdown.observe(on_model_change, names='value')

            print("\n🎯 Interactive Model Viewer")
            print("=" * 50)
            print("Use the dropdown below to switch between models:")
            display(dropdown)

            # Display best model initially
            best_model = models[0]  # First in sorted list
            print(f"\nShowing best model: Rank #{best_model['rank']} (pLDDT: {best_model['mean_plddt']:.2f})")

            plot_files = create_plots_for_model(best_model, actual_results_dir)
            display_model_plots(best_model, plot_files, actual_results_dir)

        else:
            print("❌ No confidence data found.")
            print("Make sure the inference completed successfully and confidence JSON files exist.")

    except Exception as e:
        print(f"❌ Error processing models: {e}")
        import traceback
        traceback.print_exc()
