In [None]:
import pandas as pd
import torch
import esm
from tqdm import tqdm
from Bio import SeqIO
from pathlib import Path
from scipy.stats import spearmanr

In [None]:
model_path = Path("/zhouxibin/models")
model_name = "esm1v_t33_650M_UR90S_1"
data_path = Path("data")
dataset_name = "B3VI55_LIPST_Whitehead2015"

In [None]:
def load_wt_and_data(data_path, dataset_name):
    """load wildtype and data

    Args:
        data_path (str or pathlib.Path): Data path
        dataset_name (str): dataset name

    Returns:
        fasta (Bio.SeqRecord): sequence
        mut_fitness (pd.DataFrame): dataframe
    """
    data_path = Path(data_path)
    fasta_dataset_path = data_path / dataset_name / "{}.fasta".format(dataset_name)
    mut_fitness_dataset_path = data_path / dataset_name / "{}.csv".format(dataset_name)
    
    fasta = SeqIO.read(fasta_dataset_path, "fasta")
    mut_fitness = pd.read_csv(mut_fitness_dataset_path)
    return fasta, mut_fitness

In [None]:
def label_row(row, sequence, token_probs, alphabet, offset_idx):
    # print(row)
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    # print(idx, len(sequence))
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence, idx={}, sequence[idx]={}, wt={}".format(idx, sequence[idx], wt)

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # add 1 for BOS
    # print(token_probs.shape)
    score = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded]
    return score.item()

In [None]:
def compute_pppl(row, sequence, model, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"

    # modify the sequence
    sequence = sequence[:idx] + mt + sequence[(idx + 1) :]

    # encode the sequence
    data = [
        ("protein1", sequence),
    ]

    batch_converter = alphabet.get_batch_converter()

    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # compute probabilities at each position
    log_probs = []
    for i in range(1, len(sequence) - 1):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1)
        log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item())  # vocab size
    return sum(log_probs)

## 数据集的特殊情况
- POLG_HCVJF_Sun2014: wildtype长度超过1022
- PTEN_HUMAN_Fowler2018: 最后有一项是WT，需要手动在数据集里面删掉
- PABP_YEAST_Fields2013-doubles: 双点突变，暂时处理不了
- BRCA1_HUMAN_RING: wildtype长度超过1022
- MTH3_HAEAESTABILIZED_Tawfik2015: 第26号位置csv中是A，而sequence中是C。并且由于这个是Stabilized，所以我们相信csv中的标记，将sequence改为A；同理104号将sequence的I改为K；115 M到L；181 F到L；327 C到R
- UBC9_HUMAN_Roth2017: csv中159号位置有一个Y，但是sequence中没有，所以给sequence最后加入一个Y
- UBE4B_MOUSE_Klevit2013-singles: 长度超过1022
- TIM_THEMA_b0: 102号位置 C换成S
- F7YBW7_MESOW_vae: 四点突变，暂时不处理
- B3VI55_LIPSTSTABLE: 140号位置L变成I，142号位置从S变成A，373号位置A变成C
- HIS7_YEAST_Kondrashov2017: 多点同时突变，暂时不处理
- BRCA1_HUMAN_BRCT: 长度超过1022
- TPMT_HUMAN_Fowler2018: 最后有一项是WT，需要手动在数据集里面删掉

In [None]:
wrong_datasets = [
    "POLG_HCVJF_Sun2014", "PABP_YEAST_Fields2013-doubles", "BRCA1_HUMAN_RING", 
    "UBE4B_MOUSE_Klevit2013-singles", "F7YBW7_MESOW_vae", "HIS7_YEAST_Kondrashov2017",
    "BRCA1_HUMAN_BRCT"
]
def score_fitness(model_path, model_name, data_path):
    esm_path = model_path / "{}.pt".format(model_name)
    model, alphabet = esm.pretrained.load_model_and_alphabet(str(esm_path.absolute()))
    
    scoring_strategy_pool = ["wt-marginals", "masked-marginals", "pseudo-ppl"]
    scoring_strategy = scoring_strategy_pool[0]
    
    total = {}
    for scoring_strategy in scoring_strategy_pool:
        total[scoring_strategy] = {}
        for dataset in data_path.glob("*"):
            dataset_name = dataset.stem
            if dataset_name in wrong_datasets or not dataset.is_dir():
                continue
            # print(scoring_strategy, dataset_name)
            wildtype, mut_fitness = score_fitness_one_iteration(model, alphabet, data_path, dataset_name, scoring_strategy)
            total[scoring_strategy][dataset_name] = {"wildtype": str(wildtype.seq), "mut_fitness": mut_fitness}
            spearmanr_ = spearmanr(mut_fitness.iloc[:, 1], mut_fitness.iloc[:, 2], nan_policy="omit")
            total[scoring_strategy][dataset_name]["spearmanr"] = spearmanr_.correlation
            # break
        break
    return total

def score_fitness_one_iteration(model, alphabet, data_path, dataset_name, scoring_strategy):
    wildtype, mut_fitness = load_wt_and_data(data_path, dataset_name)
    
    if len(str(wildtype.seq)) > 1022:
        print("wild type size is over 1022")
        return None, None
    
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()
        print("Transferred model to GPU")

    batch_converter = alphabet.get_batch_converter()

    data = [
        ("protein1", str(wildtype.seq)),
    ]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    # print(len(wildtype.seq), batch_tokens.shape, len(batch_strs[0]), batch_labels[0])

    offset_idx=1 # 生物中都是1开始的索引
    if scoring_strategy == "wt-marginals":
        with torch.no_grad():
            token_probs = torch.log_softmax(model(batch_tokens.cuda())["logits"], dim=-1)
            # print(token_probs.shape)
        mut_fitness[model_name+"_"+scoring_strategy] = mut_fitness.apply(
            lambda row: label_row(
                row[0],
                str(wildtype.seq),
                token_probs,
                alphabet,
                offset_idx,
            ),
            axis=1,
        )
    elif scoring_strategy == "masked-marginals":
        all_token_probs = []
        for i in tqdm(range(batch_tokens.size(1))):
            batch_tokens_masked = batch_tokens.clone()
            batch_tokens_masked[0, i] = alphabet.mask_idx
            with torch.no_grad():
                token_probs = torch.log_softmax(
                    model(batch_tokens_masked.cuda())["logits"], dim=-1
                )
            all_token_probs.append(token_probs[:, i])  # vocab size
        token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
        mut_fitness[model_name+"_"+scoring_strategy] = mut_fitness.apply(
            lambda row: label_row(
                row[0],
                str(wildtype.seq),
                token_probs,
                alphabet,
                offset_idx,
            ),
            axis=1,
        )
    elif scoring_strategy == "pseudo-ppl":
        tqdm.pandas()
        mut_fitness[model_name+"_"+scoring_strategy] = mut_fitness.progress_apply(
            lambda row: compute_pppl(
                row[0], str(wildtype.seq), model, alphabet, offset_idx
            ),
            axis=1,
        )
    return wildtype, mut_fitness

total = score_fitness(model_path, model_name, data_path)


In [None]:
total