In [None]:
import pandas as pd

In [None]:
from pathlib import Path
model_path = Path("/zhouxibin/models")
model_name = "esm1v_t33_650M_UR90S_1"
esm_path = model_path / "{}.pt".format(model_name)
data_path = Path("data")
dataset_name = "B3VI55_LIPST_Whitehead2015"

In [None]:
from Bio import SeqIO
def load_wt_and_data(data_path, dataset_name):
    """load wildtype and data

    Args:
        data_path (str or pathlib.Path): Data path
        dataset_name (str): dataset name

    Returns:
        fasta (Bio.SeqRecord): sequence
        mut_fitness (pd.DataFrame): dataframe
    """
    data_path = Path(data_path)
    fasta_dataset_path = data_path / dataset_name / "{}.fasta".format(dataset_name)
    mut_fitness_dataset_path = data_path / dataset_name / "{}.csv".format(dataset_name)
    
    fasta = SeqIO.read(fasta_dataset_path, "fasta")
    mut_fitness = pd.read_csv(mut_fitness_dataset_path)
    return fasta, mut_fitness
wildtype, mut_fitness = load_wt_and_data(data_path, dataset_name)

In [None]:
import esm
model, alphabet = esm.pretrained.load_model_and_alphabet(str(esm_path.absolute()))

In [None]:
from tqdm import tqdm

In [None]:
import torch
model.eval()
if torch.cuda.is_available():
    model = model.cuda()
    print("Transferred model to GPU")

In [None]:
scoring_strategy_pool = ["wt-marginals", "masked-marginals", "pseudo-ppl"]
scoring_strategy = scoring_strategy_pool[0]

In [None]:

batch_converter = alphabet.get_batch_converter()

data = [
    ("protein1", str(wildtype.seq)),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
print(len(wildtype.seq), batch_tokens.shape, len(batch_strs[0]), batch_labels[0])

In [None]:
def label_row(row, sequence, token_probs, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # add 1 for BOS
    score = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded]
    return score.item()

In [None]:

def compute_pppl(row, sequence, model, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"

    # modify the sequence
    sequence = sequence[:idx] + mt + sequence[(idx + 1) :]

    # encode the sequence
    data = [
        ("protein1", sequence),
    ]

    batch_converter = alphabet.get_batch_converter()

    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # compute probabilities at each position
    log_probs = []
    for i in range(1, len(sequence) - 1):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1)
        log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item())  # vocab size
    return sum(log_probs)

In [None]:
offset_idx=1 # 生物中都是1开始的索引
if scoring_strategy == "wt-marginals":
    with torch.no_grad():
        token_probs = torch.log_softmax(model(batch_tokens.cuda())["logits"], dim=-1)
        # print(token_probs.shape)
    mut_fitness[model_name] = mut_fitness.apply(
        lambda row: label_row(
            row[0],
            str(wildtype.seq),
            token_probs,
            alphabet,
            offset_idx,
        ),
        axis=1,
    )
elif scoring_strategy == "masked-marginals":
    all_token_probs = []
    for i in tqdm(range(batch_tokens.size(1))):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(
                model(batch_tokens_masked.cuda())["logits"], dim=-1
            )
        all_token_probs.append(token_probs[:, i])  # vocab size
    token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
    mut_fitness[model_name] = mut_fitness.apply(
        lambda row: label_row(
            row[0],
            str(wildtype.seq),
            token_probs,
            alphabet,
            offset_idx,
        ),
        axis=1,
    )
elif scoring_strategy == "pseudo-ppl":
    tqdm.pandas()
    mut_fitness[model_name] = mut_fitness.progress_apply(
        lambda row: compute_pppl(
            row[0], str(wildtype.seq), model, alphabet, offset_idx
        ),
        axis=1,
    )

In [None]:
from scipy.stats import spearmanr
mut_fitness
spearmanr(mut_fitness.iloc[:, 1], mut_fitness.iloc[:, 2])