# 3D-Prescient Protein Language Model
* Obtain structural embeddings of a sequence by folding the sequence -> discretizing the structure into a 3Di representation -> embed the (AA+3Di) tokens with 3D-PPLM

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import pandas as pd
import numpy as np
import os
import torch.nn.functional as F

import os
from shutil import which

from prescient_plm.model import PrescientPMLM
from prescient_plm.transforms import FoldseekTransform

device = 'cuda' if torch.cuda.is_available() else 'cpu'

load the model and 3D tokenizer

In [None]:
model = PrescientPMLM.load_from_checkpoint(
    's3://prescient-pcluster-data/freyn6/models/pmlm/prod/2023-10-27T15-58-01.675608/epoch=13-step=469494-val_loss=1.5488.ckpt'
)
model.eval();

In [None]:
f"{model.num_trainable_parameters:,}"

install Foldseek and add it to your conda env bin following the instructions [here](https://code.roche.com/prescient-design/manifold-sampler-pytorch/-/tree/master/prescient_plm?ref_type=heads#tokenizable-structural-descriptors-foldseek)

In [None]:
bin_path = "/home/freyn6/miniconda3/envs/prescient-plm/bin"
os.environ["PATH"] += os.pathsep + bin_path  # add to path to run foldseek in jupyter

In [None]:
foldseek = which("foldseek")
foldseek

### monomer
transform a single sequence and embed it with 3D-PPLM

In [None]:
example_sequence = "GYDPETGTWG"

In [None]:
foldseek_transform = FoldseekTransform(foldseek=foldseek,
                                        pplm_fold_model_name="esmfold_v1", 
                                       linker_length=3)  # specify if you have a dimer complex

In [None]:
# AA single chain -> AA+3Di
seq_dict = foldseek_transform.transform(sequences=[example_sequence])

In [None]:
# {"chain id": (AA seq, 3Di seq, AA+3Di seq)}
seq_dict["A"]

In [None]:
tokens_3d = seq_dict["A"][2]
h = model.sequences_to_latents([tokens_3d])

In [None]:
len(h), h[-1].shape  # 9 blocks, hidden rep of dim (1024, 384)

### dimer (complex)
transform a dimer and embed it with 3D-PPLM  
e.g., example_dimer = [Concatenated antibody Vh+Vl, Antigen sequence]

In [None]:
example_dimer = [example_sequence, example_sequence[::-1]]
example_dimer

In [None]:
dimer_seq_dict = foldseek_transform.transform(sequences=[example_dimer])

In [None]:
dimer_seq_dict["A"]

In [None]:
dimer_tokens_3d = dimer_seq_dict["A"][2]
h_dimer = model.sequences_to_latents([dimer_tokens_3d])

In [None]:
h_dimer[-1]

# PDB embeddings

In [None]:
from Bio import SeqIO

In [None]:
with open('/scratch/site/u/freyn6/data/fasta/pdb_3di_complexes.fasta') as fasta_file:
    identifiers = []
    sequences = []
    for seq_record in SeqIO.parse(fasta_file, 'fasta'):
        identifiers.append(seq_record.id)
        sequences.append(str(seq_record.seq))

In [None]:
model.to(device);

In [None]:
embeddings = model.sequences_to_latents(sequences[:5])[-2:]

In [None]:
len(embeddings)

In [None]:
embeddings_all = model.sequences_to_latents(sequences[:5])

In [None]:
torch.equal(embeddings[-2], embeddings_all[-2])

In [None]:
e0 = embeddings[-2][0]
e1 = embeddings[-2][1]

In [None]:
e0.shape, e1.shape

In [None]:
cos = torch.nn.CosineSimilarity(dim=0)

In [None]:
output = cos(e0.mean(dim=1), embeddings[-2][4].mean(dim=1))
output