In [45]:
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from itertools import combinations
from Levenshtein import distance
from scipy.cluster.hierarchy import fcluster, linkage


# making clusters

In [46]:
mipe_train = pd.read_csv("/home/athenes/paratope_model/datasets/mipe/train_val.csv")
paragraph_train = pd.read_csv("/home/athenes/paratope_model/datasets/paragraph/train_set.csv")
pecan_train = pd.read_csv("/home/athenes/paratope_model/datasets/pecan/train_set.csv")
paragraph_val = pd.read_csv("/home/athenes/paratope_model/datasets/paragraph/val_set.csv")
pecan_val = pd.read_csv("/home/athenes/paratope_model/datasets/pecan/val_set.csv")
mipe_test = pd.read_csv("/home/athenes/paratope_model/datasets/mipe/test_set.csv")
paragraph_test = pd.read_csv("/home/athenes/paratope_model/datasets/paragraph/test_set.csv")
pecan_test = pd.read_csv("/home/athenes/paratope_model/datasets/pecan/test_set.csv")

concat_train = pd.concat([paragraph_train, pecan_train, mipe_train, paragraph_val, pecan_val, mipe_test, paragraph_test, pecan_test]).drop_duplicates("pdb")


In [47]:
from typing import Optional

import pandas as pd
from biopandas.pdb import PandasPdb

def read_pdb_to_dataframe(
    pdb_path: Optional[str] = None,
) -> pd.DataFrame:
    """
    Read a PDB file, and return a Pandas DataFrame containing the atomic coordinates and metadata.

    Args:
        pdb_path (str, optional): Path to a local PDB file to read. Defaults to None.

    Returns:
        pd.DataFrame: A DataFrame containing the atomic coordinates and metadata, removed of it's
        hydrogen atomswith one row per atom.
    """
    atomic_df = PandasPdb().read_pdb(pdb_path).df["ATOM"].query("element_symbol!='H'")
    atomic_df["IMGT"]=atomic_df["residue_number"].astype(str)+atomic_df["insertion"].astype(str)
    return atomic_df
amino_acid_dict = {
    "ALA": "A",
    "CYS": "C",
    "ASP": "D",
    "GLU": "E",
    "PHE": "F",
    "GLY": "G",
    "HIS": "H",
    "ILE": "I",
    "LYS": "K",
    "LEU": "L",
    "MET": "M",
    "ASN": "N",
    "PRO": "P",
    "GLN": "Q",
    "ARG": "R",
    "SER": "S",
    "THR": "T",
    "VAL": "V",
    "TRP": "W",
    "TYR": "Y",
}


In [48]:
records=[]
for pdb,heavy,light in tqdm(concat_train[['pdb','Hchain','Lchain']].values):
    if pdb in ["4fqi","4fqy","3gbn"]:
        continue
    hl_chains=[heavy,light]
    if not Path(f"/home/athenes/all_structures/imgt/{pdb}.pdb").exists():
        print(pdb)
        continue
    pdb_df = read_pdb_to_dataframe(f"/home/athenes/all_structures/imgt/{pdb}.pdb").query("chain_id in @hl_chains and atom_name=='CA'")
    triplets_heavy= pdb_df.query("chain_id==@heavy and residue_number<129")["residue_name"].tolist()
    sequence_heavy = "".join([amino_acid_dict[each] for each in triplets_heavy])
    triplets_light= pdb_df.query("chain_id==@light and residue_number<128")["residue_name"].tolist()
    sequence_light = "".join([amino_acid_dict[each] for each in triplets_light])
    records.append({"pdb":pdb,"sequences":sequence_heavy+sequence_light})
df = pd.DataFrame.from_records(records)
df.to_csv("/home/athenes/paratope_model/datasets/concats/sequences.csv")


  0%|          | 0/1504 [00:00<?, ?it/s]

100%|██████████| 1504/1504 [02:01<00:00, 12.40it/s]


In [49]:
df=pd.read_csv("/home/athenes/paratope_model/datasets/concats/sequences.csv")


In [50]:
df = df.drop_duplicates(subset="pdb")


In [51]:
print(len(df))


1503


# making train_sets

In [52]:

mipe_pdbs = set(mipe_test["pdb"].unique())
paragraph_pdbs = set(paragraph_test["pdb"].unique())
pecan_pdbs = set(pecan_test["pdb"].unique())
testset_pdbs=set.union(mipe_pdbs,pecan_pdbs,paragraph_pdbs)


In [53]:
def exclude_from_trainset(pdb_set):
    exclude_train_pdbs = []
    for pdb1, seq1 in tqdm(df[["pdb","sequences"]].values):
        for pdb2, seq2 in df.query("pdb in @pdb_set")[["pdb","sequences"]].values:
            if distance(seq1, seq2) / (np.mean([len(seq1), len(seq2)]))<0.05:
                exclude_train_pdbs.append(pdb1)
    return list(set(exclude_train_pdbs))


In [54]:
exclude_mipe = exclude_from_trainset(mipe_pdbs)
exclude_paragraph = exclude_from_trainset(paragraph_pdbs)
exclude_pecan = exclude_from_trainset(pecan_pdbs)


  0%|          | 0/1503 [00:00<?, ?it/s]

100%|██████████| 1503/1503 [00:02<00:00, 637.86it/s]
100%|██████████| 1503/1503 [00:04<00:00, 348.27it/s]
100%|██████████| 1503/1503 [00:03<00:00, 423.86it/s]


In [55]:
exclude_everything = set.union(set(exclude_paragraph), set(exclude_pecan), set(exclude_mipe))


In [56]:
from sklearn.model_selection import train_test_split

# Split the dataset into training and validation sets

train_val_everything = concat_train.query("pdb not in @exclude_everything and pdb!='4fqi' and pdb!='4fqy' and pdb!='3gbn'")
test_everything = concat_train.query("pdb in @exclude_everything")
train_everything, val_everything = train_test_split(
    train_val_everything,
    test_size=0.25,  # 10% for validation
    random_state=42  # For reproducibility
)


In [57]:
print(len(train_everything))


803


In [58]:
train_everything.to_csv("/home/athenes/paratope_model/datasets/concats/train_set.csv")


In [59]:
val_everything.to_csv("/home/athenes/paratope_model/datasets/concats/val_set.csv")


In [60]:
test_everything.to_csv("/home/athenes/paratope_model/datasets/concats/test_set.csv")
