In [1]:
# %load_ext autotime
%load_ext autoreload
%autoreload 2

In [20]:
import subprocess
import argparse
import random
import subprocess
import os
import time
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import esm
import biotite.structure.io as bsio
# from biopandas.pdb import PandasPdb
from pathlib import Path
from datetime import datetime

import Bio
import Bio.PDB as bp

def calculate_tm_score(
    pred_path, pdb_path, chain_id=None, use_tmalign=False, verbose=False,
    tmscore_path="/scratch/project/open-32-14/pimenol1/ProteinTTT/ProteinTTT/TMalign",
    tmalign_path="/scratch/project/open-32-14/pimenol1/ProteinTTT/ProteinTTT/TMalign.cpp"
):

    if chain_id is not None:
        raise NotImplementedError("Chain ID is not implemented for TM-score calculation.")

    # Run TMscore and capture the output
    command = [tmalign_path, pdb_path, pred_path] if use_tmalign else [tmscore_path, pred_path, pdb_path]
    result = subprocess.run(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )

    def print_cmd():
        print("TMscore command:")
        print(result.args)
        print("TMscore output:")
        print(result.stdout)
        print("TMscore error:")
        print(result.stderr)
    if verbose:
        print_cmd()

    # Extract TM-score from the output
    for line in result.stdout.split('\n'):
        if line.startswith("TM-score"):
            tm_score = float(line.split('=')[1].split()[0])
            return tm_score

    print_cmd()
    raise ValueError("TM-score not found in the output")

In [3]:
# !pip install autotime

In [4]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")

CUDA available: True


In [5]:
from pathlib import Path
import os
import warnings
import pandas as pd
import tqdm
from Bio.PDB import PDBParser

import esm
import biotite.structure.io as bsio

from proteinttt.models.esmfold import ESMFoldTTT, DEFAULT_ESMFOLD_TTT_CFG

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 10000)
pd.set_option("max_colwidth", 100)
base_path = Path("/scratch/project/open-32-14/pimenol1/ProteinTTT/ProteinTTT/data/all_structures")
SABDAB_SUMMARY_PATH = base_path / Path("sabdab_summary_with_sequences.tsv")
OUTPUT_PATH = base_path / Path('predicted_structures')
PDB_FILES_PATH = base_path / Path("raw")
os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"Summary file: {SABDAB_SUMMARY_PATH}")
print(f"PDB files path: {PDB_FILES_PATH}")
print(f"Output will be saved to: {OUTPUT_PATH}")

Summary file: /scratch/project/open-32-14/pimenol1/ProteinTTT/ProteinTTT/data/all_structures/sabdab_summary_with_sequences.tsv
PDB files path: /scratch/project/open-32-14/pimenol1/ProteinTTT/ProteinTTT/data/all_structures/raw
Output will be saved to: /scratch/project/open-32-14/pimenol1/ProteinTTT/ProteinTTT/data/all_structures/predicted_structures


In [6]:
df = pd.read_csv(SABDAB_SUMMARY_PATH, sep="\t", low_memory=False)
print(f"Entries in summary: {len(df):,}")
df.head(5)

Entries in summary: 3,232


Unnamed: 0,pdb,Hchain,Lchain,model,antigen_chain,antigen_type,antigen_het_name,antigen_name,short_header,date,compound,organism,heavy_species,light_species,antigen_species,authors,resolution,method,r_free,r_factor,scfv,engineered,heavy_subclass,light_subclass,light_ctype,affinity,delta_g,affinity_method,temperature,pmid,seqH,seqL,full_sequence
0,9fys,B,C,0,X,protein,,alpha-bungarotoxin,TOXIN,07/16/25,D11 mAbs bound to alpha-Bungarotoxin,Bungarus multicinctus; Homo sapiens,homo sapiens,homo sapiens,bungarus multicinctus,"Wade, J., Bohn, M.F., Laustsen, A.H., Morth, J.P.",1.32,X-RAY DIFFRACTION,0.192,0.167,False,True,IGHV1,IGLV6,Lambda,,,,,,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...,ASNFMLTQPRSVSESPGKTVTISCTRSSGSIGSDYVHWYQQRPGSSPTTVIYEDNQRPSGVPDRFSGSIDSSSNSASLTISGLKTEDEADYYCQSY...,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...
1,9fys,I,M,0,Z,protein,,alpha-bungarotoxin,TOXIN,07/16/25,D11 mAbs bound to alpha-Bungarotoxin,Bungarus multicinctus; Homo sapiens,homo sapiens,homo sapiens,bungarus multicinctus,"Wade, J., Bohn, M.F., Laustsen, A.H., Morth, J.P.",1.32,X-RAY DIFFRACTION,0.192,0.167,False,True,IGHV1,IGLV6,Lambda,,,,,,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...,ASNFMLTQPRSVSESPGKTVTISCTRSSGSIGSDYVHWYQQRPGSSPTTVIYEDNQRPSGVPDRFSGSIDSSSNSASLTISGLKTEDEADYYCQSY...,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...
2,9fyt,I,M,0,B,protein,,alpha-cobratoxin,TOXIN,07/16/25,mAbs in complex with cobratoxin at pH 4.5,Homo sapiens; Naja kaouthia,homo sapiens,homo sapiens,naja kaouthia,"Wade, J., Bohn, M.F., Laustsen, A.H., Morth, J.P.",1.55,X-RAY DIFFRACTION,0.237,0.186,False,True,IGHV1,IGLV3,Lambda,,,,,,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...,SSYELTQPPSVSVAPGRTATITCEGDNIGQQIVHWYQQKPGQAPVAVISSDSDRPSGIPERFSGSNSGNTATLTISRVEAGDEADYYCQVWDSGSD...,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...
3,9fyt,C,D,0,A,protein,,alpha-cobratoxin,TOXIN,07/16/25,mAbs in complex with cobratoxin at pH 4.5,Homo sapiens; Naja kaouthia,homo sapiens,homo sapiens,naja kaouthia,"Wade, J., Bohn, M.F., Laustsen, A.H., Morth, J.P.",1.55,X-RAY DIFFRACTION,0.237,0.186,False,True,IGHV1,IGLV3,Lambda,,,,,,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...,SSYELTQPPSVSVAPGRTATITCEGDNIGQQIVHWYQQKPGQAPVAVISSDSDRPSGIPERFSGSNSGNTATLTISRVEAGDEADYYCQVWDSGSD...,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELRSLRSDDTAVYYC...
4,9h39,C,,0,A,protein,,lysm type receptor kinase,PLANT PROTEIN,07/16/25,Crystal structure of Lotus japonicus CHIP13 extracellular domain in complex with a nanobody,Lama glama; Lotus japonicus,lama glama,,lotus japonicus,"Gysel, K., Andersen, K.R.",1.66,X-RAY DIFFRACTION,0.225,0.182,False,True,IGHV1,,,,,,,,QLQLVESGGGLVQAGGSLRLSCATSGTTFRLNTMGWYRQAPGKQRELVATISRDFKTNYADSVKGRFTISRDNAKHTVDLQMNSLTPEDTAVYYCL...,,QLQLVESGGGLVQAGGSLRLSCATSGTTFRLNTMGWYRQAPGKQRELVATISRDFKTNYADSVKGRFTISRDNAKHTVDLQMNSLTPEDTAVYYCL...


In [7]:
def parse_antibody_data(df: pd.DataFrame, pdb_dir: Path):
    parser = PDBParser(QUIET=True)
    three_to_one = {
        "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F", "GLY": "G", "HIS": "H",
        "ILE": "I", "LYS": "K", "LEU": "L", "MET": "M", "ASN": "N", "PRO": "P", "GLN": "Q", "ARG": "R",
        "SER": "S", "THR": "T", "VAL": "V", "TRP": "W", "TYR": "Y"
    }

    def extract_chain_sequence(chain_id, structure):
        seq = []
        for chain in structure.get_chains():
            if chain.id == chain_id:
                for res in chain.get_residues():
                    if res.id[0] == " ":
                        seq.append(three_to_one.get(res.resname, "X"))
        return "".join(seq)

    for i, row in df.iterrows():
        pdb_code = row["pdb"].lower()
        pdb_path = pdb_dir / f"{pdb_code}.pdb"

        try:
            structure = parser.get_structure(pdb_code, pdb_path)
            seqH = extract_chain_sequence(row["Hchain"], structure)
            seqL = extract_chain_sequence(row["Lchain"], structure)
        except Exception as e:
            seqH = seqL = f"ERROR: {e}"

        df.at[i, "seqH"] = seqH
        df.at[i, "seqL"] = seqL
        df.at[i, "full_sequence"] = f"{seqH}:{seqL}"

    return df


# df["resolution"] = pd.to_numeric(df["resolution"], errors="coerce")
# df = df.query("resolution < 2.0").copy()
# df_with_seq = parse_antibody_data(df, PDB_FILES_PATH)

In [8]:
# df_with_seq.to_csv(base_path / "sabdab_summary_with_sequences.tsv", sep="\t", index=False)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", device)

base_model = esm.pretrained.esmfold_v1().eval().to(device)

ttt_cfg = DEFAULT_ESMFOLD_TTT_CFG
ttt_cfg.steps = 20          
ttt_cfg.batch_size = 4           
ttt_cfg.seed = 0
# ttt_cfg.lr = 2e-4
model = ESMFoldTTT.ttt_from_pretrained(base_model,
                                       ttt_cfg=ttt_cfg,
                                       esmfold_config=base_model.cfg).to(device)

Using cuda


In [28]:
def predict_structure(model, sequence, pdb_id, tag, out_dir=OUTPUT_PATH):
    with torch.no_grad():
        pdb_str = model.infer_pdb(sequence)

    out_path = out_dir / f"{pdb_id.lower()}_{tag}.pdb"
    out_path.write_text(pdb_str)
    
    struct = bsio.load_structure(out_path, extra_fields=["b_factor"])
    pLDDT = struct.b_factor.mean()
    tm_score = calculate_tm_score(
        pred_path=out_path,
        pdb_path=PDB_FILES_PATH / f"{pdb_id}.pdb")
#     if pLDDT < 70:
#         print(f'{pdb_id}_{tag}_pLDDT:', pLDDT)
    return pLDDT, tm_score

In [None]:
def fold_chain(sequence, pdb_id, *, model, tag, out_dir = OUTPUT_PATH):
    """
    Predict a structure for `sequence` with TTT
    return pLDDT.
    The PDB is written as <out_dir>/<pdb_id>_<tag>.pdb
    """
    model.ttt(sequence)
    pLDDT_after, tm_score_after = predict_structure(model, sequence, pdb_id, tag=f'{tag}_ttt', out_dir=out_dir)
    model.ttt_reset()
    return pLDDT_after, tm_score_after

## Calculate pLDDT and pTM scores before applying ProteinTTT

In [None]:
num_of_low = 0 
for i, row in tqdm.tqdm(df.iterrows()):
    if pd.notna(row.get("pLDDT_before")):
#         if row.get("pLDDT_before") >= 70:
        continue
    pdb_id = str(row.get("pdb"))
    chain_map = {"Hchain": "seqH", "Lchain": "seqL"}
    
    for tag, col in chain_map.items():
        if pd.isna(row[col]):
            continue
        seq = str(row[col]).strip().upper()
        tag = f'{str(row[tag])}'
        
        try:
            pLDDT_before, tm_score_before = predict_structure(model, seq, pdb_id, tag=f'{tag}_before_ttt', out_dir=OUTPUT_PATH)
            df.at[i, 'pLDDT_before'] = pLDDT_before
            df.at[i, 'tm_score_before'] = tm_score_before
            if pLDDT_before < 70:
                num_of_low+=1
        except Exception as e:
            warnings.warn(f"{pdb_id}{tag}: {e}")
            
df.to_csv(base_path / "sabdab_summary_with_sequences.tsv", sep="\t", index=False)

1673it [06:57,  2.34s/it]  

In [None]:
df.query("pLDDT_before <= 70").shape

Unnamed: 0,pdb,Hchain,Lchain,model,antigen_chain,antigen_type,antigen_het_name,antigen_name,short_header,date,compound,organism,heavy_species,light_species,antigen_species,authors,resolution,method,r_free,r_factor,scfv,engineered,heavy_subclass,light_subclass,light_ctype,affinity,delta_g,affinity_method,temperature,pmid,seqH,seqL,full_sequence,pLDDT_before,tm_score_before
31,9lux,C,c,0,,,,,IMMUNE SYSTEM,06/11/25,Single-chain Fv antibody of G2 fused with antigen peptide from chicken prion protein,Gallus gallus; Mus musculus,"gallus gallus, mus musculus",,,"Hanazono, Y., Yabuno, S., Hayashi, T., Numoto, N., Ito, N., Oda, M.",1.85,X-RAY DIFFRACTION,0.226,0.194,True,True,unknown,unknown,unknown,,,,,,GEAVAAANQTEVEMENKVVLEVLFQGPDIVLTQSPASLAVSLGQRATISCKASQTLDYDGGTYMNWYQQKPGQPPKLLIFAASNLESGIPARFSGS...,,GEAVAAANQTEVEMENKVVLEVLFQGPDIVLTQSPASLAVSLGQRATISCKASQTLDYDGGTYMNWYQQKPGQPPKLLIFAASNLESGIPARFSGS...,70.15248,
32,9lux,D,d,0,,,,,IMMUNE SYSTEM,06/11/25,Single-chain Fv antibody of G2 fused with antigen peptide from chicken prion protein,Gallus gallus; Mus musculus,"gallus gallus, mus musculus",,,"Hanazono, Y., Yabuno, S., Hayashi, T., Numoto, N., Ito, N., Oda, M.",1.85,X-RAY DIFFRACTION,0.226,0.194,True,True,unknown,unknown,unknown,,,,,,GEAVAAANQTEVEMENKVVLEVLFQGPDIVLTQSPASLAVSLGQRATISCKASQTLDYDGGTYMNWYQQKPGQPPKLLIFAASNLESGIPARFSGS...,,GEAVAAANQTEVEMENKVVLEVLFQGPDIVLTQSPASLAVSLGQRATISCKASQTLDYDGGTYMNWYQQKPGQPPKLLIFAASNLESGIPARFSGS...,70.15248,
1155,7x1t,D,d,0,B | C,protein | protein,NA | NA,mini-g alpha q protein | guanine nucleotide-binding protein g(i)/g(s)/g(t) subunitbeta-1,MEMBRANE PROTEIN,08/31/22,Structure of Thyrotropin-Releasing Hormone Receptor bound with Taltirelin.,Bos taurus; Homo sapiens; Rattus norvegicus; SYNTHETIC CONSTRUCT,rattus norvegicus,,homo sapiens | bos taurus,"Yang, F., Zhang, H.H., Meng, X.Y., Li, Y.G., Zhou, Y.X., Ling, S.L., Liu, L., Shi, P., Tian, C.L.",0.0,ELECTRON MICROSCOPY,,,True,True,unknown,unknown,unknown,,,,,,VQLVESGGGLVQPGKLSCSASGFAFSSFGMHWVRQAPEWVAYISSGSGTIYYADTVKGRFTISRDDPKNTLFLQMTSLRSEDTAMYYCVRSIYYYG...,,VQLVESGGGLVQPGKLSCSASGFAFSSFGMHWVRQAPEWVAYISSGSGTIYYADTVKGRFTISRDDPKNTLFLQMTSLRSEDTAMYYCVRSIYYYG...,70.056513,0.10517
1156,7x1u,D,d,0,B | F | C,protein | peptide | protein,NA | NA | NA,mini-g alpha q prtoein | guanine nucleotide-binding protein g(i)/g(s)/g(o) subunitgamma-2 | gua...,MEMBRANE PROTEIN,08/31/22,Structure of Thyrotropin-Releasing Hormone Receptor bound with an Endogenous Peptide Agonist TRH.,Bos taurus; Homo sapiens; Rattus norvegicus; SYNTHETIC CONSTRUCT,rattus norvegicus,,homo sapiens | bos taurus | bos taurus,"Yang, F., Zhang, H.H., Meng, X.Y., Li, Y.G., Zhou, Y.X., Ling, S.L., Liu, L., Shi, P., Tian, C.L.",0.0,ELECTRON MICROSCOPY,,,True,True,unknown,unknown,unknown,,,,,,VQLVESGGGLVQPGGLSCSASGFAFSSFGMHWVRQAPEKGLEWVAYISSGSGTIYYADTVKGRFTISRDDPKNTLFSLRSEDTAMYYCVRSIYYYG...,,VQLVESGGGLVQPGGLSCSASGFAFSSFGMHWVRQAPEKGLEWVAYISSGSGTIYYADTVKGRFTISRDDPKNTLFSLRSEDTAMYYCVRSIYYYG...,60.279506,0.10009
1469,7aqx,D,,0,B,protein,,variant surface glycoprotein mitat 1.2,MEMBRANE PROTEIN,10/23/20,Co-Crystal Structure of Variant Surface Glycoprotein VSG2 in complex with Nanobody VSG2(NB9),LAMA GLAMA; TRYPANOSOMA BRUCEI BRUCEI,lama glama,,trypanosoma brucei brucei,"Stebbins, C.E., Hempelmann, A.",1.499,X-RAY DIFFRACTION,0.192,0.164,False,True,IGHV1,,,,,,,,VQLQESGLRLSCTTSGLTFSNYAFSWFRQREFVGAISWSGGRTDYASVKGRFTISRDNAKNTFYLQMNSLDTAVYYCAADLLGEGSRRSEYEYWGQGTQ,,VQLQESGLRLSCTTSGLTFSNYAFSWFRQREFVGAISWSGGRTDYASVKGRFTISRDNAKNTFYLQMNSLDTAVYYCAADLLGEGSRRSEYEYWGQ...,58.131869,0.30879
1472,7aqz,D,,0,A,protein,,variant surface glycoprotein mitat 1.2,MEMBRANE PROTEIN,10/23/20,Co-Crystal Structure of Variant Surface Glycoprotein VSG2 in complex with Nanobody VSG2(NB14),LAMA GLAMA; TRYPANOSOMA BRUCEI BRUCEI,lama glama,,trypanosoma brucei brucei,"Stebbins, C.E., Hempelmann, A., VanStraaten, M.",1.3,X-RAY DIFFRACTION,0.172,0.151,False,True,IGHV1,,,,,,,,QVQLQESLSCEASGLTFSNYAMAWFRQEFVAGISWTGSRTYYADSVRGTSRDGHKNTVYLQMNDTAVYLCAADLLGSGKDGTSVYEYWGQGTQ,,QVQLQESLSCEASGLTFSNYAMAWFRQEFVAGISWTGSRTYYADSVRGTSRDGHKNTVYLQMNDTAVYLCAADLLGSGKDGTSVYEYWGQGTQ:,58.205746,0.2993


In [34]:
df.to_csv(base_path / "sabdab_summary_with_sequences.tsv", sep="\t", index=False)

## Run ProteinTTT on SAbDab dataset for pLDDT less than 70

In [None]:
for i, row in tqdm.tqdm(df.iterrows()):
    if row.get("pLDDT_before") > 70:
        continue
    pdb_id = str(row.get("pdb"))
    chain_map = {"Hchain": "seqH", "Lchain": "seqL"}

    for tag, col in chain_map.items():
        if pd.isna(row[col]):
            continue
        seq = str(row[col]).strip().upper()
        tag = f'{str(row[tag])}'

        try:
            pLDDT_after, tm_score_after = fold_chain(seq, pdb_id, model=model, tag=tag)
            df.at[i, 'pLDDT_after'] = pLDDT_after
            df.at[i, 'tm_score_after'] = tm_score_after
        except Exception as e:
            warnings.warn(f"{pdb_id}{tag}: {e}")

df.to_csv(base_path / "sabdab_summary_with_sequences.tsv", sep="\t", index=False)

In [None]:
df

In [None]:
results = pd.DataFrame(results)
results.to_csv(OUTPUT_PATH / "ttt_results.tsv", sep="\t", index=False)
OUTPUT_PATH

In [None]:
results = pd.read_csv("/scratch/project/open-32-14/pimenol1/ProteinTTT/ProteinTTT/data/predicted_structures/ttt_results2.tsv", sep="\t")

In [None]:
results = results.query("pLDDT_after.notna()")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

df = pd.DataFrame({
    'pLDDT': results['pLDDT_before'] + results['pLDDT_after'],
    'Model': ['ESMFold'] * len(results['pLDDT_before']) + ['ESMFold+ProteinTTT'] * len(results['pLDDT_after'])
})

sns.violinplot(data=df, x='Model', y='pLDDT')
plt.show()

In [None]:
! pip install seaborn matplotlib

In [None]:
# res_df = pd.DataFrame(results)
print(f"Finished {len(results)} chains")
display(results.describe().loc[["count", "mean", "min", "max"]])

In [None]:
df = pd.DataFrame({
    'pLDDT': results['pLDDT_before'] + results['pLDDT_after'],
    'Model': ['ESMFold'] * len(results['pLDDT_before']) + ['ESMFold+ProteinTTT'] * len(results['pLDDT_after'])
})

# Violin plot
sns.violinplot(data=df, x='Model', y='pLDDT')
plt.title('Распределение pLDDT до и после кастомизации')
plt.show()

In [None]:
import matplotlib.pyplot as plt

plt.scatter(results['pLDDT_before'], results['pLDDT_after'])
plt.plot([0, 100], [0, 100], 'r--', label='равенство')  
plt.xlabel('ESMFold')
plt.ylabel('ESMFold + ProteinTTT')
plt.title('Сравнение pLDDT по белкам')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
results = results.query("pLDDT_after.notna()")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

sns.kdeplot(results['pLDDT_before'], label='ESMFold', fill=True)
sns.kdeplot(results['pLDDT_after'], label='ESMFold + ProteinTTT', fill=True)
plt.xlabel('pLDDT')
plt.legend()
plt.show()