In [1]:
import transformers
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
from transformers import GPT2LMHeadModel

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
gpu_idx = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', gpu_idx)
tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2')
model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device) 

In [7]:
def protgpt_wrapper(samples):
    ppls = []
    for seq in samples:
        out = tokenizer(seq, return_tensors="pt")
        input_ids = out.input_ids.cuda()

        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)

        ppl = (outputs.loss * input_ids.shape[1]).item()
        ppls.append(ppl)
    
    ppls = np.array(ppls)
    return ppls

def extract_ll_distr(df, seq_label):
    sequences = df[seq_label]
    return -1 * protgpt_wrapper(sequences)

def extract_ll_directory(dir_name, seq_label):
    for fn in os.listdir(dir_name):
        file_path = os.path.join(dir_name, fn)
        if fn.lower().endswith(".csv"):
            try:
                df = pd.read_csv(file_path, nrows=0)  # Read only the header
                if 'loglikelihood' in df.columns: # Don't re-compute if already has loglikelihood
                    print(f"{fn} already processed - Skipping...")
                elif seq_label in df.columns:
                    df = pd.read_csv(file_path)
                    ll_distr = extract_ll_distr(df, seq_label)
                    df['loglikelihood'] = ll_distr
                    df.to_csv(file_path, index=False)
            except Exception as e: 
                pass

In [4]:
extract_ll_directory('data/baseline_data/distribution', 'seq')

original_new_10_0.5_0_results_merge.csv already processed - Skipping...
original_old_10_0.5_0_results_merge.csv already processed - Skipping...


In [5]:
extract_ll_directory('data/beam_data/distribution', 'seq')

original_old_10_0.5_0_results_merge_old_7JJK_scrmsd_beam_5_5.csv already processed - Skipping...
original_old_10_0.5_0_results_merge_old_7JJK_beam_10_1.csv already processed - Skipping...
original_new_10_0.5_0_results_merge_new_7JJK_scrmsd_beam_5_5.csv already processed - Skipping...
original_new_10_0.5_0_results_merge_new_7JJK_scrmsd_beam_5_10.csv already processed - Skipping...
original_new_10_0.5_0_results_merge_beam_10_1.csv already processed - Skipping...


In [6]:
extract_ll_directory('data/bon_data/distribution', 'seq')

original_new_10_0.5_0_results_merge_bon_10.csv already processed - Skipping...
original_old_10_0.5_0_results_merge_old_7JJK_bon_10.csv already processed - Skipping...
original_old_10_0.5_0_results_merge_old_7JJK_scrmsd_bon_5.csv already processed - Skipping...
original_new_10_0.5_0_results_merge_new_7JJK_scrmsd_bon_5.csv already processed - Skipping...
