In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import pandas as pd
import torch

from Bio import SeqIO
from scipy.stats import spearmanr
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [3]:
import sys

prosst_path = "/home/rcalef/sandbox/repos/magneton/magneton/external/ProSST"
sys.path.append(prosst_path)

In [4]:
from prosst.structure.get_sst_seq import SSTPredictor, init_shared_pool

In [5]:
init_shared_pool(8)

In [7]:
example_pdb = prosst_path + "/example_data/p1.pdb"

#example_pdb = "/weka/scratch/weka/kellislab/rcalef/data/pdb_alphafolddb/AF-A1RZJ9-F1-model_v4.pdb"
predictor = SSTPredictor(structure_vocab_size=2048) # can be 20, 128, 512, 1024, 2048, 4096
result = predictor.predict_from_pdb(example_pdb)

---------- Load Model on cuda ----------
MODEL: 5.90M parameters
---------- Building Subgraphs ----------


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.11it/s]
100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.38s/it]


In [9]:
len(result[0]["2048_sst_seq"])

75

In [10]:
len(result[0]["aa_seq"])

75

### Scoring example from repo

In [5]:
model_path = "/home/rcalef/storage/om_storage/model_weights/ProSST-2048"
model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [17]:
model.cls

ProSSTOnlyMLMHead(
  (predictions): ProSSTLMPredictionHead(
    (transform): ProSSTPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=25, bias=False)
  )
)

In [12]:
len(model.prosst.encoder.layer)

12

In [23]:
residue_sequence = str(SeqIO.read(prosst_path + '/zero_shot/example_data/GRB2_HUMAN_Faure_2021.fasta', 'fasta').seq)

In [24]:
structure_sequence = predictor.predict_from_pdb(prosst_path + "/zero_shot/example_data/GRB2_HUMAN_Faure_2021.pdb")[0]['2048_sst_seq']


---------- Building Subgraphs ----------


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.14it/s]
100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.57it/s]


In [25]:
structure_sequence_offset = [i + 3 for i in structure_sequence]


In [26]:
tokenized_res = tokenizer([residue_sequence], return_tensors='pt')
input_ids = tokenized_res['input_ids']
attention_mask = tokenized_res['attention_mask']
structure_input_ids = torch.tensor([1, *structure_sequence_offset, 2], dtype=torch.long).unsqueeze(0)

In [32]:
with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ss_input_ids=structure_input_ids,
        output_hidden_states=True,
    )
logits = torch.log_softmax(outputs.logits[:, 1:-1], dim=-1).squeeze()

In [28]:

df = pd.read_csv(prosst_path + "/zero_shot/example_data/GRB2_HUMAN_Faure_2021.csv")
mutants = df['mutant'].tolist()

In [29]:
vocab = tokenizer.get_vocab()
pred_scores = []
for mutant in mutants:
    mutant_score = 0
    for sub_mutant in mutant.split(":"):
        wt, idx, mt = sub_mutant[0], int(sub_mutant[1:-1]) - 1, sub_mutant[-1]
        pred = logits[idx, vocab[mt]] - logits[idx, vocab[wt]]
        mutant_score += pred.item()
    pred_scores.append(mutant_score)

In [30]:
spearmanr(pred_scores, df['DMS_score'])

SignificanceResult(statistic=0.7182950414920266, pvalue=0.0)

In [35]:
outputs.hidden_states

(tensor([[[-0.2485, -0.1342, -0.0800,  ...,  0.3164,  0.1432,  0.0799],
          [-0.0896, -0.1648,  0.4901,  ...,  0.1394,  0.0144,  0.1983],
          [ 0.1657, -0.1281,  0.3248,  ..., -0.0207,  0.0780, -0.5081],
          ...,
          [ 0.2166,  0.0608,  0.2667,  ..., -0.0236,  0.2734,  0.4715],
          [ 0.2006, -0.2071,  0.8953,  ..., -0.2292, -0.1639,  0.2802],
          [-0.1214,  0.0263,  0.0887,  ...,  0.0862,  0.0719, -0.0068]]]),
 tensor([[[-2.5764e-03, -1.0332e-01,  4.5247e-02,  ..., -3.8302e-03,
           -1.0434e-01, -1.3491e-02],
          [ 1.0533e-01,  3.2962e-03,  4.6990e-02,  ...,  3.2873e-01,
            2.7378e-01, -4.6533e-02],
          [ 6.9737e-02, -9.1406e-02,  1.9271e-01,  ...,  1.2026e-01,
           -1.9264e-01, -4.4676e-02],
          ...,
          [ 1.4992e-02,  2.1682e-02,  8.5999e-02,  ..., -5.0822e-02,
            1.1033e-01, -8.3667e-02],
          [ 6.6207e-02, -1.1911e-01,  9.3116e-02,  ..., -5.9162e-02,
            1.3843e-01, -1.0879e-04],


In [13]:
aa_toks = tokenizer(result[0]["aa_seq"], return_tensors="pt")
aa_toks

{'input_ids': tensor([[ 1, 13, 17,  3,  6, 12, 15, 11, 17, 20, 20, 20,  6, 17,  8,  3, 12, 16,
          7, 12, 15,  6, 20, 12, 17,  6, 12,  8,  4, 18, 11, 19, 20, 20, 20, 19,
          5, 18,  8, 20, 21, 18, 20, 20,  8, 18, 20, 20,  6,  8,  3, 12, 17,  8,
         12,  3, 22,  6, 20, 20, 22, 10,  6,  3,  3,  5, 14, 18, 14, 20,  6, 17,
          3, 17, 18,  3,  3, 17, 17, 20,  6,  3,  4,  3, 20,  3,  8, 12,  8,  8,
          8, 17, 15, 20,  5, 20,  3, 11, 22,  3,  3,  7, 13,  6,  8, 12, 15,  7,
         20, 18, 20, 15, 19,  3, 10, 18,  9,  5,  8,  7,  3, 18, 15, 10, 20,  3,
         12, 11,  5, 15,  6,  8, 14, 15, 12, 18, 10,  7, 19, 17, 15, 15,  3,  3,
         20, 12, 20,  5, 12,  3, 20, 20, 18, 17,  3, 15, 17, 17, 12, 12,  3, 18,
          8, 20,  8,  5, 10, 20,  8, 11, 20, 19, 18, 20,  3,  5,  3, 17, 12,  3,
         16, 17, 12, 19,  8,  6,  6, 20, 15,  6, 20,  3, 12, 17, 13,  3,  6, 19,
          3,  3, 17, 13, 20, 12,  5,  6, 20,  5,  6, 10,  3, 18, 21, 19,  6, 17,
          8, 2

In [14]:
aa_toks["input_ids"].shape

torch.Size([1, 349])

In [15]:
ss_toks = torch.tensor(result[0]["2048_sst_seq"])
ss_toks.shape

torch.Size([347])

In [39]:
torch.cat([1], ss_toks, [2])

TypeError: cat() received an invalid combination of arguments - got (list, Tensor, list), but expected one of:
 * (tuple of Tensors tensors, int dim = 0, *, Tensor out = None)
 * (tuple of Tensors tensors, name dim, *, Tensor out = None)


In [29]:
aa_toks["input_ids"].shape

torch.Size([1, 347])

In [31]:
# def forward(
#     self,
#     input_ids: Optional[torch.Tensor] = None,
#     ss_input_ids: Optional[torch.Tensor] = None,
#     attention_mask: Optional[torch.Tensor] = None,
#     token_type_ids: Optional[torch.Tensor] = None,
#     position_ids: Optional[torch.Tensor] = None,
#     inputs_embeds: Optional[torch.Tensor] = None,
#     labels: Optional[torch.Tensor] = None,
#     output_attentions: Optional[bool] = None,
#     output_hidden_states: Optional[bool] = None,
#     return_dict: Optional[bool] = None,
# ) -> Union[Tuple, MaskedLMOutput]:
with torch.inference_mode():
    out = model(
        input_ids=aa_toks["input_ids"],
        ss_input_ids=ss_toks,
        attention_mask=aa_toks["attention_mask"],
    )
out

MaskedLMOutput(loss=None, logits=tensor([[[-18.3368, -12.2662, -13.9507,  ...,   0.8657, -15.6547, -18.3795],
         [-18.8052, -12.7068, -15.0800,  ...,   0.6976, -16.0266, -18.8201],
         [-18.6050, -12.5979, -14.4972,  ...,   0.2606, -15.8135, -18.6339],
         ...,
         [-18.6463, -12.9135, -14.2035,  ...,   0.7077, -15.9107, -18.6853],
         [-18.7029, -13.0057, -14.9421,  ...,   0.2711, -16.3094, -18.7389],
         [-19.0306, -13.0453, -15.0684,  ...,   0.3415, -16.1042, -19.0603]]]), hidden_states=None, attentions=None)

In [None]:
model = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)

# Test modules

In [6]:
import os

from magneton.config import (
    DataConfig,
    EmbeddingConfig,
    ModelConfig,
    TrainingConfig,
    PipelineConfig,
)
from magneton.data import MagnetonDataModule
from magneton.embedders.prosst_embedder import ProSSTConfig, ProSSTEmbedder
from magneton.training.embedding_mlp import EmbeddingMLP, MultitaskEmbeddingMLP

INFO:rdkit:Enabling RDKit 2023.09.6 jupyter extensions


## Load data

In [7]:
interpro_path = "/weka/scratch/weka/kellislab/rcalef/data/interpro/103.0/"
fasta_path = "/weka/scratch/weka/kellislab/rcalef/data/uniprot/uniprot_sprot.fasta.gz"
labels_path = "/weka/scratch/weka/kellislab/rcalef/data/interpro/103.0/label_sets/selected_subset"
pickle_path = os.path.join(interpro_path, "swissprot", "sharded_swissprot", "with_ss", "debug_datasets")
#pickle_path = os.path.join(interpro_path, "swissprot", "sharded_swissprot", "with_ss", "dataset_splits", "seq_splits")

In [8]:
model_path = "/home/rcalef/storage/om_storage/model_weights/ProSST-2048"

data_config = DataConfig(
    data_dir=pickle_path,
    prefix="swissprot.with_ss",
    fasta_path=fasta_path,
    labels_path=labels_path,
    substruct_types=["Domain"],
    collapse_labels=True,
    struct_template="/weka/scratch/weka/kellislab/rcalef/data/pdb_alphafolddb/AF-%s-F1-model_v4.pdb",
    batch_size=2,
)

In [9]:
module = MagnetonDataModule(
    data_config=data_config,
    model_type="prosst",
)

In [10]:
loader = module.val_dataloader()

100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.42s/it]
INFO:magneton.data.model_specific.prosst:ProSST tokens file found at: /weka/scratch/weka/kellislab/rcalef/data/interpro/103.0/swissprot/sharded_swissprot/with_ss/debug_datasets/prosst_toks.tsv.bz2


11:30:03   ProSST tokens file found at: /weka/scratch/weka/kellislab/rcalef/data/interpro/103.0/swissprot/sharded_swissprot/with_ss/debug_datasets/prosst_toks.tsv.bz2
/home/rcalef/storage/om_storage/model_weights/ProSST-2048


INFO:magneton.data.model_specific.prosst:read ProSST structure tokens for 132347 proteins


11:30:43   read ProSST structure tokens for 132347 proteins


In [11]:
batch = next(iter(loader))
batch

ProSSTBatch(protein_ids=['A0A024FA41', 'A0A059J0G5'], lengths=[365, 1567], seqs=None, substructures=[[LabeledSubstructure(ranges=[tensor([189, 335])], label=421, element_type='Domain')], [LabeledSubstructure(ranges=[tensor([525, 742]), tensor([1212, 1420])], label=551, element_type='Domain'), LabeledSubstructure(ranges=[tensor([ 95, 175])], label=735, element_type='Domain'), LabeledSubstructure(ranges=[tensor([167, 432]), tensor([ 891, 1134])], label=282, element_type='Domain')]], structure_list=None, prot_mask=None, labels=None, tokenized_seq=tensor([[ 1, 13, 15,  ...,  0,  0,  0],
        [ 1, 13,  3,  ..., 10, 22,  2]]), tokenized_struct=tensor([[   1, 1233, 1233,  ...,    0,    0,    0],
        [   1,  943,  943,  ..., 1060, 1136,    2]]))

## Make model

In [12]:
model_config = ModelConfig(
    model_type="embedding_mlp",
    model_params={
        "hidden_dims": [256, 256],
        "dropout_rate": 0.1,
    },
    frozen_embedder=False,
)
prosst_config = ProSSTConfig(
    weights_path=model_path,
)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [20]:
embedder = ProSSTEmbedder(prosst_config, frozen=False)

In [21]:
embedder = embedder.to(device)
batch = batch.to(device)

In [22]:
out = embedder.embed_batch(batch)

In [23]:
out

tensor([[[ 2.5074e-02,  2.0554e-02,  2.9464e-03,  ...,  1.3994e-02,
           4.9377e-02, -1.4424e-02],
         [ 1.3344e-02,  7.8973e-03, -5.8249e-03,  ...,  9.2792e-03,
          -2.6190e-02, -1.2555e-03],
         [ 6.4270e-03,  1.0599e-02, -7.7542e-03,  ..., -1.3719e-03,
          -2.4192e-02, -1.6272e-02],
         ...,
         [-4.8919e-03,  5.6617e-03,  1.3259e-03,  ...,  1.7139e-03,
           7.8884e-03,  5.8250e-03],
         [-4.8919e-03,  5.6617e-03,  1.3259e-03,  ...,  1.7139e-03,
           7.8884e-03,  5.8250e-03],
         [-4.8919e-03,  5.6617e-03,  1.3259e-03,  ...,  1.7139e-03,
           7.8884e-03,  5.8250e-03]],

        [[ 4.1353e-02,  2.0393e-02,  2.4705e-02,  ...,  4.9594e-03,
           2.2046e-02,  3.6619e-03],
         [ 1.7102e-02,  4.9413e-03,  1.4392e-03,  ..., -4.4777e-03,
          -3.6532e-02, -1.0120e-02],
         [-2.2535e-03,  1.4751e-03, -1.0967e-02,  ..., -1.0333e-02,
           4.5706e-05, -1.5181e-02],
         ...,
         [ 2.9947e-03, -1

In [24]:
out.sum().backward()

In [25]:
for name, p in embedder.named_parameters():
    if p.requires_grad and p.grad is None:
        print(name)

In [29]:
embedder.model.prosst.encoder.layer[0].attention.self.pos_att_type

['aa2pos', 'pos2aa', 'aa2ss', 'ss2aa']

In [53]:
embedder.model.prosst

ProSSTModel(
  (embeddings): ProSSTEmbeddings(
    (word_embeddings): Embedding(25, 768, padding_idx=0)
    (ss_embeddings): Embedding(2051, 768)
    (ss_layer_norm): ProSSTLayerNorm()
    (LayerNorm): ProSSTLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): ProSSTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ProSSTLayer(
        (attention): ProSSTAttention(
          (self): DisentangledSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (pos_dropout): Dropout(p=0.1, inplace=False)
            (pos_proj): Linear(in_features=768, out_features=768, bias=False)
            (pos_q_proj): Linear(in_features=768, out_features=768, bias=True)
            (ss_proj): Linear(in_features=768, out_features=768, bias=False)
            (ss_q_proj): Linear(in_features=768, out

In [39]:
embedder.model.prosst.encoder.layer[10].attention.output.dense.weight

Parameter containing:
tensor([[-0.0365, -0.0922, -0.0679,  ...,  0.0392,  0.1228,  0.0718],
        [ 0.0143, -0.0194, -0.0404,  ..., -0.0288, -0.0470,  0.0308],
        [-0.0961, -0.0490, -0.0238,  ..., -0.0414, -0.0932,  0.1191],
        ...,
        [-0.0211, -0.0622,  0.0514,  ..., -0.0565,  0.0298,  0.0812],
        [ 0.0134, -0.0049, -0.0624,  ...,  0.0704, -0.0666, -0.0145],
        [ 0.0245, -0.0542, -0.0725,  ..., -0.0272,  0.0313,  0.0244]],
       requires_grad=True)

In [21]:
embedder.model.get_output_embeddings().in_features

768

In [None]:
embedder.model.

In [80]:
embedder = embedder.to(device)

In [49]:
batch = batch.to(device)

In [33]:
batch.tokenized_seqs.shape

torch.Size([2, 778])

In [81]:
out = embedder.embed_batch(batch)
out

torch.Size([2, 1163])
torch.Size([2, 1163])
torch.Size([2, 1163])


tensor([[[ 1.0127,  0.5458, -0.2287,  ..., -0.0958, -0.0664,  0.5374],
         [ 0.2572,  0.0957,  0.1295,  ..., -0.0222,  0.4835, -0.0405],
         [ 0.3038,  0.0396,  0.0052,  ..., -0.1692, -0.0866,  0.4359],
         ...,
         [-0.0923,  0.1068,  0.0250,  ...,  0.0323,  0.1488,  0.1099],
         [-0.0923,  0.1068,  0.0250,  ...,  0.0323,  0.1488,  0.1099],
         [-0.0923,  0.1068,  0.0250,  ...,  0.0323,  0.1488,  0.1099]],

        [[ 0.3875, -0.2288, -0.1128,  ..., -0.1080, -0.0196,  0.1385],
         [ 0.3325,  0.4462,  0.2209,  ..., -0.2696,  0.6936,  0.0857],
         [-0.1536,  0.0761,  0.1590,  ..., -0.3747, -0.2730, -0.0158],
         ...,
         [ 0.2958, -0.5699,  0.3141,  ...,  0.0576,  0.1312,  0.0399],
         [ 0.4738, -0.1931,  0.2375,  ...,  0.0573,  0.0022,  0.3500],
         [ 0.9418, -0.3425,  0.0878,  ..., -0.4024, -0.3270,  0.1994]]],
       device='cuda:0')

In [82]:
out.shape

torch.Size([2, 1163, 768])

In [56]:
len(out.hidden_states)

13

In [58]:
out.hidden_states[12].shape

torch.Size([2, 1163, 768])

In [62]:
import torch.nn.functional as F

In [75]:
got = F.normalize(out.hidden_states[12], dim=-1)

In [76]:

check = got.norm(dim=-1)
check

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0')

In [77]:
check.shape

torch.Size([2, 1163])

In [61]:
torch.norm()

ProSSTModel(
  (embeddings): ProSSTEmbeddings(
    (word_embeddings): Embedding(25, 768, padding_idx=0)
    (ss_embeddings): Embedding(2051, 768)
    (ss_layer_norm): ProSSTLayerNorm()
    (LayerNorm): ProSSTLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): ProSSTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ProSSTLayer(
        (attention): ProSSTAttention(
          (self): DisentangledSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (pos_dropout): Dropout(p=0.1, inplace=False)
            (pos_proj): Linear(in_features=768, out_features=768, bias=False)
            (pos_q_proj): Linear(in_features=768, out_features=768, bias=True)
            (ss_proj): Linear(in_features=768, out_features=768, bias=False)
            (ss_q_proj): Linear(in_features=768, out

In [26]:
embedder.encoder_model.W_v[0].scalar_norm.weight

Parameter containing:
tensor([0.3121, 0.3089, 0.3078, 0.3285, 0.2895, 0.2951, 0.3388, 0.2945, 0.2910,
        0.3344, 0.3081, 0.3292, 0.3346, 0.2586, 0.2828, 0.2632, 0.2735, 0.2689,
        0.3046, 0.3257], device='cuda:0', requires_grad=True)

In [None]:
pipeline_config = PipelineConfig(
    data=data_config,
    model=model_config,
    training=train_config,
    embedding=EmbeddingConfig(
        model="esmc",
        model_params=esmc_config.__dict__,
    ),
)

# Train model
if data_config.collapse_labels:
    mlp = EmbeddingMLP(
        config=pipeline_config,
        num_classes=num_labels,
    )
else:
    mlp = MultitaskEmbeddingMLP(
        config=pipeline_config,
        num_classes=num_labels,
    )


=== Creating esmc embedder ===
EmbeddingConfig(_target_='magneton.config.EmbeddingConfig',
                model='esmc',
                batch_size=32,
                model_params={'mask_prob': 0.15,
                              'max_seq_length': 2048,
                              'model_size': '600m',
                              'rep_layer': 33,
                              'use_flash_attn': True,
                              'weights_path': '/weka/scratch/weka/kellislab/rcalef/model_weights/esmc-600m-2024-12'})
Config parameters: None
