In [3]:
import transformers
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
from transformers import GPT2LMHeadModel
import argparse
from tqdm import tqdm
from tree_spex import lgboost_fit, lgboost_to_fourier, lgboost_tree_to_fourier, ExactSolver # type:ignore

In [18]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
gpu_idx = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', gpu_idx)
tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2')
model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device) # type: ignore

In [64]:
def protgpt_wrapper(samples, model, tokenizer):
    res = []
    for seq in samples:
        out = tokenizer(seq, return_tensors="pt")
        input_ids = out.input_ids.cuda(device=model.device)

        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)

        ppl = (outputs.loss * input_ids.shape[1]).item()
        res.append(ppl)
    
    res = np.array(res)
    return res

def calc_mask_reward(seq, mask):
    tokens = list(seq)
    for j in range(mask.shape[0]):
        if mask[j] == 1:
            tokens[j] = 'X'
    mask_seq = "".join(tokens)
    reward = protgpt_wrapper([mask_seq], model, tokenizer)
    return reward

def analyze_seq(seq, num_masks=500, p=0.5, top_interactions=25):
    num_tokens = len(seq)
    all_masks = np.random.choice(2, size=(num_masks, num_tokens), p = np.array([1-p, p]))

    rewards = np.array([calc_mask_reward(seq, m) for m in all_masks]).ravel()

    feature_names = [f"x{i}" for i in range(num_tokens)]

    # Convert to DataFrame
    X_df = pd.DataFrame(all_masks, columns=feature_names)

    best_model, cv_r2 = lgboost_fit(X_df, rewards)
    print(f'CV r2: {cv_r2}')

    # Algorithm: select top parent and its children
    fourier_dict = lgboost_to_fourier(best_model)
    fourier_dict_trunc = dict(sorted(fourier_dict.items(), key=lambda item: item[1], reverse=True)[:top_interactions])

    target_features = set()
    fourier_iter = iter(fourier_dict_trunc)

    top_coefficient = next(fourier_iter, None)
    if top_coefficient is None:
        print("No meaningful interactions found")
        return
    top_features = set()
    if sum(top_coefficient) == 0: # type: ignore
        top_coefficient = next(fourier_iter, None)
        if top_coefficient is None:
            print("No meaningful interactions found")
            return

    nonzero_pos, = np.where(np.array(top_coefficient) == 1)
    top_features.update(nonzero_pos)
    target_features.update(nonzero_pos)

    for k in fourier_dict_trunc:
        if fourier_dict_trunc[k] <= 0: break # no more contributing coefficients left
        nonzero_pos, = np.where(np.array(k) == 1)
        if len(target_features & set(nonzero_pos)) > 0:
            target_features.update(nonzero_pos)
        descr = "("
        for i in range(len(nonzero_pos) - 1):
            descr += f"{nonzero_pos[i]}, "
        if len(nonzero_pos) > 0: descr += str(nonzero_pos[-1])
        descr += ")"                    
        print(descr, fourier_dict_trunc[k])
    print(f"SPECTRAL targets: {sorted(list(target_features))}")


In [53]:
# Structure: r6_560_TrROS_Hall

seq1 = "AVPAPVVTVLVAVTNPDGKVVLKRVTLSGLPRELKPGDKVTLPETGQEATIVEVLP"
seq2 = "APPPPRVRVVVAVTRPDGRTELVTVELTGLPRPLRPGDTVTLPETGQKATVVEVLP"
diff_pos = np.where(np.array(list(seq1)) != np.array(list(seq2)))
print(diff_pos)

reward1, reward2 = protgpt_wrapper([seq1, seq2], model, tokenizer)

(array([ 1,  3,  5,  7,  9, 14, 18, 19, 20, 22, 23, 25, 27, 32, 34, 38, 47,
       50]),)


In [21]:
print("Reward 1:", reward1)
print("Reward 2:", reward2)

Reward 1: 147.87730407714844
Reward 2: 130.5386199951172


In [65]:
analyze_seq(seq1, num_masks=500, p=0.25, top_interactions=10)

CV r2: 0.47047575571636563
() 147.9111299747269
(31) 1.0844565421061554
(0) 0.9811677968194322
(14) 0.9570326399791297
(21) 0.7194319694136097
(23) 0.6554349626852427
(46) 0.5598891939967457
(12) 0.4897228231537926
(18) 0.46777505095499883
(13) 0.45707358829896627
SPECTRAL targets: [31]


In [66]:
analyze_seq(seq2, num_masks=500, p=0.25, top_interactions=10)

CV r2: 0.4812200030625572
() 145.01779040351684
(47) 1.5206155727991773
(0) 0.9261279739756545
(43) 0.9166225113008971
(19) 0.72583083639104
(21) 0.7224110147689785
(46) 0.7101205746900472
(9) 0.5912795634302482
(10) 0.5626561106433193
(27) 0.5153143672236693
SPECTRAL targets: [47]


In [56]:
print(diff_pos)

(array([ 1,  3,  5,  7,  9, 14, 18, 19, 20, 22, 23, 25, 27, 32, 34, 38, 47,
       50]),)
