In [None]:
# default_exp berteome


In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
from transformers import BertForMaskedLM, BertTokenizer, pipeline
import pandas as pd

In [None]:
 #export
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
#export
def spacifySeq(seq):
  return "".join([ aa +" " for aa in seq]).strip()

In [None]:
assert spacifySeq("MENDEL") == "M E N D E L"

In [None]:
#export
def maskifySeq(seq, pos, mask="[MASK]"):
  seqList = seq.split()
  seqList[pos] = mask
  return "".join(aa +" " for aa in seqList).strip()

In [None]:
assert  maskifySeq("M E N D E L",3) == 'M E N [MASK] E L'


In [None]:
#export
def allResidueCoordinates(seq,residue):
  return [i for i, x in enumerate(seq) if x == residue]

In [None]:
assert allResidueCoordinates("MENDEL","E") == [1,4]

In [None]:
#export
def allResiduePredictions(seq):
  spaceSeq = spacifySeq(seq)
  
  posPredictions = []
  for aaPos in range(len(seq)):
    aa = seq[aaPos]
    maskPosSeq = maskifySeq(spaceSeq, aaPos)
    prediction = unmasker(maskPosSeq, top_k=30)
    posPredictions.append(prediction)
  return posPredictions

In [None]:
assert len(allResiduePredictions("MENDEL")) == 6

6

In [None]:
assert allResiduePredictions("MENDEL")[0][0]["token_str"] == "E"

In [None]:
#export
def getTopSeq(allPredictions):
  topSeq = ""
  for aaPred in allPredictions:    
    topSeq += aaPred[0]["token_str"]
  return topSeq

In [None]:
#export
def residuePredictionScore(allPredictions, seq):
  residueScoreDict = {
      "wt":list(seq),
      "wtIndex":list(range(len(seq)+1))[1:],
      "wtScore":[],
      "A":[],
      "C":[],
      "D":[],
      "E":[],
      "F":[],
      "G":[],
      "H":[],
      "I":[],
      "K":[],
      "L":[],
      "M":[],
      "N":[],
      "P":[],
      "Q":[],
      "R":[],
      "S":[],
      "T":[],
      "V":[],
      "W":[],
      "Y":[]
  }
  for aaPredPos in range(len(allPredictions)):
    aaPred = allPredictions[aaPredPos]
    wtAA = seq[aaPredPos]
    for predRank in range(len(aaPred)):
      posPred = aaPred[predRank]
      predAA = posPred["token_str"]
      # print(predRank, posPred["token_str"])
      if predAA in residueScoreDict:
        residueScoreDict[predAA].append(posPred["score"])
        if predAA == wtAA:
          residueScoreDict["wtScore"].append(posPred["score"])

  residueScoreDF = pd.DataFrame.from_dict(residueScoreDict)
  return residueScoreDF


In [None]:
#export
def hasNonStandardAA(seq, alphabet="ACDEFGHIKLMNPQRSTVWY"):
	return (set(seq) - set(alphabet)) != set()

In [None]:
assert hasNonStandardAA("MENDEL") == False

# Unified approach to interface with both ProtBERT and ESM

I will likely just leave the previous few functions alone so that it will continue to be supported, luckily I have unittests for the helper functions so I will be sure not to break those either.

# Experimental

In [None]:
def childrenPredictions(allPredictions,seq):
  residues = "ACDEFGHIKLMNPQRSTVQY"

  parentScoreDF = residuePredictionScore(allPredictions, seq)
  parentScoreDF["seq"] = seq
  scoreDFs = [parentScoreDF]
  for prediction in allPredictions:
    top5predictions = prediction[:5]
    for child in top5predictions:
        childSeq = child["sequence"].replace(" ","")
        if childSeq != seq:
          if not hasNonStandardAA(childSeq):
            # print(childSeq)
            childPredictions = allResiduePredictions(childSeq)
            childScoreDF = residuePredictionScore(childPredictions, childSeq)
            childScoreDF["seq"] = childSeq
            scoreDFs.append(childScoreDF)
  return pd.concat(scoreDFs)

In [None]:
def childrenPPM(childrenScores):
  childrenScoreSum = childrenScores.groupby(["wtIndex"]).sum()
  childrenScoreSum = childrenScoreSum[childrenScoreSum.columns[1:]]
  childrenScorePPM = childrenScoreSum.div(childrenScoreSum.sum(axis=1),axis=0)
  return childrenScorePPM