<a href="https://colab.research.google.com/github/tijeco/berteome/blob/14-esm-hugging-face/notebooks/scratch/03_general_approach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 27.8 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 48.5 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 62.9 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1


So, I'm pretty sure since bert and esm are available through hugging face, I can now make a somewhat more generalizable approach to working with the models. 

The first big issue to think about is loading the various tokenizers and maskedLM from the transformers library. The options seems like as follows:

1. Load all the libraries at once
2. Have each of the libraries be associated with different files that are loaded as needed
3. Have a function that somehow loads the library on command???

The third one sounds impossible, I'll try that first! Yep probably impossible. I think just loading bert and ESM should be fine for now?? It's only four things to import..

In [7]:
from transformers import BertTokenizer, BertForMaskedLM, EsmTokenizer, EsmForMaskedLM
import torch
import pandas as pd

In [8]:
class modelPredDF():
    def __init__(self, predDict, seq, aas):
        self.predDf = pd.DataFrame.from_dict(predDict, orient = "index", columns = list(aas))
        self.predDf = self.predDf.div(self.predDf.sum(axis=1),axis=0)
        self.predDf.insert(0, "wt",list(seq))
        self.predDf.insert(1, "wtIndex",list(range(1,len(seq)+1)))
        wtScore = self.wtScoreCol()
        self.predDf.insert(2, "wtScore",wtScore)

    def wtScoreCol(self):
        wtScore = []
        for row in self.predDf.to_dict(orient="records"):
	        wt = row["wt"]
	        wtScore.append(row[wt])
        return wtScore

In [9]:
def load_model(model_path, tokenizerLM, maskedLM):
  tokenizer = tokenizerLM.from_pretrained(model_path)
  model = maskedLM.from_pretrained(model_path)
  return tokenizer, model

So as far as a good entrypoint for where users make decisions as to what model they use, this seems like as good of a place as ever! So as it stands, the user would have to provide the model_path, tokenizer name and maskedLM name.. I don't like that! 

I think what we'll eventually do here is just put all that info in somesort of datastructure, to simplify the amount of input needed from the user. Off the top of my head, it seems like just querying "ESM" or "ESM1b" might be a bit too vague. I think having the actual model path as the key values would be useful?? That way the user just needs to know the path of the model they want to use (probably important to be able to know that pretty specifically), then they should be off to the races from there! Of course, running multiple different models should be just as easy as making new variables with updated model paths!

In [10]:
def run_model(model, inputs):
  with torch.no_grad():
    logits = model(**inputs).logits
  return logits

In [11]:
def logits2prob(logits):
  return torch.softmax(logits,dim=2)

In [90]:
def maskifySeq(seq, tokenizer, i):
    seqList = list(seq)
    if i != None:
      seqList[i] = tokenizer.mask_token 
    return " ".join(seqList)

In [89]:
def tokenizeSeq(seq, tokenizer, mask_index = None, return_tensors = "pt"):
  maskified_seq = maskifySeq(seq, tokenizer, mask_index)
  return tokenizer(maskified_seq, return_tensors=return_tensors)

In [95]:
def naturalAAIndex(aas, tokenizer):
    return tokenizeSeq(aas, tokenizer, return_tensors=None)["input_ids"][1:-1]

In [87]:
def getNatProbs(natAAList,probList):
    natProbList = []
    for natAAIndex in natAAList:
      natProbList.append(probList[natAAIndex])
    return natProbList

In [86]:
def predictionDF(seq, tokenizer, model, aas = "ACDEFGHIKLMNPQRSTVWY"):
  naturalAAIndices = naturalAAIndex(aas,tokenizer)
  predDict = {}
  for wtIndex in range(len(seq)):
    maskedSeq = tokenizeSeq(seq, tokenizer, mask_index = wtIndex)
    seq_logits = run_model(model, maskedSeq)
    seq_probs = logits2prob(seq_logits)
    predDict[wtIndex] = [i.item() for i in getNatProbs(naturalAAIndices, seq_probs[0, wtIndex +1])]
  predDF = modelPredDF(predDict, seq, aas).predDf
  return predDF

So a key part to generalizing this is knowing when to use space separation. I guess, it's still a question if esm can work with it being space separated or not??

In [21]:
esm2_tokenizer, esm2_model = load_model("facebook/esm1b_t33_650M_UR50S",EsmTokenizer,EsmForMaskedLM)

Some weights of the model checkpoint at facebook/esm1b_t33_650M_UR50S were not used when initializing EsmForMaskedLM: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [24]:
bert_tokenizer, bert_model = load_model("Rostlab/prot_bert",BertTokenizer,BertForMaskedLM)

Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/361 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [93]:
tokenizeSeq("MENDEL", bert_tokenizer, mask_index=3)

{'input_ids': tensor([[ 2, 21,  9, 17,  4,  9,  5,  3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [28]:
run_model(bert_model, mendel_mask3)

tensor([[[-1.9025e+01, -1.9418e+01, -1.9070e+01, -2.0259e+01, -2.1078e+01,
           6.8966e-01, -5.8320e-02,  9.9683e-02,  1.1744e-01,  9.2746e-01,
           1.7481e-01,  2.6523e-01,  1.0348e+00,  3.9002e-01, -1.5987e-01,
           9.6185e-02, -7.2865e-01,  2.0590e-01,  1.6183e-01, -6.3175e-01,
          -8.3162e-01,  2.3495e-01, -8.9098e-01, -1.4568e+00, -1.2295e+00,
          -5.1355e+00, -1.8929e+01, -1.8691e+01, -1.8911e+01, -1.9173e+01],
         [-1.8777e+01, -1.9759e+01, -1.9174e+01, -1.7491e+01, -2.1409e+01,
           4.5803e-01, -1.3080e-01, -4.1512e-01, -8.9476e-02,  4.4306e-01,
          -2.1111e-01, -2.7136e-01,  1.9480e-01, -3.5983e-01, -4.5520e-01,
          -4.6519e-01, -1.1010e+00, -3.5952e-01, -4.2683e-01, -8.3092e-01,
          -1.0543e+00,  3.9943e+00, -1.4092e+00, -1.9026e+00, -1.7119e+00,
          -3.4918e+00, -1.7906e+01, -1.8299e+01, -1.9515e+01, -1.9495e+01],
         [-2.1036e+01, -2.1122e+01, -2.1379e+01, -1.9386e+01, -2.2437e+01,
           3.5269e-01, 

In [96]:
naturalAAIndex("ACDEFGHIKLMNPQRSTVWY",bert_tokenizer)

[6, 23, 14, 9, 19, 7, 22, 11, 12, 5, 21, 17, 16, 18, 13, 10, 15, 8, 24, 20]

In [97]:
naturalAAIndex("ACDEFGHIKLMNPQRSTVWY",esm2_tokenizer)

[5, 23, 13, 9, 18, 6, 21, 12, 15, 4, 20, 17, 14, 16, 10, 8, 11, 7, 22, 19]

In [98]:
predictionDF("MENDEL", bert_tokenizer, bert_model)

Unnamed: 0,wt,wtIndex,wtScore,A,C,D,E,F,G,H,...,M,N,P,Q,R,S,T,V,W,Y
0,M,1,0.076602,0.036697,0.011504,0.048245,0.118906,0.024072,0.039202,0.012621,...,0.076602,0.072661,0.024722,0.038672,0.043105,0.07028,0.056544,0.049927,0.007781,0.021699
1,E,2,0.07483,0.045721,0.015662,0.041921,0.07483,0.037153,0.044325,0.018264,...,0.043581,0.062667,0.025277,0.036911,0.055543,0.064425,0.049955,0.056789,0.012691,0.029893
2,N,3,0.04199,0.043564,0.009685,0.16259,0.184364,0.033782,0.044661,0.012355,...,0.041484,0.04199,0.019992,0.025515,0.029433,0.048106,0.030303,0.054742,0.00743,0.024924
3,D,4,0.049748,0.042083,0.013244,0.049748,0.086194,0.039736,0.055911,0.016861,...,0.04008,0.060822,0.032024,0.039689,0.046228,0.062323,0.044901,0.058937,0.010875,0.026596
4,E,5,0.086915,0.046641,0.01877,0.079822,0.086915,0.050638,0.050466,0.022397,...,0.028962,0.062234,0.023879,0.030534,0.040489,0.065195,0.044938,0.068038,0.012156,0.038034
5,L,6,0.060736,0.038191,0.009217,0.065189,0.152547,0.02095,0.049525,0.013955,...,0.040042,0.096484,0.020712,0.035022,0.046888,0.049071,0.046247,0.048276,0.010486,0.022727


In [100]:
predictionDF("MENDEL", esm2_tokenizer, esm2_model)

Unnamed: 0,wt,wtIndex,wtScore,A,C,D,E,F,G,H,...,M,N,P,Q,R,S,T,V,W,Y
0,M,1,0.464699,0.034593,0.007172,0.055771,0.064563,0.018942,0.029819,0.010409,...,0.464699,0.029717,0.023785,0.019403,0.024049,0.03074,0.023793,0.034624,0.004882,0.013323
1,E,2,0.072099,0.059034,0.021043,0.054654,0.072099,0.037208,0.053121,0.025863,...,0.024344,0.059745,0.035749,0.040393,0.052331,0.073661,0.057352,0.063012,0.014037,0.030357
2,N,3,0.044189,0.055648,0.014066,0.077918,0.122202,0.034741,0.058845,0.019178,...,0.03133,0.044189,0.02968,0.034349,0.052393,0.057417,0.047451,0.06946,0.013287,0.025869
3,D,4,0.036312,0.044627,0.017634,0.036312,0.071052,0.031641,0.0498,0.023527,...,0.058537,0.043779,0.037241,0.054184,0.05514,0.060113,0.052931,0.070151,0.017574,0.027221
4,E,5,0.0576,0.045986,0.027939,0.047452,0.0576,0.052473,0.054569,0.030684,...,0.025428,0.051129,0.033465,0.038958,0.055219,0.076457,0.050082,0.063579,0.017671,0.038108
5,L,6,0.072376,0.048256,0.016242,0.060321,0.100453,0.031553,0.052616,0.02303,...,0.029372,0.063879,0.025561,0.039337,0.062976,0.064187,0.053823,0.060214,0.013722,0.028164


Interestingly / annoyingly esm seems to work regardless of the space separator?? So I guess I'll just default to having space. I was considering swiping the delimeter all together, and just hard coding " " in one place, but I might as well just set it to be defalt to sep = " ", then if something crazy happens and the sep needs to be configured, I can! Nah, I just hardcoded the " ", so much simpler for now..