In [54]:
import os
import sys
import gzip
import shutil
import tempfile
import warnings

import numpy as np
import pandas as pd
from tqdm import tqdm
import requests

import torch
import esm
from esm.inverse_folding.util import (
    load_structure,
    extract_coords_from_structure,
    get_encoder_output,
)

import training_utils.partitioning_utils as pat_utils


warnings.simplefilter(action="ignore", category=FutureWarning)

# dataframe_path = "/work3/s232958/data/PPint_DB/disordered_interfaces_no_cutoff_filtered_nonredundant80_3å_5.csv.gz"
dataframe_path = "/work3/s232958/data/PPint_DB/PPint_train.csv"
download_dir = "/work3/s232958/data/PPint_DB/pdb_cache"
name_column = "PDB_chain_name"
sequence_column = "sequence"

# Where to save ESM-IF embeddings (.npy per chain)
path_to_output_embeddings = "/work3/s232958/data/PPint_DB/esmif_embeddings_full"

# Root of your local PDB mirror (with subfolders a0, a1, etc.)
root_pdb_dir = "/novo/users/cpjb/rdd/PDB_mirror_pdb"

In [55]:
def download_pdb(pdb_id: str, out_dir: str) -> str:
    """
    Download PDB file (.pdb) from RCSB if not already cached.
    Returns path to the downloaded file.
    """
    pdb_id = pdb_id.lower()
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{pdb_id}.pdb.gz")

    if os.path.exists(out_path):
        return out_path

    url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
    # print(f"Downloading {pdb_id} from RCSB...", file=sys.stderr)

    r = requests.get(url)
    if r.status_code != 200:
        raise FileNotFoundError(f"Failed to download {pdb_id} from RCSB")

    # Save gzipped version for compatibility with your loader
    with gzip.open(out_path, "wb") as f:
        f.write(r.content)

    return out_path

def index_pdb_files(root_dir: str):
    """
    Walk the PDB mirror and build:
        pdb_id (4-letter, lowercase) -> full path to .pdb.gz

    Adjust filename parsing if your mirror uses a different naming scheme.
    """
    pdb_index = {}

    for dirpath, dirnames, filenames in os.walk(root_dir):
        for fn in filenames:
            fn_lower = fn.lower()
            if not fn_lower.endswith(".pdb.gz"):
                continue

            # strip ".pdb.gz"
            stem = fn_lower[:-7]

            # Common naming patterns:
            # - "1abc.pdb.gz"           -> pdb_id = "1abc"
            # - "pdb1abc.ent.gz" etc.   -> adjust below if needed
            if stem.startswith("pdb") and len(stem) >= 7:
                # e.g., "pdb1abc" -> "1abc"
                pdb_id = stem[3:7]
            else:
                # assume first 4 chars are the PDB id
                pdb_id = stem[:4]

            pdb_index[pdb_id] = os.path.join(dirpath, fn)

    return pdb_index


def load_structure_from_gz(path_gz: str, chain_id: str):
    """
    Decompress a .pdb.gz to a temporary file, load with esm's load_structure,
    and return the Structure object.
    """
    if not os.path.isfile(path_gz):
        raise FileNotFoundError(f"PDB file not found: {path_gz}")

    with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
        with gzip.open(path_gz, "rb") as fin, open(tmp.name, "wb") as fout:
            shutil.copyfileobj(fin, fout)
        structure = load_structure(tmp.name, chain=chain_id)

    return structure


def calculate_esm_if_embeddings(model, alphabet, pdb_path: str, chain_id: str, device="cuda") -> np.ndarray:
    """
    Returns per-residue ESM-IF encoder embeddings for a single chain.
    Shape: [L, D]
    """
    structure = load_structure_from_gz(pdb_path, chain_id)
    coords, _seq = extract_coords_from_structure(structure)
    coords = torch.as_tensor(coords, dtype=torch.float32, device=device)
    
    device = next(model.parameters()).device
    batch_converter = CoordBatchConverter(alphabet)
    batch = [(coords, None, None)]
    coords, confidence, strs, tokens, padding_mask = batch_converter(batch, device=device)
    with torch.no_grad():
        encoder_out = model.encoder.forward(coords, padding_mask, confidence, return_all_hiddens=False)
    # remove beginning and end (bos and eos tokens)
    return encoder_out['encoder_out'][0][1:-1, 0].cpu().numpy()

In [56]:
# Creating dirs
os.makedirs(path_to_output_embeddings, exist_ok=True)
os.makedirs(download_dir, exist_ok=True)

# Loading Df
print(f"Reading dataframe from: {dataframe_path}", file=sys.stderr)
sequence_df = pd.read_csv(dataframe_path)
sequence_df

Reading dataframe from: /work3/s232958/data/PPint_DB/PPint_train.csv


Unnamed: 0,interface_id,seq1,seq2,ID1,ID2,dimer,seq_target_len,seq_binder_len,target_binder_id
0,6NZA_0,MNTVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIH...,TVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIHAL...,6NZA_0_A,6NZA_0_B,True,461,459,6NZA_0_A_6NZA_0_B
1,9JKA_1,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,9JKA_1_B,9JKA_1_C,True,362,362,9JKA_1_B_9JKA_1_C
2,2YMZ_0,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,2YMZ_0_A,2YMZ_0_B,True,130,130,2YMZ_0_A_2YMZ_0_B
3,1EFP_0,AVLLLGEVTNGALNRDATAKAVAAVKALGDVTVLCAGASAKAAAEE...,MKVLVPVKRLIDYNVKARVKSDGSGVDLANVKMSMNPFDEIAVEEA...,1EFP_0_A,1EFP_0_B,False,307,246,1EFP_0_A_1EFP_0_B
4,6KW6_0,IDWPALRTQARDAMSRAYAPYSGYPVGAAALVDDGRTVTGCNVENA...,IDWPALRTQARDAMSRAYAPYSGYPVGAAALVDDGRTVTGCNVENA...,6KW6_0_A,6KW6_0_B,True,119,119,6KW6_0_A_6KW6_0_B
...,...,...,...,...,...,...,...,...,...
5891,3ON7_2,KLETIDYRAADSAKRFVESLRETGFGVLSNHPIDKELVERIYTEWQ...,KLETIDYRAADSAKRFVESLRETGFGVLSNHPIDKELVERIYTEWQ...,3ON7_2_C,3ON7_2_D,True,263,266,3ON7_2_C_3ON7_2_D
5892,4DPL_2,TLKAAILGATGLVGIEYVRMLSNHPYIKPAYLAGKGSVGKPYGEVV...,RTLKAAILGATGLVGIEYVRMLSNHPYIKPAYLAGKGSVGKPYGEV...,4DPL_2_C,4DPL_2_D,True,353,354,4DPL_2_C_4DPL_2_D
5893,4R5M_0,MRVGLVGWRGMVGSVLMQRMVEERDFDLIEPVFFSTSQIGVPAPNF...,MRVGLVGWRGMVGSVLMQRMVEERDFDLIEPVFFSTSQIGVPAPNF...,4R5M_0_A,4R5M_0_B,True,369,369,4R5M_0_A_4R5M_0_B
5894,6O42_0,EIVLTQSPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPR...,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLE...,6O42_0_L,6O42_0_H,False,214,220,6O42_0_L_6O42_0_H


In [58]:
cols = ["interface_id", "dimer", "target_binder_id"]

# Create long format for target
df_target = pd.DataFrame({
    "interface_id": sequence_df["interface_id"],
    "role": "target",
    "ID": sequence_df["ID1"],
    "sequence": sequence_df["seq1"],
    "length": sequence_df["seq_target_len"],
    "dimer": sequence_df["dimer"],
    "target_binder_id": sequence_df["target_binder_id"]
})

# Create long format for binder
df_binder = pd.DataFrame({
    "interface_id": sequence_df["interface_id"],
    "role": "binder",
    "ID": sequence_df["ID2"],
    "sequence": sequence_df["seq2"],
    "length": sequence_df["seq_binder_len"],
    "dimer": sequence_df["dimer"],
    "target_binder_id": sequence_df["target_binder_id"]
})

# Combine
sequence_df_LONG = pd.concat([df_target, df_binder], ignore_index=True)
sequence_df_LONG["PDB"] = [row["interface_id"].split("_")[0] for __, row in sequence_df_LONG.iterrows()]
sequence_df_LONG["chainname"] = [row["ID"].split("_")[-1] for __, row in sequence_df_LONG.iterrows()]
sequence_df_LONG

Unnamed: 0,interface_id,role,ID,sequence,length,dimer,target_binder_id,PDB,chainname
0,6NZA_0,target,6NZA_0_A,MNTVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIH...,461,True,6NZA_0_A_6NZA_0_B,6NZA,A
1,9JKA_1,target,9JKA_1_B,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,362,True,9JKA_1_B_9JKA_1_C,9JKA,B
2,2YMZ_0,target,2YMZ_0_A,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,130,True,2YMZ_0_A_2YMZ_0_B,2YMZ,A
3,1EFP_0,target,1EFP_0_A,AVLLLGEVTNGALNRDATAKAVAAVKALGDVTVLCAGASAKAAAEE...,307,False,1EFP_0_A_1EFP_0_B,1EFP,A
4,6KW6_0,target,6KW6_0_A,IDWPALRTQARDAMSRAYAPYSGYPVGAAALVDDGRTVTGCNVENA...,119,True,6KW6_0_A_6KW6_0_B,6KW6,A
...,...,...,...,...,...,...,...,...,...
11787,3ON7_2,binder,3ON7_2_D,KLETIDYRAADSAKRFVESLRETGFGVLSNHPIDKELVERIYTEWQ...,266,True,3ON7_2_C_3ON7_2_D,3ON7,D
11788,4DPL_2,binder,4DPL_2_D,RTLKAAILGATGLVGIEYVRMLSNHPYIKPAYLAGKGSVGKPYGEV...,354,True,4DPL_2_C_4DPL_2_D,4DPL,D
11789,4R5M_0,binder,4R5M_0_B,MRVGLVGWRGMVGSVLMQRMVEERDFDLIEPVFFSTSQIGVPAPNF...,369,True,4R5M_0_A_4R5M_0_B,4R5M,B
11790,6O42_0,binder,6O42_0_H,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLE...,220,False,6O42_0_L_6O42_0_H,6O42,H


In [59]:
sequence_df_LONG[name_column] = (sequence_df_LONG["PDB"] + "_" + sequence_df_LONG["chainname"]).tolist()

# Basic cleaning
sequence_df_LONG = sequence_df_LONG[sequence_df_LONG[sequence_column].notna()]
sequence_df_LONG = sequence_df_LONG.drop_duplicates(subset=[name_column, sequence_column])
sequence_df_LONG["pdb_path"] = [f'{row["PDB"].lower()}.pdb.gz' for __, row in df_long.iterrows()]
sequence_df_LONG

Unnamed: 0,interface_id,role,ID,sequence,length,dimer,target_binder_id,PDB,chainname,PDB_chain_name,pdb_path
0,6NZA_0,target,6NZA_0_A,MNTVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIH...,461,True,6NZA_0_A_6NZA_0_B,6NZA,A,6NZA_A,6nza.pdb.gz
1,9JKA_1,target,9JKA_1_B,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,362,True,9JKA_1_B_9JKA_1_C,9JKA,B,9JKA_B,9jka.pdb.gz
2,2YMZ_0,target,2YMZ_0_A,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,130,True,2YMZ_0_A_2YMZ_0_B,2YMZ,A,2YMZ_A,2ymz.pdb.gz
3,1EFP_0,target,1EFP_0_A,AVLLLGEVTNGALNRDATAKAVAAVKALGDVTVLCAGASAKAAAEE...,307,False,1EFP_0_A_1EFP_0_B,1EFP,A,1EFP_A,1efp.pdb.gz
4,6KW6_0,target,6KW6_0_A,IDWPALRTQARDAMSRAYAPYSGYPVGAAALVDDGRTVTGCNVENA...,119,True,6KW6_0_A_6KW6_0_B,6KW6,A,6KW6_A,6kw6.pdb.gz
...,...,...,...,...,...,...,...,...,...,...,...
11787,3ON7_2,binder,3ON7_2_D,KLETIDYRAADSAKRFVESLRETGFGVLSNHPIDKELVERIYTEWQ...,266,True,3ON7_2_C_3ON7_2_D,3ON7,D,3ON7_D,3on7.pdb.gz
11788,4DPL_2,binder,4DPL_2_D,RTLKAAILGATGLVGIEYVRMLSNHPYIKPAYLAGKGSVGKPYGEV...,354,True,4DPL_2_C_4DPL_2_D,4DPL,D,4DPL_D,4dpl.pdb.gz
11789,4R5M_0,binder,4R5M_0_B,MRVGLVGWRGMVGSVLMQRMVEERDFDLIEPVFFSTSQIGVPAPNF...,369,True,4R5M_0_A_4R5M_0_B,4R5M,B,4R5M_B,4r5m.pdb.gz
11790,6O42_0,binder,6O42_0_H,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLE...,220,False,6O42_0_L_6O42_0_H,6O42,H,6O42_H,6o42.pdb.gz


In [60]:
# Downloading PDBs
for __, row in tqdm(sequence_df_LONG.iterrows(), total=len(sequence_df_LONG)):
    pid = row.PDB
    download_pdb(pid, download_dir)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 11731/11731 [35:39<00:00,  5.48it/s]


In [None]:
# Drop rows where we don't have a PDB file
missing_mask = sequence_df["pdb_path"].isna()
if missing_mask.any():
    missing_ids = sequence_df.loc[missing_mask, "PDB"].unique()
    print("Warning: no PDB file found for these IDs:", file=sys.stderr)
    for pid in missing_ids:
        print(f"  {pid}", file=sys.stderr)
    sequence_df = sequence_df[~missing_mask]

In [None]:
# Load ESM-IF model
print("Loading ESM-IF1 model...", file=sys.stderr)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.to(device).eval()
print(f"Model loaded on device: {device}", file=sys.stderr)

In [None]:
# Set of already computed embeddings ----
already_calculated_files = {
    fname[:-4] for fname in os.listdir(path_to_output_embeddings) if fname.endswith(".npy")
}
print(f"Found {len(already_calculated_files)} existing .npy files", file=sys.stderr)

In [6]:
# ---- Main loop ----
for idx, row in tqdm(sequence_df.iterrows(), total=len(sequence_df)):
    name = row[name_column]     # e.g. "1ABC_A"
    if name in already_calculated_files:
        continue

    pdb_path = row["pdb_path"]
    chain_id = row["chainname"]  # e.g. "A"
    try:
        embeddings = calculate_esm_if_embeddings(model, alphabet, pdb_path, chain_id)
    except Exception as e:
        print(f"Failed for {name}: {e}", file=sys.stderr)
        continue

    out_path = os.path.join(path_to_output_embeddings, name + ".npy")
    np.save(out_path, embeddings)

print("Done.", file=sys.stderr)

Reading dataframe from: /work3/s232958/data/PPint_DB/disordered_interfaces_no_cutoff_filtered_nonredundant80_3å_5.csv.gz


Downloading .pdbs...


KeyboardInterrupt: 

In [45]:
files = os.listdir("/work3/s232958/data/PPint_DB/esmif_embeddings_full/")
for file in files:
    embeding = np.load(os.path.join("/work3/s232958/data/PPint_DB/esmif_embeddings_full/", f"{file}"))
    print(embeding.shape)

(130, 512)
(130, 512)
(459, 512)
(362, 512)
(317, 512)
(362, 512)
(461, 512)
(109, 512)
(172, 512)
(97, 512)


In [26]:
print("Loading ESM-IF1 model...", file=sys.stderr)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.to(device).eval()
print(f"Model loaded on device: {device}", file=sys.stderr)

names = [(row[name_column], row["pdb_path"], row["chainname"]) for idx, row in sequence_df[:10].iterrows()]
name, path_gz, chainname = names[0][0], names[0][1], names[0][2]
print(name, path_gz, chainname)

for i in range(len(names)):
    name, path_gz, chainname = names[i][0], names[i][1], names[i][2]
    if not os.path.isfile(path_gz):
        raise FileNotFoundError(f"PDB file not found: {path_gz}")
    
    with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
        with gzip.open(path_gz, "rb") as fin, open(tmp.name, "wb") as fout:
            shutil.copyfileobj(fin, fout)
        structure = load_structure(tmp.name, chain=chain_id)

    coords, _seq = extract_coords_from_structure(structure)
    coords = torch.as_tensor(coords, dtype=torch.float32, device=device)
    
    device = next(model.parameters()).device
    batch_converter = CoordBatchConverter(alphabet)
    batch = [(coords, None, None)]
    coords, confidence, strs, tokens, padding_mask = batch_converter(batch, device=device)
    with torch.no_grad():
        encoder_out = model.encoder.forward(coords, padding_mask, confidence, return_all_hiddens=False)
    # remove beginning and end (bos and eos tokens)
    print(encoder_out['encoder_out'][0][1:-1, 0].shape)

torch.Size([459, 512])
torch.Size([459, 512])
torch.Size([362, 512])
torch.Size([362, 512])
torch.Size([109, 512])
torch.Size([109, 512])
torch.Size([130, 512])
torch.Size([130, 512])
torch.Size([172, 512])
torch.Size([172, 512])
