In [1]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 5.5 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 58.2 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 53.0 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1


So, I'm pretty sure since bert and esm are available through hugging face, I can now make a somewhat more generalizable approach to working with the models. 

The first big issue to think about is loading the various tokenizers and maskedLM from the transformers library. The options seems like as follows:

1. Load all the libraries at once
2. Have each of the libraries be associated with different files that are loaded as needed
3. Have a function that somehow loads the library on command???

The third one sounds impossible, I'll try that first! Yep probably impossible. I think just loading bert and ESM should be fine for now?? It's only four things to import..

In [2]:
from transformers import BertTokenizer, BertForMaskedLM, EsmTokenizer, EsmForMaskedLM
import torch
import pandas as pd
import numpy as np

I think I should figure out a way to plug `predictionDF()` as predictionDict or something here?? Maybe?? Worth thinking about I suppose, maybe that will be a separate refactoring..

In [14]:
class modelPredDF():
    def __init__(self, seq, tokenizer, model):
        self.aas = "ACDEFGHIKLMNPQRSTVWY"
        predDict = self.predictionDict(seq, tokenizer, model)
        self.df = pd.DataFrame.from_dict(predDict, orient = "index", columns = list(self.aas))
        self.df = self.df.div(self.df.sum(axis=1),axis=0)
        self.df.insert(0, "wt",list(seq))
        self.df.insert(1, "wtIndex",list(range(1,len(seq)+1)))
        wtScore = self.scoreCol("wt")
        self.df.insert(2, "wtScore",wtScore)
        self.df.insert(3, "n_effective", self.n_effective())
        self.df.insert(4, "topAA",self.topAA())
        topAAscore = self.scoreCol("topAA")
        self.df.insert(5, "topAAscore", topAAscore)
        
        self.wtSeq = ''.join(list(self.df["wt"]))
        self.wtSeqScore = self.scoreSeq(self.wtSeq)

        self.topAASeq = ''.join(list(self.df["topAA"]))
        self.topAASeqScore = self.scoreSeq(self.topAASeq)

    def predictionDict(self, seq, tokenizer, model):
      naturalAAIndices = naturalAAIndex(self.aas,tokenizer)
      predDict = {}
      for wtIndex in range(len(seq)):
        maskedSeq = tokenizeSeq(seq, tokenizer, mask_index = wtIndex)
        seq_logits = run_model(model, maskedSeq)
        seq_probs = logits2prob(seq_logits)
        predDict[wtIndex] = [i.item() for i in getNatProbs(naturalAAIndices, seq_probs[0, wtIndex +1])]
      return predDict

    def scoreCol(self, col):
        score = []
        for row in self.df.to_dict(orient="records"):
	        col_aa = row[col]
	        score.append(row[col_aa])
        return score
    
    def scoreSeq(self, seq):
      seqScore = 0
      if len(seq) != len(self.df):
        raise Exception(f"The provided sequence is of length {len(seq)}, but berteome expected {len(self.df)}")
      for index, row in self.df.iterrows():
        seqScore += row[seq[index]]
      return seqScore / len(self.df)


    def n_effective(self):
      df_aas = self.df[list(self.aas)]
      entropy =  -(np.log(df_aas) * df_aas)
      return np.exp(entropy.sum(axis = 1))

    def topAA(self):
      return self.df[list(self.aas)].idxmax(axis=1)
                            
    def aa_correlation(self):
      return self.df[list(self.aas)].corr()

In [4]:
class modelLoader():
  def __init__(self):
    self.supported_model_dict = {
        "Rostlab/prot_bert" : self.token_model_dict("prot_bert"),
        "facebook/esm2_t33_650M_UR50D" : self.token_model_dict("esm"),
        "facebook/esm1b_t33_650M_UR50S": self.token_model_dict("esm")
    }
    self.supported_models = list(self.supported_model_dict.keys())

  
  def token_model_dict(self, model_name):
    if model_name == "prot_bert":
      tokenModelDict = {"tokenizer":BertTokenizer, "model":BertForMaskedLM}
    if model_name == "esm":
      tokenModelDict = {"tokenizer":EsmTokenizer, "model":EsmForMaskedLM}
    return tokenModelDict
  
  def load_model(self, model_path):
    tokenizerLM = self.supported_model_dict[model_path]["tokenizer"]
    maskedLM = self.supported_model_dict[model_path]["model"]
    tokenizer = tokenizerLM.from_pretrained(model_path)
    model = maskedLM.from_pretrained(model_path)
    return tokenizer, model

In [None]:
modelLoader().supported_models

['Rostlab/prot_bert',
 'facebook/esm2_t33_650M_UR50D',
 'facebook/esm1b_t33_650M_UR50S']

In [15]:
berteome_models = modelLoader()

In [16]:
berteome_models.supported_models

['Rostlab/prot_bert',
 'facebook/esm2_t33_650M_UR50D',
 'facebook/esm1b_t33_650M_UR50S']

In [17]:
bert_tokenizer, bert_model = berteome_models.load_model("Rostlab/prot_bert")

Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/361 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
mendel_berteome = modelPredDF("MENDEL", bert_tokenizer, bert_model)

In [21]:
mendel_berteome.wtSeq

'MENDEL'

In [22]:
mendel_berteome.wtSeqScore

0.06513695385878104

In [23]:
mendel_berteome.topAASeq

'ELELLE'

In [24]:
mendel_berteome.topAASeqScore

0.127035315825644

In [28]:
mendel_berteome.scoreSeq("MMMMMM")

0.04512516515453344

In [19]:
mendel_berteome.df

Unnamed: 0,wt,wtIndex,wtScore,n_effective,topAA,topAAscore,A,C,D,E,...,M,N,P,Q,R,S,T,V,W,Y
0,M,1,0.076602,16.680519,E,0.118906,0.036697,0.011504,0.048245,0.118906,...,0.076602,0.072661,0.024722,0.038672,0.043105,0.07028,0.056544,0.049927,0.007781,0.021699
1,E,2,0.07483,17.599154,L,0.106501,0.045721,0.015662,0.041921,0.07483,...,0.043581,0.062667,0.025277,0.036911,0.055543,0.064425,0.049955,0.056789,0.012691,0.029893
2,N,3,0.04199,14.518531,E,0.184364,0.043564,0.009685,0.16259,0.184364,...,0.041484,0.04199,0.019992,0.025515,0.029433,0.048106,0.030303,0.054742,0.00743,0.024924
3,D,4,0.049748,17.561047,L,0.109088,0.042083,0.013244,0.049748,0.086194,...,0.04008,0.060822,0.032024,0.039689,0.046228,0.062323,0.044901,0.058937,0.010875,0.026596
4,E,5,0.086915,17.921406,L,0.090807,0.046641,0.01877,0.079822,0.086915,...,0.028962,0.062234,0.023879,0.030534,0.040489,0.065195,0.044938,0.068038,0.012156,0.038034
5,L,6,0.060736,16.068075,E,0.152547,0.038191,0.009217,0.065189,0.152547,...,0.040042,0.096484,0.020712,0.035022,0.046888,0.049071,0.046247,0.048276,0.010486,0.022727


In [20]:
mendel_aa_correlations = mendel_berteome.aa_correlation()

In [None]:
mendel_aa_correlations

Unnamed: 0,A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y
A,1.0,0.728715,0.23581,-0.38988,0.879478,0.295939,0.745629,0.281994,-0.521591,0.733512,-0.720194,-0.611639,0.079973,-0.433475,-0.010752,0.051076,-0.411044,0.833235,0.585926,0.854028
C,0.728715,1.0,-0.335086,-0.816555,0.854112,0.23124,0.948531,0.774243,-0.042334,0.46636,-0.382031,-0.235096,0.369489,0.063834,0.313217,0.63868,0.247711,0.876376,0.736407,0.923179
D,0.23581,-0.335086,1.0,0.76598,0.084237,-0.105943,-0.311785,-0.822663,-0.909457,0.087421,-0.275042,-0.581996,-0.599214,-0.924922,-0.89091,-0.671449,-0.903984,0.053589,-0.545103,-0.021774
E,-0.38988,-0.816555,0.76598,1.0,-0.555584,-0.275365,-0.756599,-0.960437,-0.445062,-0.449607,0.09659,-0.027763,-0.732526,-0.612387,-0.710517,-0.797275,-0.600745,-0.555534,-0.767346,-0.570185
F,0.879478,0.854112,0.084237,-0.555584,1.0,0.456554,0.850721,0.485917,-0.477467,0.699526,-0.622552,-0.579098,0.359107,-0.254099,-0.072739,0.316781,-0.244826,0.988906,0.546931,0.916871
G,0.295939,0.23124,-0.105943,-0.275365,0.456554,1.0,0.469717,0.397913,-0.077729,0.311335,-0.730916,0.058536,0.495873,0.101611,0.103227,-0.197846,-0.268709,0.464575,0.501189,0.351613
H,0.745629,0.948531,-0.311785,-0.756599,0.850721,0.469717,1.0,0.780563,-0.042422,0.403466,-0.613977,-0.096189,0.33173,0.020781,0.334186,0.428619,0.133945,0.884543,0.852824,0.949147
I,0.281994,0.774243,-0.822663,-0.960437,0.485917,0.397913,0.780563,1.0,0.529266,0.250584,-0.168636,0.251695,0.680964,0.638904,0.73224,0.718683,0.641502,0.519188,0.816,0.560873
K,-0.521591,-0.042334,-0.909457,-0.445062,-0.477467,-0.077729,-0.042422,0.529266,1.0,-0.363205,0.430718,0.773594,0.335643,0.889435,0.850884,0.411444,0.87226,-0.447166,0.317516,-0.325412
L,0.733512,0.46636,0.087421,-0.449607,0.699526,0.311335,0.403466,0.250584,-0.363205,1.0,-0.36075,-0.779562,0.554163,-0.037801,0.062683,0.196178,-0.320043,0.588138,0.326964,0.436263


In [5]:
def token_model_dict(model_name):
  if model_name == "prot_bert":
    tokenModelDict = {"tokenizer":BertTokenizer, "model":BertForMaskedLM}
  if model_name == "esm":
    tokenModelDict = {"tokenizer":EsmTokenizer, "model":EsmForMaskedLM}
  return tokenModelDict

In [6]:
def load_model(model_path):
  supported_models = {
      "Rostlab/prot_bert" : token_model_dict("prot_bert"),
      "facebook/esm2_t33_650M_UR50D" : token_model_dict("esm"),
      "facebook/esm1b_t33_650M_UR50S": token_model_dict("esm")
  }
  tokenizerLM = supported_models[model_path]["tokenizer"]
  maskedLM = supported_models[model_path]["model"]
  tokenizer = tokenizerLM.from_pretrained(model_path)
  model = maskedLM.from_pretrained(model_path)
  return tokenizer, model

So as far as a good entrypoint for where users make decisions as to what model they use, this seems like as good of a place as ever! So as it stands, the user would have to provide the model_path, tokenizer name and maskedLM name.. I don't like that! 

I think what we'll eventually do here is just put all that info in somesort of datastructure, to simplify the amount of input needed from the user. Off the top of my head, it seems like just querying "ESM" or "ESM1b" might be a bit too vague. I think having the actual model path as the key values would be useful?? That way the user just needs to know the path of the model they want to use (probably important to be able to know that pretty specifically), then they should be off to the races from there! Of course, running multiple different models should be just as easy as making new variables with updated model paths!

In [7]:
def run_model(model, inputs):
  with torch.no_grad():
    logits = model(**inputs).logits
  return logits

In [8]:
def logits2prob(logits):
  return torch.softmax(logits,dim=2)

In [9]:
def maskifySeq(seq, tokenizer, i):
    seqList = list(seq)
    if i != None:
      seqList[i] = tokenizer.mask_token 
    return " ".join(seqList)

In [10]:
def tokenizeSeq(seq, tokenizer, mask_index = None, return_tensors = "pt"):
  maskified_seq = maskifySeq(seq, tokenizer, mask_index)
  return tokenizer(maskified_seq, return_tensors=return_tensors)

In [11]:
def naturalAAIndex(aas, tokenizer):
    return tokenizeSeq(aas, tokenizer, return_tensors=None)["input_ids"][1:-1]

In [12]:
def getNatProbs(natAAList,probList):
    natProbList = []
    for natAAIndex in natAAList:
      natProbList.append(probList[natAAIndex])
    return natProbList

In [13]:
def predictionDF(seq, tokenizer, model, aas = "ACDEFGHIKLMNPQRSTVWY"):
  naturalAAIndices = naturalAAIndex(aas,tokenizer)
  predDict = {}
  for wtIndex in range(len(seq)):
    maskedSeq = tokenizeSeq(seq, tokenizer, mask_index = wtIndex)
    seq_logits = run_model(model, maskedSeq)
    seq_probs = logits2prob(seq_logits)
    predDict[wtIndex] = [i.item() for i in getNatProbs(naturalAAIndices, seq_probs[0, wtIndex +1])]
  predDF = modelPredDF(predDict, seq, aas).predDf
  return predDF

So a key part to generalizing this is knowing when to use space separation. I guess, it's still a question if esm can work with it being space separated or not??

In [None]:
esm2_tokenizer, esm2_model = load_model("facebook/esm2_t33_650M_UR50D")

Some weights of the model checkpoint at facebook/esm2_t33_650M_UR50D were not used when initializing EsmForMaskedLM: ['esm.contact_head.regression.weight', 'esm.contact_head.regression.bias']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
bert_tokenizer, bert_model = load_model("Rostlab/prot_bert",BertTokenizer,BertForMaskedLM)

Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/361 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
predictionDF("MENDEL", bert_tokenizer, bert_model)

Unnamed: 0,wt,wtIndex,wtScore,A,C,D,E,F,G,H,...,M,N,P,Q,R,S,T,V,W,Y
0,M,1,0.076602,0.036697,0.011504,0.048245,0.118906,0.024072,0.039202,0.012621,...,0.076602,0.072661,0.024722,0.038672,0.043105,0.07028,0.056544,0.049927,0.007781,0.021699
1,E,2,0.07483,0.045721,0.015662,0.041921,0.07483,0.037153,0.044325,0.018264,...,0.043581,0.062667,0.025277,0.036911,0.055543,0.064425,0.049955,0.056789,0.012691,0.029893
2,N,3,0.04199,0.043564,0.009685,0.16259,0.184364,0.033782,0.044661,0.012355,...,0.041484,0.04199,0.019992,0.025515,0.029433,0.048106,0.030303,0.054742,0.00743,0.024924
3,D,4,0.049748,0.042083,0.013244,0.049748,0.086194,0.039736,0.055911,0.016861,...,0.04008,0.060822,0.032024,0.039689,0.046228,0.062323,0.044901,0.058937,0.010875,0.026596
4,E,5,0.086915,0.046641,0.01877,0.079822,0.086915,0.050638,0.050466,0.022397,...,0.028962,0.062234,0.023879,0.030534,0.040489,0.065195,0.044938,0.068038,0.012156,0.038034
5,L,6,0.060736,0.038191,0.009217,0.065189,0.152547,0.02095,0.049525,0.013955,...,0.040042,0.096484,0.020712,0.035022,0.046888,0.049071,0.046247,0.048276,0.010486,0.022727


In [None]:
predictionDF("MENDEL", esm2_tokenizer, esm2_model)

Unnamed: 0,wt,wtIndex,wtScore,A,C,D,E,F,G,H,...,M,N,P,Q,R,S,T,V,W,Y
0,M,1,0.464699,0.034593,0.007172,0.055771,0.064563,0.018942,0.029819,0.010409,...,0.464699,0.029717,0.023785,0.019403,0.024049,0.03074,0.023793,0.034624,0.004882,0.013323
1,E,2,0.072099,0.059034,0.021043,0.054654,0.072099,0.037208,0.053121,0.025863,...,0.024344,0.059745,0.035749,0.040393,0.052331,0.073661,0.057352,0.063012,0.014037,0.030357
2,N,3,0.044189,0.055648,0.014066,0.077918,0.122202,0.034741,0.058845,0.019178,...,0.03133,0.044189,0.02968,0.034349,0.052393,0.057417,0.047451,0.06946,0.013287,0.025869
3,D,4,0.036312,0.044627,0.017634,0.036312,0.071052,0.031641,0.0498,0.023527,...,0.058537,0.043779,0.037241,0.054184,0.05514,0.060113,0.052931,0.070151,0.017574,0.027221
4,E,5,0.0576,0.045986,0.027939,0.047452,0.0576,0.052473,0.054569,0.030684,...,0.025428,0.051129,0.033465,0.038958,0.055219,0.076457,0.050082,0.063579,0.017671,0.038108
5,L,6,0.072376,0.048256,0.016242,0.060321,0.100453,0.031553,0.052616,0.02303,...,0.029372,0.063879,0.025561,0.039337,0.062976,0.064187,0.053823,0.060214,0.013722,0.028164


Interestingly / annoyingly esm seems to work regardless of the space separator?? So I guess I'll just default to having space. I was considering swiping the delimeter all together, and just hard coding " " in one place, but I might as well just set it to be defalt to sep = " ", then if something crazy happens and the sep needs to be configured, I can! Nah, I just hardcoded the " ", so much simpler for now..

In [None]:
class modelPredDF():
    def __init__(self, seq, tokenizer, model):
        self.aas = "ACDEFGHIKLMNPQRSTVWY"
        predDict = self.predictionDict(seq, tokenizer, model)
        self.predDf = pd.DataFrame.from_dict(predDict, orient = "index", columns = list(self.aas))
        self.predDf = self.predDf.div(self.predDf.sum(axis=1),axis=0)
        self.predDf.insert(0, "wt",list(seq))
        self.predDf.insert(1, "wtIndex",list(range(1,len(seq)+1)))
        wtScore = self.wtScoreCol()
        self.predDf.insert(2, "wtScore",wtScore)
    
    def predictionDict(self, seq, tokenizer, model):
      naturalAAIndices = naturalAAIndex(self.aas,tokenizer)
      predDict = {}
      for wtIndex in range(len(seq)):
        maskedSeq = tokenizeSeq(seq, tokenizer, mask_index = wtIndex)
        seq_logits = run_model(model, maskedSeq)
        seq_probs = logits2prob(seq_logits)
        predDict[wtIndex] = [i.item() for i in getNatProbs(naturalAAIndices, seq_probs[0, wtIndex +1])]
      #predDF = modelPredDF(predDict, seq, aas).predDf
      return predDict
    def wtScoreCol(self):
        wtScore = []
        for row in self.predDf.to_dict(orient="records"):
	        wt = row["wt"]
	        wtScore.append(row[wt])
        return wtScore

In [None]:
modelPredDF("MENDEL", esm2_tokenizer, esm2_model)

<__main__.modelPredDF at 0x7fd3d3b7d400>

In [None]:
esm2_berteome = modelPredDF("MENDEL", esm2_tokenizer, esm2_model)

In [None]:
esm2_berteome.predDf

Unnamed: 0,wt,wtIndex,wtScore,A,C,D,E,F,G,H,...,M,N,P,Q,R,S,T,V,W,Y
0,M,1,0.519067,0.03211,0.006379,0.040362,0.058591,0.018163,0.032379,0.008097,...,0.519067,0.023694,0.018577,0.016761,0.020366,0.027675,0.021765,0.039665,0.005265,0.014724
1,E,2,0.082258,0.06206,0.018147,0.047721,0.082258,0.036064,0.048342,0.01872,...,0.031536,0.054447,0.030269,0.034311,0.053452,0.072699,0.063905,0.072899,0.011034,0.026293
2,N,3,0.046711,0.053596,0.021541,0.056509,0.091755,0.039491,0.060947,0.020384,...,0.028182,0.046711,0.03186,0.033307,0.068211,0.074992,0.055649,0.061182,0.014518,0.024133
3,D,4,0.031955,0.039806,0.021652,0.031955,0.0577,0.046293,0.051755,0.022893,...,0.033664,0.039776,0.053131,0.049562,0.054575,0.083303,0.051578,0.055736,0.016449,0.030357
4,E,5,0.054317,0.046637,0.027707,0.043781,0.054317,0.051995,0.055344,0.030827,...,0.020944,0.045684,0.042034,0.044497,0.067907,0.077613,0.049258,0.058553,0.016128,0.03449
5,L,6,0.071026,0.054715,0.016047,0.054189,0.108328,0.030832,0.062917,0.019348,...,0.02177,0.06197,0.029235,0.045546,0.071085,0.058554,0.048158,0.053547,0.010247,0.021742


In [None]:
esm2_berteome.aas

'ACDEFGHIKLMNPQRSTVWY'