# PrescientPLMFold
notebook adapted from [huggingface](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import pandas as pd
import numpy as np
import os
import torch.nn.functional as F

from lobster.model import PrescientPLMFold

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.backends.cuda.matmul.allow_tf32 = True

load the model

In [None]:
m = PrescientPLMFold(model_name="esmfold_v1")  # pre-trained
# m = PrescientPLMFold(model_name="PPLM")  # random initialization for training from scratch
m.cuda();

In [None]:
f"{sum(p.numel() for p in m.model.parameters() if p.requires_grad):,}"

In [None]:
m.model.esm = m.model.esm.half()

In [None]:
m.model.trunk.set_chunk_size(64)

predict coordinates and convert to pdb

In [None]:
test_protein = "MGAGASAEEKHSRELEKKLKEDAEKDARTVKLLLLGAGESGKSTIVKQMKIIHQDGYSLEECLEFIAIIYGNTLQSILAIVRAMTTLNIQYGDSARQDDARKLMHMADTIEEGTMPKEMSDIIQRLWKDSGIQACFERASEYQLNDSAGYYLSDLERLVTPGYVPTEQDVLRSRVKTTGIIETQFSFKDLNFRMFDVGGQRSERKKWIHCFEGVTCIIFIAALSAYDMVLVEDDEVNRMHESLHLFNSICNHRYFATTSIVLFLNKKDVFFEKIKKAHLSICFPDYDGPNTYEDAGNYIKVQFLELNMRRDVKEIYSHMTCATDTQNVKFVFDAVTDIIIKENLKDCGLF"

tokenized_input = m.tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']

In [None]:
tokenized_input = tokenized_input.cuda()

In [None]:
with torch.no_grad():
    output = m.model(tokenized_input)

In [None]:
output.keys()

In [None]:
m.model.config

In [None]:
# get hidden states from folding trunk
output.states.shape  # (num structure blocks, B, L, sequence hidden dim)

In [None]:
output.frames.shape

In [None]:
output.positions.shape  # (..., B, L, 14 (atom14), 3 (xyz))

In [None]:
output.plddt.shape  # (B, L, 37 (atom37))

In [None]:
pdb_file = m.model.output_to_pdb(output)[0]

In [None]:
with open('output.pdb', 'w') as f:
    f.write(pdb_file)

predict coordinates of a homodimer.

we insert a "linker" of flexible glycine residues between each chain we want to fold simultaneously, and then we offset the position IDs for each chain from each other, so that the model treats them as being very distant portions of the same long chain  

Tip: If you're trying to predict a multimeric structure and you're getting low-quality outputs, try varying the order of the chains (if it's a heteropolymer) or the length of the linker

In [None]:
sequence = "MRLIPLHNVDQVAKWSARYIVDRINQFQPTEARPFVLGLPTGGTPLKTYEALIELYKAGEVSFKHVVTFNMDEYVGLPKEHPESYHSFMYKNFFDHVDIQEKNINILNGNTEDHDAECQRYEEKIKSYGKIHLFMGGVGVDGHIAFNEPASSLSSRTRIKTLTEDTLIANSRFFDNDVNKVPKYALTIGVGTLLDAEEVMILVTGYNKAQALQAAVEGSINHLWTVTALQMHRRAIIVCDEPATQELKVKTVKYFTELEASAIRSVK"

linker = 'G' * 25

homodimer_sequence = sequence + linker + sequence

In [None]:
tokenized_homodimer = m.tokenizer([homodimer_sequence], return_tensors="pt", add_special_tokens=False)

In [None]:
# add a large offset to the position IDs of the second chain
with torch.no_grad():
    position_ids = torch.arange(len(homodimer_sequence), dtype=torch.long)
    position_ids[len(sequence) + len(linker):] += 512

In [None]:
tokenized_homodimer['position_ids'] = position_ids.unsqueeze(0)

tokenized_homodimer = {key: tensor.cuda() for key, tensor in tokenized_homodimer.items()}

In [None]:
with torch.no_grad():
    output = m.model(**tokenized_homodimer)

In [None]:
# remove the poly-G linker from the output, so we can display the structure as fully independent chains
linker_mask = torch.tensor([1] * len(sequence) + [0] * len(linker) + [1] * len(sequence))[None, :, None]

output['atom37_atom_exists'] = output['atom37_atom_exists'] * linker_mask.to(output['atom37_atom_exists'].device)

In [None]:
pdb_file = m.model.output_to_pdb(output)[0]

In [None]:
with open('homodimer.pdb', 'w') as f:
    f.write(pdb_file)

## data stuff

In [None]:
from lobster.data import FastaLightningDataModule
from lobster.tokenization import PmlmTokenizer
import importlib.resources

In [None]:
path = importlib.resources.files("lobster") / "assets" / '3di_tokenizer'

t = PmlmTokenizer.from_pretrained(path)
inputs = ['GdPfQaPfIlSvRvLvEcQvClGpId']

In [None]:
tokenized_inputs = t(inputs)

In [None]:
tokenized_inputs['input_ids']

In [None]:
out = t.decode(token_ids=tokenized_inputs['input_ids'][0], skip_special_tokens=True).replace(' ', '')

In [None]:
assert inputs[0] == out

In [None]:
dm = FastaLightningDataModule(path_to_fasta=[
    "/scratch/site/u/freyn6/data/fasta/pdb_3di.fasta",
],
                              tokenizer_dir='3di_tokenizer',
                             batch_size=4,
                             )
dm.setup(stage='fit')

In [None]:
batch = next(iter(dm.train_dataloader()))

In [None]:
dm._transform_fn._auto_tokenizer

In [None]:
batch['input_ids'][0]