In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import os
import requests
from io import StringIO
from Bio.PDB import PDBParser, PDBIO, Select
from tape.datasets import LMDBDataset
from collections import Counter
from functools import partial
from Bio.PDB import PDBParser, DSSP
from biotite.structure.io.pdb import PDBFile
from pathlib import Path
from foldingdiff.datasets import extract_pdb_code_and_chain
os.chdir('/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff')

class ChainSelect(Select):
    """
    A custom selection class for PDBIO that only writes out the specified chain.
    """
    def __init__(self, chain_id):
        self.chain_id = chain_id

    def accept_chain(self, chain_obj):
        # Only accept the chain with id matching the desired chain.
        if chain_obj.get_id() == self.chain_id:
            return 1
        else:
            return 0

def download_and_filter_pdb(pdb_code, chain, download_dir="pdb_files"):
    """
    Downloads the full PDB file for the given pdb_code, parses it with Biopython, and writes
    out only the structure corresponding to the specified chain.
    """
    os.makedirs(download_dir, exist_ok=True)
    filename = os.path.join(download_dir, f"{pdb_code}_{chain}.pdb")
    url = f"https://files.rcsb.org/download/{pdb_code}.pdb"
    response = requests.get(url)
    if response.status_code == 200:
        pdb_text = response.text
        # Parse the PDB content using a StringIO stream
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure(pdb_code, StringIO(pdb_text))
        
        # Prepare the PDB writer with our custom ChainSelect
        io = PDBIO()
        io.set_structure(structure)        
        io.save(filename, select=ChainSelect(chain))
        # print(f"Downloaded and filtered {pdb_code} chain {chain} successfully.")
    else:
        print(f"Error: Could not download {pdb_code}; status code {response.status_code}")


In [None]:
def process_sample(sample, download_dir):
    """Process one dataset sample: extract id, get the PDB code and download the file."""
    dataset_id = sample['id']
    # print(f"Processing: {dataset_id}")
    pdb_code, chain = extract_pdb_code_and_chain(dataset_id)
    download_and_filter_pdb(pdb_code, chain, download_dir)
    return dataset_id  # You may return any result you need

# Set an appropriate number of worker threads (adjust max_workers as needed)
max_workers = 100

def process_dataset(dataset_ids, download_dir):
    # If LMDBDataset isn’t a list (and doesn’t have __len__), consider converting it to a list first 
    # so that tqdm knows the total number of items. For example:
    dataset_ids = list(dataset_ids)
    process_sample_partial = partial(process_sample, download_dir=download_dir)
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Use executor.map to apply process_sample to each sample in the dataset
        # executor.map returns an iterator that produces results in order
        results = list(tqdm(executor.map(process_sample_partial, dataset_ids), total=len(dataset_ids)))

    # Optionally, process or log the collected results
    print("Finished processing samples.")

In [None]:
train = LMDBDataset(f'/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/data/remote_homology/remote_homology_train.lmdb')
counts = Counter((s['fold_label'] for s in train))
keep = set([k for k in counts if counts[k]>50])
print(len(keep))
for suffix in ['train','valid','test_family_holdout','test_fold_holdout','test_superfamily_holdout']:
    dataset_ids = LMDBDataset(f'/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/data/remote_homology/remote_homology_{suffix}.lmdb')
    dataset_ids = [s for s in dataset_ids if s['fold_label'] in keep]
    process_dataset(dataset_ids, f'/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/data/remote_homology/{suffix}_pdbs')



In [None]:
bad = []
for suffix in ['train','valid','test_family_holdout','test_fold_holdout','test_superfamily_holdout']:
    folder = f'/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/data/remote_homology/{suffix}_pdbs'
    for f in tqdm(os.listdir(folder)):
        fname = os.path.join(folder, f)
        if '1JBA_A' not in f:
            continue
        with open(str(fname), "rt") as f:
            source = PDBFile.read(f)
        source_struct = source.get_structure(model=1)        
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure(Path(fname).stem, fname)        
        model = structure[0]  # assuming you want the first model
        print(model)
        # try:
        #     dssp = DSSP(model, fname)        
        # except:
        #     print(fname)
        #     bad.append(fname)

In [None]:
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig

# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
# login()

# This will download the model weights and instantiate the model on your machine.
model: ESM3InferenceClient = ESM3.from_pretrained("esm3-open").to("cuda") # or "cpu"

# Generate a completion for a partial Carbonic Anhydrase (2vvb)
prompt = "___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIKTKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________"
protein = ESMProtein(sequence=prompt)
# Generate the sequence, then the structure. This will iteratively unmask the sequence track.
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7))
# We can show the predicted structure for the generated sequence.
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./generation.pdb")
# Then we can do a round trip design by inverse folding the sequence and recomputing the structure
protein.sequence = None
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
protein.coordinates = None
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./round_tripped.pdb")

In [None]:
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

protein = ESMProtein(sequence="AAAAA")
client = ESMC.from_pretrained("esmc_300m").to("cuda") # or "cpu"
protein_tensor = client.encode(protein)
logits_output = client.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output.logits, logits_output.embeddings)

In [None]:
from foldingdiff.angles_and_coords import *

canonical_distances_and_dihedrals('/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/data/remote_homology/test_superfamily_holdout_pdbs/1UX8_A.pdb')

In [None]:
import pickle
from foldingdiff.bpe import MyDataset
train, valid, test_datasets = pickle.load(open('/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/homo_datasets_with_test.pkl', 'rb'))


In [None]:
test_datasets = list(zip(('test_family', 'test_fold', 'test_superfamily'), test_datasets))

In [None]:
pickle.dump((train, valid,test_datasets), open('/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/homo_datasets_with_test.pkl', 'wb+'))

In [None]:
import pickle 

for f in ['1749673452.5647202', '1749673452.650502', '1749673453.07734', '1749696458.6950917', '1749696458.6925066', '1749696458.6264317']:
    i = 0
    while True:
        path = f'/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/ckpts/{f}/feats_100_{i}.pkl'
        if not os.path.exists(path):
            break
        feats = pickle.load(open(path, 'rb'))
        for prot_id, feat in feats.items():
            feat['fp'] = {key: val.cpu().numpy() for (key, val) in feat.items()}
            if 'foldseek' in feat:
                feat['foldseek'] = {key: val.cpu().numpy() for (key, val) in feat.items()}
            feats[prot_id] = feat
        pickle.dump(feats, open(path, "wb+"))

In [None]:
# create a csv
import pickle
import pandas as pd
# for name in ["conserved-site-prediction","CatBio","BindBio","CatInt","repeat-motif-prediction"]:
name = "remote-homology-detection"
path = f"/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/data/struct_token_bench/processed_pickles/{name}.pkl"
train, val, test = pickle.load(open(path, 'rb'))
if "homo" in name:
    keys = ["pdb_id", "chain_id", "fold_label"]
else:
    keys = ["pdb_id", "chain_id", "label_type", "residue_label"]


def fill_keys(item):
    assert "id" in item
    pdb_id, chain_id = extract_pdb_code_and_chain(item["id"])
    item["pdb_id"] = pdb_id
    item["chain_id"] = chain_id    

rows = []
for i, item in enumerate(train):
    if "homo" in name:
        fill_keys(item)
    rows.append({"split": "train",
                "idx": i} | {key: item[key] for key in keys})
for i, item in enumerate(val):
    if "homo" in name:
        fill_keys(item)
    rows.append({"split": "valid",
                "idx": i} | {key: item[key] for key in keys})    
for k, test_k in test:
    for i, item in enumerate(test_k):
        if "homo" in name:
            fill_keys(item)
        rows.append({"split": k,
                    "idx": i} | {key: item[key] for key in keys})                    
df = pd.DataFrame(rows)
df.to_csv(f"/n/holylfs06/LABS/mzitnik_lab/Users/msun415/foldingdiff/data/struct_token_bench/processed_csvs/{name}.csv", index=False)
print(f"saved {name}")