<a href="https://colab.research.google.com/github/zshengyu14/CoDropleT/blob/main/CoDropleT_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CoDropleT: Interactive Analysis Notebook

Welcome! This notebook provides a complete, interactive workflow for predicting the co-condensation propensity of protein pairs using the **CoDropleT** model.

---

### About the Method & Citation

This notebook is an implementation of the CoDropleT (co-condensation into droplet transformer) model, which leverages protein structures from AlphaFold2 to predict how likely two proteins are to form condensates together.

If you use this work, please cite the original publication:

> Zhang, S., Lim, C. M., Occhetta, M., & Vendruscolo, M. (2024). AlphaFold2-based prediction of the co-condensation propensity of proteins. *PNAS*, 121(34), e2315005121. https://doi.org/10.1073/pnas.2315005121

---

### How to Use This Notebook

This is a text cell written in **Markdown**. It provides information but doesn't run any code. To use the notebook, you must run the **code cells** (the ones with a gray background and a ▶️ button) in order from top to bottom.

#### **Workflow Steps:**

1.  **▶️ Cell 1: Install Dependencies**
    * Run this cell first to set up the required software.

2.  **▶️ Cell 2: Add Proteins Interactively**
    * **Prerequisite:** You must first generate protein representations using the separate **[AlphaFold2 Representations Notebook](https://github.com/zshengyu14/ColabFold_distmats/blob/main/AlphaFold2_representations.ipynb)**.
    * Use the interactive widgets to add your proteins by uploading `.zip` files or selecting them from Google Drive.
    * Click **"Finish and Display Summary"** when you are done.

3.  **▶️ Cell 3: Run CoDropleT on All Protein Pairs**
    * This cell automatically creates all pairs from your input and runs the CoDropleT analysis.

4.  **▶️ Cell 4: Visualize Profiles**
    * Use the interactive controls to select a pair, customize the 3D and 2D plots, and explore the results.

5.  **▶️ Cell 5: Download Results**
    * Run this final cell to download all your inputs and results as a single `.zip` file.

In [None]:
#@title 1. Install Dependencies
# git clone
!git clone https://github.com/zshengyu14/CoDropleT.git
# add to path
import sys
import os
sys.path.append('./CoDropleT')

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/zshengyu14/Colabfold_distmats'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("pip uninstall -y jax jaxlib")
    os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  os.system("touch COLABFOLD_READY")

In [None]:
#@title 2. Add Proteins Interactively
#@markdown ### ❗️ Important Prerequisite
#@markdown Before using this cell, you must first generate representations for each of your proteins. These representations, which include the essential `.npy` and `.a3m` files, should be created using the modified ColabFold notebook available at this link:
#@markdown > 🔗 **[AlphaFold2 Representations Notebook](https://github.com/zshengyu14/ColabFold_distmats/blob/main/AlphaFold2_representations.ipynb)**
#@markdown
#@markdown Once generated, you can provide the data for each protein below, either by uploading the resulting `.zip` file or by selecting it from your Google Drive.
#@markdown ---
#@markdown ### ➕ Add Proteins
#@markdown For each protein, choose your preferred method and click **"Add Protein"**.
#@markdown When you're done, click **"Finish and Display Summary"**.
#@markdown ---

import os
import zipfile
import shutil
import pandas as pd
import ipywidgets as widgets
from google.colab import drive, files, output
from IPython.display import display, clear_output

# --- Function Definitions ---

def extract_protein_information(protein_name, prefix):
    """Extracts protein information (sequence, length, PDB path) from the specified directory."""
    if not os.path.exists(prefix):
        print(f"❌ Directory {prefix} does not exist.")
        return None
    protein_info_out = {}
    a3m_file = os.path.join(prefix, protein_name + '.a3m')
    pdb_files = [f for f in os.listdir(prefix) if f.startswith(protein_name) and f.endswith('.pdb')]
    try:
        with open(a3m_file, 'r') as f:
            lines = f.readlines()
            seq = lines[2].strip() if len(lines) > 2 else ""
        if not seq: raise ValueError(f"Sequence not found or is empty in {a3m_file}")
        protein_info_out['seq'], protein_info_out['length'] = seq, len(seq)
    except Exception as e:
        print(f"❌ Error reading sequence from {a3m_file}: {e}"); return None
    protein_info_out['pdb'] = os.path.join(prefix, pdb_files[0]) if pdb_files else None
    if not protein_info_out['pdb']: print(f"⚠️ Warning: No PDB file found for {protein_name} in {prefix}")
    return protein_info_out

def process_zip_file(zip_path):
    """Processes a single ZIP file from a given path."""
    fname = os.path.basename(zip_path)
    protein_name = os.path.splitext(fname)[0]
    if protein_name.endswith('.result'):
      protein_name = protein_name[:-7]
    if any(p['name'] == protein_name for p in protein_info_list):
        print(f"⚠️ Protein '{protein_name}' has already been added. Skipping."); return
    dest_dir = os.path.join(protein_base_dir, protein_name)
    os.makedirs(dest_dir, exist_ok=True)
    try:
        with zipfile.ZipFile(zip_path, 'r') as z: z.extractall(protein_base_dir)
        print(f"✅ Successfully processed '{protein_name}'.")
        protein_info_list.append({'name': protein_name, 'base_dir': dest_dir})
    except Exception as e:
        print(f"❌ An unexpected error occurred while processing {fname}: {e}")

# --- Widget and State Initialization ---
if 'protein_info_list' not in globals(): protein_info_list = []
protein_base_dir = 'proteins'
if os.path.exists(protein_base_dir): shutil.rmtree(protein_base_dir)
os.makedirs(protein_base_dir, exist_ok=True)

method_dropdown = widgets.Dropdown(options=['Upload ZIP File', 'Browse Google Drive'], description='Method:')
upload_widget = widgets.FileUpload(accept='.zip', multiple=False, description='Upload ZIP:')

# Create Output widgets to act as containers for dynamic content
drive_status_out = widgets.Output()
drive_selector_out = widgets.Output()
# Group the Output widgets in the VBox layout container
gdrive_selector_box = widgets.VBox([drive_status_out, drive_selector_out])

add_button = widgets.Button(description="Add Protein", button_style='success')
finish_button = widgets.Button(description="Finish and Display Summary", button_style='info')
status_output = widgets.Output()
summary_output = widgets.Output()
input_method_box = widgets.VBox([upload_widget])

# --- Widget Event Handlers ---
def on_method_change(change):
    """Handles the logic when the user switches input methods."""
    method = change['new']
    if method == 'Browse Google Drive':
        input_method_box.children = [gdrive_selector_box]
        # Use the correct Output widget to show status
        with drive_status_out:
            clear_output(wait=True)
            print("🚀 Mounting Google Drive...")
            try:
                drive.mount('/content/drive', force_remount=True)
                clear_output(wait=True)
                print("📂 Scanning for .zip files in your Drive... (this may take a moment)")
                zip_files = []
                for root, dirs, files in os.walk('/content/drive/MyDrive'):
                    for file in files:
                        if file.lower().endswith('.zip'):
                            full_path = os.path.join(root, file)
                            display_name = full_path.replace('/content/drive/MyDrive/', '')
                            zip_files.append((display_name, full_path))

                if not zip_files:
                    print("❌ No .zip files found in your Google Drive.")
                    return
                zip_files.sort()

                global gdrive_file_selector
                gdrive_file_selector = widgets.Dropdown(options=zip_files, description='Select File:', style={'description_width': 'initial'})

                # Use the other Output widget to show the file selector
                with drive_selector_out:
                    clear_output(wait=True)
                    display(gdrive_file_selector)
                clear_output() # Clear the "Scanning..." message
            except Exception as e:
                clear_output(wait=True)
                print(f"❌ Failed to mount or scan Google Drive: {e}")

    else: # Upload ZIP File
        input_method_box.children = [upload_widget]

def on_add_button_clicked(b):
    """Processes a single protein based on the selected method."""
    with status_output:
        clear_output(wait=True)
        method = method_dropdown.value
        if method == 'Browse Google Drive':
            if 'gdrive_file_selector' in globals() and gdrive_file_selector.value:
                process_zip_file(gdrive_file_selector.value)
            else:
                print("❌ No file selected from Google Drive. Please choose a file from the dropdown.")
        elif method == 'Upload ZIP File':
            if not upload_widget.value:
                print("❌ Please select a ZIP file to upload."); return
            uploaded_filename = list(upload_widget.value.keys())[0]
            content = upload_widget.value[uploaded_filename]['content']
            temp_zip_path = os.path.join(protein_base_dir, uploaded_filename)
            with open(temp_zip_path, 'wb') as f: f.write(content)
            process_zip_file(temp_zip_path)
            os.remove(temp_zip_path)

        upload_widget.value.clear()
        upload_widget._counter = 0
        print("\n--- Current Proteins ---")
        if protein_info_list:
            for p in protein_info_list: print(f"- {p['name']}")
        else:
            print("None yet.")

def on_finish_button_clicked(b):
    """Finalizes the process and displays the summary table."""
    with summary_output:
        clear_output()
        global full_protein_details
        full_protein_details = []
        if not protein_info_list:
            print("No valid proteins were added."); return
        print("✅ Finished adding proteins! Extracting details...")
        for info in protein_info_list:
            details = extract_protein_information(info['name'], info['base_dir'])
            if details:
                full_protein_details.append({'id': info['name'], 'dir': info['base_dir'], **details})
        if full_protein_details:
            print("📝 Summary of Loaded Proteins:")
            display_data = [{'Protein Name': p['id'], 'Seq. Length': p['length'], 'Sequence (start)': p['seq'][:30] + '...'} for p in full_protein_details]
            protein_df = pd.DataFrame(display_data)
            display(protein_df)
        if not full_protein_details:
            print("\n⚠️ No valid proteins could be processed.")
    ui.children = [finish_button, summary_output]

# --- Display UI ---
method_dropdown.observe(on_method_change, names='value')
add_button.on_click(on_add_button_clicked)
finish_button.on_click(on_finish_button_clicked)

input_box = widgets.VBox([method_dropdown, input_method_box, add_button])
ui = widgets.VBox([input_box, finish_button, status_output, summary_output])
display(ui)
# Trigger the initial UI setup
on_method_change({'new': method_dropdown.value})

In [None]:
#@title 3. Run CoDropleT on All Protein Pairs
import pandas as pd
from itertools import combinations
from CoDropleT.run_model import run_inference_colab

# This cell now uses the 'full_protein_details' variable created in the previous cell.
if 'full_protein_details' not in globals() or len(full_protein_details) < 1:
    raise ValueError("Please go back to Step 2 and upload at least one valid protein.")

# Generate all unique pairs from the detailed list, including self pairs
protein_pairs = list((p, p) for p in full_protein_details)
if len(protein_pairs) > 1:
    protein_pairs += list(combinations(full_protein_details, 2))
input_pairs_list = []
for i, (p1, p2) in enumerate(protein_pairs):
    input_pairs_list.append({
        'raw_id': i,
        'id_1': p1['id'],
        'len_1': p1['length'],
        'dir_1': p1['dir'],
        'seq_1': p1['seq'],
        'id_2': p2['id'],
        'len_2': p2['length'],
        'dir_2': p2['dir'],
        'seq_2': p2['seq'],
    })

# Create and save the input CSV
input_csv = pd.DataFrame(input_pairs_list)
input_csv.to_csv('input.csv', index=False)

print(f"Generated input.csv with {len(input_pairs_list)} protein pairs.")
print(input_csv[['raw_id', 'id_1', 'id_2']])

# Run inference on all pairs
print("\nRunning CoDropleT inference...")
run_inference_colab('input.csv')
print("\nInference completed.")

# Display all scores
result_df = pd.read_csv('results/output.txt', header=None, names=['raw_id', 'score'])
pair_info_df = input_csv[['raw_id', 'id_1', 'id_2']].copy()
result_df = pd.merge(result_df, pair_info_df, on='raw_id')
print("\nCoDropleT scores for all pairs:")
print(result_df[['id_1', 'id_2', 'score']])

In [None]:
#@title 4. Visualize Profiles
#@markdown **Note:** The profile visualized here is an approximation of how each residue contributes to the co-condensation score
#@markdown 1. Select a protein pair and visualization options.
#@markdown 2. Adjust the smoothing window for the 2D plot (1 = no smoothing).
#@markdown 3. Click the "Plot" button to generate the visualization.

import os
import py3Dmol
import pickle
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from CoDropleT.utils import update_pdb_b_factors
from IPython.display import display, Markdown, clear_output
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize, LinearSegmentedColormap

# --- Main Functions and Widget Setup ---

# Load all generated profiles
try:
    with open('results/profiles.pkl', 'rb') as f:
        all_profiles = pickle.load(f)
except FileNotFoundError:
    all_profiles = None
    print("❌ results/profiles.pkl not found. Cannot proceed with visualization.")

# Function to create and return the 3D view object
def create_3d_view(pdb_file, show_sidechains, show_mainchains, vmin, vmax):
    with open(pdb_file, 'r') as f:
        pdb_data = f.read()

    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width='auto', height=400)
    view.addModel(pdb_data, 'pdb')
    view.setStyle({'cartoon': {'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': float(vmin), 'max': float(vmax)}}})

    if show_sidechains:
        BB = ['C', 'O', 'N']
        view.addStyle({'and': [{'resn': ["GLY", "PRO"], 'invert': True}, {'atom': BB, 'invert': True}]}, {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
        view.addStyle({'and': [{'resn': "GLY"}, {'atom': 'CA'}]}, {'sphere': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
        view.addStyle({'and': [{'resn': "PRO"}, {'atom': ['C', 'O'], 'invert': True}]}, {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})
    if show_mainchains:
        BB = ['C', 'O', 'N', 'CA']
        view.addStyle({'atom': BB}, {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}})

    view.zoomTo()
    return view

# Function to create and return a Matplotlib figure for the colorbar
def create_colorbar_fig(vmin, vmax, bar_label):
    cmap = LinearSegmentedColormap.from_list('roygb', ['red', 'orange', 'yellow', 'green', 'blue'])
    norm = Normalize(vmin=vmin, vmax=vmax)

    fig, ax = plt.subplots(figsize=(6, 0.4), dpi=100)
    cbar = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation='horizontal')
    cbar.set_label(bar_label, labelpad=4, fontsize=12)
    mid = (vmin + vmax) / 2
    cbar.set_ticks([vmin, mid, vmax])
    cbar.set_ticklabels([f"{vmin:.2f}", f"{mid:.2f}", f"{vmax:.2f}"])
    fig.patch.set_alpha(0.0) # Make the figure background transparent
    ax.patch.set_alpha(0.0)
    return fig

# --- Widget Definitions ---
if all_profiles is not None:
    pair_options = {f"{row.id_1} vs. {row.id_2}": row.raw_id for index, row in input_csv.iterrows()}
    pair_dropdown = widgets.Dropdown(options=pair_options, description='Select Pair:', style={'description_width': 'initial'})
    sidechain_checkbox = widgets.Checkbox(value=False, description='Show Sidechains')
    mainchain_checkbox = widgets.Checkbox(value=False, description='Show Mainchains')

    # New widget for smoothing
    smoothing_slider = widgets.IntSlider(
        value=1, min=1, max=51, step=2, # Step of 2 encourages odd window sizes
        description='2D Plot Smoothing:', style={'description_width': 'initial'},
        continuous_update=False, readout_format='d'
    )

    plot_button = widgets.Button(description="Plot", button_style='success', tooltip='Click to generate the visualization', icon='check')
    out = widgets.Output()

# --- Button Click Handler ---
def generate_visualization(b):
    with out:
        clear_output(wait=True)
        # --- Get data for the selected pair ---
        pair_id = pair_dropdown.value
        show_sc = sidechain_checkbox.value
        show_mc = mainchain_checkbox.value
        window_size = smoothing_slider.value

        pair_info = input_csv.loc[input_csv['raw_id'] == pair_id].iloc[0]
        p1_name, p2_name = pair_info['id_1'], pair_info['id_2']
        len_1, len_2 = pair_info['len_1'], pair_info['len_2']
        p1_details = next((p for p in full_protein_details if p['id'] == p1_name), None)
        p2_details = next((p for p in full_protein_details if p['id'] == p2_name), None)
        profile = all_profiles[pair_id]
        profile1, profile2 = profile[:len_1], profile[len_1:len_1 + len_2]

        final_widgets_list = []

        # --- 3D Visualization ---
        if not (p1_details and p1_details['pdb'] and p2_details and p2_details['pdb']):
            # Display warning if PDB is missing
            warning_out = widgets.Output()
            with warning_out:
                display(Markdown("⚠️ **Note:** PDB file(s) missing for this pair, skipping 3D visualization."))
            final_widgets_list.append(warning_out)
        else:
            # Update PDBs with original (non-smoothed) data
            update_pdb_b_factors(p1_details['pdb'], "protein1_viz.pdb", profile1)
            update_pdb_b_factors(p2_details['pdb'], "protein2_viz.pdb", profile2)

            # --- Process Protein 1 ---
            title1 = Markdown(f"### 3D Profile for {p1_name}")
            view1 = create_3d_view("protein1_viz.pdb", show_sc, show_mc, profile.min(), profile.max())
            colorbar_fig1 = create_colorbar_fig(profile.min(), profile.max(), f"{p1_name} Profile")
            title1_out, view1_out, colorbar1_out = widgets.Output(), widgets.Output(), widgets.Output()
            with title1_out: display(title1)
            with view1_out: display(view1)
            with colorbar1_out: display(colorbar_fig1)
            final_widgets_list.extend([title1_out, view1_out, colorbar1_out])

            # --- Process Protein 2 ---
            title2 = Markdown(f"### 3D Profile for {p2_name}")
            view2 = create_3d_view("protein2_viz.pdb", show_sc, show_mc, profile.min(), profile.max())
            colorbar_fig2 = create_colorbar_fig(profile.min(), profile.max(), f"{p2_name} Profile")
            title2_out, view2_out, colorbar2_out = widgets.Output(), widgets.Output(), widgets.Output()
            with title2_out: display(title2)
            with view2_out: display(view2)
            with colorbar2_out: display(colorbar_fig2)
            final_widgets_list.extend([title2_out, view2_out, colorbar2_out])

            plt.close('all')

        # --- 2D Profile Plot ---
        # Apply smoothing if window > 1
        profile1_display = np.convolve(profile1, np.ones(window_size)/window_size, mode='same') if window_size > 1 else profile1
        profile2_display = np.convolve(profile2, np.ones(window_size)/window_size, mode='same') if window_size > 1 else profile2

        plot_title = f"### 2D Per-Residue Profile (Smoothing Window: {window_size})"
        title_2d = Markdown(plot_title)
        fig_2d, ax_2d = plt.subplots(figsize=(12, 6), dpi=100)

        # Define x-axis ranges to be consecutive
        x1 = range(1, len_1 + 1)
        x2 = range(len_1 + 1, len_1 + len_2 + 1)

        # Plot original data as faint line if smoothing is active
        if window_size > 1:
            ax_2d.plot(x1, profile1, color='lightblue', alpha=0.6)
            ax_2d.plot(x2, profile2, color='peachpuff', alpha=0.6)

        # Plot main (or smoothed) data
        ax_2d.plot(x1, profile1_display, label=f'{p1_name}', color='royalblue')
        ax_2d.plot(x2, profile2_display, label=f'{p2_name}', color='darkorange')

        ax_2d.set_xlabel('Residue Number (Concatenated)', fontsize=12)
        ax_2d.set_ylabel('Co-condensation Profile Score', fontsize=12)
        ax_2d.set_title('Per‐Residue Co-condensation Profile', fontsize=14)
        ax_2d.legend(frameon=False, loc='best', fontsize=10)
        ax_2d.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()

        title_2d_out, plot_2d_out = widgets.Output(), widgets.Output()
        with title_2d_out: display(title_2d)
        with plot_2d_out: display(fig_2d)
        final_widgets_list.extend([title_2d_out, plot_2d_out])
        plt.close(fig_2d)

        display(widgets.VBox(final_widgets_list))

# --- Display UI ---
if all_profiles is not None:
    plot_button.on_click(generate_visualization)

    controls = widgets.VBox([
        pair_dropdown,
        sidechain_checkbox,
        mainchain_checkbox,
        smoothing_slider, # Add new slider to the controls
        plot_button
    ])

    controls.layout.width = '400px'
    controls.layout.margin = '0 20px 0 0'

    app_layout = widgets.HBox([controls, out])

    display(app_layout)
    generate_visualization(None)

In [None]:
#@title 5. Download Results
#@markdown Run this cell to package the `input.csv` file and the `results` directory into a single ZIP file for download.

import os
import zipfile
from google.colab import files

# Define the name of the output zip file
zip_filename = 'CoDropleT_results.zip'
files_to_zip = ['input.csv']
dirs_to_zip = ['results']

try:
    with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # Add individual files
        for file in files_to_zip:
            if os.path.exists(file):
                zipf.write(file)
                print(f"✅ Added '{file}' to the archive.")
            else:
                print(f"⚠️ Warning: '{file}' not found, skipping.")

        # Add contents of directories
        for dir_path in dirs_to_zip:
            if os.path.isdir(dir_path):
                for root, dirs, files_in_dir in os.walk(dir_path):
                    for file in files_in_dir:
                        file_path = os.path.join(root, file)
                        zipf.write(file_path)
                print(f"✅ Added the '{dir_path}' directory to the archive.")
            else:
                print(f"⚠️ Warning: Directory '{dir_path}' not found, skipping.")

    print(f"\n📦 Successfully created '{zip_filename}'.")
    print("🚀 Starting download...")
    files.download(zip_filename)

except Exception as e:
    print(f"❌ An error occurred: {e}")