In [None]:
import torch
from transformers import BertForMaskedLM, BertTokenizer, AutoTokenizer, AutoModel, pipeline
from datasets import Dataset, load_dataset
import re
import numpy as np
import os
import requests
import pandas as pd
from tqdm import tqdm

### Get model and tokenizer

In [2]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False )

In [None]:
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert_bfd")

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [5]:
tqdm.pandas()

In [6]:
def get_perplexity(sequence):
    sequence = " ".join(re.sub(r"[UZOB]", "X", sequence))
    tokenized_text = tokenizer.tokenize(sequence)
    
    
    # Track log likelihoods
    log_likelihoods = []

    for masked_index in range(1, len(tokenized_text) - 1):  # Avoid masking [CLS] and [SEP]
        # Create a copy of the token list and mask one token
        masked_tokens = tokenized_text.copy()
        masked_tokens[masked_index] = '[MASK]'
        masked_text = ' '.join(masked_tokens).replace(' ##', '')

        # Tokenize masked sentence
        inputs = tokenizer(masked_text, return_tensors="pt")

        # Move input tensors to the GPU
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get model predictions for the masked token
        with torch.no_grad():
            outputs = model(**inputs)
            predictions = outputs.logits

        # Get the softmax probabilities
        softmax = torch.nn.Softmax(dim=-1)
        mask_token_logits = predictions[0, masked_index]
        mask_token_probs = softmax(mask_token_logits)

        # Find the true token's probability
        true_token_id = inputs['input_ids'][0, masked_index]
        true_token_prob = mask_token_probs[true_token_id]

        # Calculate the negative log likelihood
        log_likelihood = torch.log(true_token_prob)
        log_likelihoods.append(log_likelihood)

    # Calculate the average negative log likelihood
    average_neg_log_likelihood = -torch.stack(log_likelihoods).mean().item()

    # Estimate the average perplexity per token
    average_perplexity = torch.exp(torch.tensor(average_neg_log_likelihood))
    return average_perplexity.item()

In [None]:
df = pd.read_csv('data/Artificial_proteins.csv')
df = df[df['Tool'] != 'Real'].reset_index(drop=True)

In [13]:
df

Unnamed: 0,ID,Sequence,Tool
0,RF_4-426_0,SLEEERKKKFIEDFRALMDVLLDRILEEIKKLCEKENKDVIIMVFV...,RFdiffusion
1,RF_4-55_0,AAIQAELEAEVARLAAAMPAVMARVAEEAKKAASLSFLELLAVLSG...,RFdiffusion
2,RF_5-600_0,KTVVLKDLPKEEMKKKLKEAAKAGDKIRIVITPENAEEVLEVIKEL...,RFdiffusion
3,RF_5-395_0,MKVLISRATIRVLSVDDEHELRFVYNEETGNLEVITVTLSEGVAYI...,RFdiffusion
4,RF_10_61_1,KKIFELTLVVSTEELALEILEELSKSCELTLVATPEGFTLLILCAL...,RFdiffusion
...,...,...,...
4995,RF_5-1318_0,SMEEVLEATVELETEAELEDLLRLIALVVALQPDARVLVAAEDGVL...,RFdiffusion
4996,RF_8_256_1,MKEKKKAVLVDITFIPASRITPETLEKMKELQAKMVAALKAGDVET...,RFdiffusion
4997,RF_1-441_0,SEKAEQLQAILDEYIEMLEKELKKKYKGLELKSIKFPLIVYDGEKD...,RFdiffusion
4998,RF_5-661_0,GLWEEVKQLVKEMKVDKEKGTLTIELVVETKDGTVLRAKITVTLPT...,RFdiffusion


In [None]:
perplexity = df['Sequence'].progress_apply(get_perplexity)

In [None]:
df['Perplexity'] = perplexity

In [20]:
df

Unnamed: 0,ID,Sequence,Tool,Perplexity
0,RF_4-426_0,SLEEERKKKFIEDFRALMDVLLDRILEEIKKLCEKENKDVIIMVFV...,RFdiffusion,1.397958
1,RF_4-55_0,AAIQAELEAEVARLAAAMPAVMARVAEEAKKAASLSFLELLAVLSG...,RFdiffusion,1.452535
2,RF_5-600_0,KTVVLKDLPKEEMKKKLKEAAKAGDKIRIVITPENAEEVLEVIKEL...,RFdiffusion,1.417952
3,RF_5-395_0,MKVLISRATIRVLSVDDEHELRFVYNEETGNLEVITVTLSEGVAYI...,RFdiffusion,1.398252
4,RF_10_61_1,KKIFELTLVVSTEELALEILEELSKSCELTLVATPEGFTLLILCAL...,RFdiffusion,1.421470
...,...,...,...,...
4995,RF_5-1318_0,SMEEVLEATVELETEAELEDLLRLIALVVALQPDARVLVAAEDGVL...,RFdiffusion,1.399607
4996,RF_8_256_1,MKEKKKAVLVDITFIPASRITPETLEKMKELQAKMVAALKAGDVET...,RFdiffusion,1.416514
4997,RF_1-441_0,SEKAEQLQAILDEYIEMLEKELKKKYKGLELKSIKFPLIVYDGEKD...,RFdiffusion,1.414330
4998,RF_5-661_0,GLWEEVKQLVKEMKVDKEKGTLTIELVVETKDGTVLRAKITVTLPT...,RFdiffusion,1.459369


In [None]:
df.to_csv('data/Artificial_proteins_perplexity.csv', index=False)

In [21]:
df.describe(include='all')

Unnamed: 0,ID,Sequence,Tool,Perplexity
count,5000,5000,5000,5000.0
unique,5000,5000,1,
top,RF_4-426_0,SLEEERKKKFIEDFRALMDVLLDRILEEIKKLCEKENKDVIIMVFV...,RFdiffusion,
freq,1,1,5000,
mean,,,,1.403765
std,,,,0.048862
min,,,,1.111967
25%,,,,1.377371
50%,,,,1.410344
75%,,,,1.437371


---------------