<a href="https://github.com/zshengyu14/CoDropleT/blob/main/CoDropleT_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Install dependencies
# git clone
!git clone https://github.com/zshengyu14/CoDropleT.git
# add to path
import sys
import os
sys.path.append('./CoDropleT')

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/zshengyu14/Colabfold_distmats'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("pip uninstall -y jax jaxlib")
    os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  os.system("touch COLABFOLD_READY")

In [None]:
#@title Upload protein features

# @markdown ### 📂 Upload and Unzip Protein ZIP Files in Colab
from google.colab import files
import zipfile
import os

# @markdown #### 1️⃣ Upload & Unzip **Protein1.zip**
# @markdown Please click “Browse” to select your first ZIP (Protein1).
print("Please upload the first ZIP (Protein1)…")
uploaded1 = files.upload()

for fname in uploaded1:
    if fname.lower().endswith('.zip'):
        protein1_name = fname.split('.')[0]  # Get the base name without extension
        with zipfile.ZipFile(fname, 'r') as z:
            z.extractall('Protein1')
        print(f"✅ Unzipped {fname} → ./Protein1/")
    else:
            print(f"❌ {fname} is not a valid ZIP file. Please upload a ZIP file generated by ColabFold.")

# @markdown #### 2️⃣ Upload & Unzip **Protein2.zip**
# @markdown Now select your second ZIP (Protein2).
print("Please upload the second ZIP (Protein2)…")
uploaded2 = files.upload()

for fname in uploaded2:
    if fname.lower().endswith('.zip'):
        protein2_name = fname.split('.')[0]  # Get the base name without extension
        with zipfile.ZipFile(fname, 'r') as z:
            z.extractall('Protein2')
        print(f"✅ Unzipped {fname} → ./Protein2/")
    else:
            print(f"❌ {fname} is not a valid ZIP file. Please upload a ZIP file generated by ColabFold.")

In [None]:
#@title Run CoDropleT

def extra_protein_infromation(protein_name,prefix):
    """
    Extracts protein information from the specified directory.
    
    Args:
        protein_name (str): The name of the protein directory.
        
    Returns:
        dict: A dictionary containing the protein information.
    """
    import os
    import json
    
    protein_info = {}
    
    # Check if the directory exists
    if not os.path.exists(os.path.join(prefix, protein_name)):
        print(f"Directory {os.path.join(prefix, protein_name)} does not exist.")
        return None
    a3m_file = os.path.join(prefix, protein_name, protein_name + '.a3m')
    with open(a3m_file, 'r') as f:
        a3m_lines = f.readlines()
        seq = a3m_lines[2].strip()
    protein_info['seq'] = seq
    length = len(seq)
    protein_info['length'] = length
    # find protein pdb name, start with protein name and ends with .pdb
    pdb_files = [f for f in os.listdir(os.path.join(prefix, protein_name)) if f.startswith(protein_name) and f.endswith('.pdb')]
    if pdb_files:
        protein_info['pdb'] = os.path.join(prefix, protein_name, pdb_files[0])
    return protein_info

input_protein1 = extra_protein_infromation(protein1_name,'./Protein1')
if input_protein1 is None:
    raise ValueError(f"Failed to extract information for Protein1: {protein1_name}")
input_protein2 = extra_protein_infromation(protein2_name,'./Protein2')
if input_protein2 is None:
    raise ValueError(f"Failed to extract information for Protein2: {protein2_name}")

input_pair = {
    'raw_id': 0,
    'id_1': protein1_name,
    'len_1': input_protein1['length'],
    'dir_1': f'./Protein1/{protein1_name}',
    'seq_1': input_protein1['seq'],
    'id_2': protein2_name,
    'len_2': input_protein2['length'],
    'dir_2': f'./Protein2/{protein2_name}',
    'seq_2': input_protein2['seq'],
}

from CoDropleT.run_model import run_inference_colab
import pandas as pd

input_csv = pd.DataFrame([input_pair])
input_csv.to_csv('input.csv', index=False)

run_inference_colab('input.csv')
print("Inference completed.")
result_df= pd.read_csv('results/output.txt',header=None, names=['raw_id', 'score'])
print(f"\nCoDropleT scores between {protein1_name} and {protein2_name}:")
print(result_df.loc[0, 'score'])

In [None]:
#@title Visualize profile
from CoDropleT.utils import update_pdb_b_factors
import pickle
import py3Dmol
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LinearSegmentedColormap
from matplotlib.cm import ScalarMappable
from IPython.display import display, Markdown

profile = pickle.load(open('results/profiles.pkl', 'rb'))[0]
len_1 = input_protein1['length']
len_2 = input_protein2['length']
profile_protein1 = profile[:len_1]
profile_protein2 = profile[len_1:len_1+len_2]

if input_protein1['pdb'] is None or input_protein2['pdb'] is None:
    raise ValueError("PDB files for one or both proteins are missing. Please ensure the ZIP files contain valid PDB files to visualize the profiles.")

update_pdb_b_factors(f"{input_protein1['pdb']}", "protein1.pdb", profile_protein1)
update_pdb_b_factors(f"{input_protein2['pdb']}", "protein2.pdb", profile_protein2)


# 3D display
def show_pdb(
    pdb_file,
    show_sidechains=False,
    show_mainchains=False,
    add_colorbar=False,
    vmin=None,
    vmax=None,
    bar_label="Profile"
):
    # --- 3D view ---
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    view.addModel(open(pdb_file,'r').read(), 'pdb')
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient':'roygb'}}})
    if show_sidechains:
        BB = ['C','O','N']
        view.addStyle({'and':[
            {'resn':["GLY","PRO"],'invert':True},
            {'atom':BB,'invert':True}
        ]}, {'stick':{'colorscheme':"WhiteCarbon",'radius':0.3}})
        view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                      {'sphere':{'colorscheme':"WhiteCarbon",'radius':0.3}})
        view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                      {'stick':{'colorscheme':"WhiteCarbon",'radius':0.3}})
    if show_mainchains:
        BB = ['C','O','N','CA']
        view.addStyle({'atom':BB},
                      {'stick':{'colorscheme':"WhiteCarbon",'radius':0.3}})
    view.zoomTo()
    view.show()

    # --- horizontal colorbar with tick labels ---
    if add_colorbar:
        vmin = 0    if vmin is None else vmin
        vmax = 100  if vmax is None else vmax
        cmap = LinearSegmentedColormap.from_list('roygb',
                                                 ['red','orange','yellow','green','blue'])
        norm = Normalize(vmin=vmin, vmax=vmax)
        sm = ScalarMappable(norm=norm, cmap=cmap)
        sm.set_array([vmin, vmax])   # ← necessary to get ticks

        fig, ax = plt.subplots(figsize=(6, 0.4), dpi=100)
        cbar = fig.colorbar(sm, cax=ax, orientation='horizontal')
        cbar.set_label(bar_label, labelpad=4)
        # place ticks at min, mid, max
        mid = (vmin + vmax) / 2
        cbar.set_ticks([vmin, mid, vmax])
        cbar.set_ticklabels([f"{vmin:.1f}", f"{mid:.1f}", f"{vmax:.1f}"])
        # ensure tick labels are visible
        cbar.ax.xaxis.set_tick_params(labelbottom=True)
        plt.show()

#── Example usage ───────────────────────────────────────────────────────────────
# Protein 1
display(Markdown(f"## {protein1_name}"))
show_pdb(
    "protein1.pdb",
    show_sidechains=False,
    show_mainchains=False,
    add_colorbar=True,
    vmin=profile_protein1.min(),
    vmax=profile_protein1.max(),
    bar_label=f"{protein1_name} profile"
)

# Protein 2
display(Markdown(f"## {protein2_name}"))
show_pdb(
    "protein2.pdb",
    show_sidechains=False,
    show_mainchains=False,
    add_colorbar=True,
    vmin=profile_protein2.min(),
    vmax=profile_protein2.max(),
    bar_label=f"{protein2_name} profile"
)


#── NEW: 2D profile plot with legend ────────────────────────────────────────────
import matplotlib.pyplot as plt

# x-axes for each protein
x1 = range(1, len_1 + 1)
x2 = range(len_1 + 1, len_1 + len_2 + 1)

plt.figure(figsize=(8,4), dpi=100)
plt.plot(x1, profile_protein1, label=protein1_name)
plt.plot(x2, profile_protein2, label=protein2_name)
plt.xlabel('Residue Number')
plt.ylabel('Co-condensation Profile Score')
plt.title('Per‐Residue Co-condensation Profile')
plt.legend(frameon=False, loc='best')
plt.tight_layout()
plt.show()