In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.rna_model.model_slim import ESM2
from models.rna_model import rna_esm
from models.rna_model.evo.tokenization import Vocab, mapdict
from models.rna_model.config import TransformerConfig, OptimizerConfig, Config, DataConfig, ProduceConfig, TrainConfig, LoggingConfig
esm_path = "/home/fkli/RNAm/RNA-ESM2-trans-2a100-mappro-KDNY-epoch_06-valid_F1_0.564.ckpt"
class RNAESM2(nn.Module):
    def __init__(self, esm_ckpt, device="cuda:0"):
        super(RNAESM2, self).__init__()
        self.device = device
        if esm_ckpt is None:
            raise ValueError("Please provide a valid RNA-ESM2 checkpoint")
        else:
            self.esm_ckpt = esm_ckpt
        self.model, self.rna_map_vocab, self.rna_alphabet = self.__init_model()

    def __init_model(self):
        _, protein_alphabet = rna_esm.pretrained.esm2_t30_150M_UR50D()
        rna_alphabet = rna_esm.data.Alphabet.from_architecture("rna-esm")
        protein_vocab = Vocab.from_esm_alphabet(protein_alphabet)
        rna_vocab = Vocab.from_esm_alphabet(rna_alphabet)
        rna_map_dict = mapdict(protein_vocab, rna_vocab)
        rna_map_vocab = Vocab.from_esm_alphabet(rna_alphabet, rna_map_dict)
        model = ESM2(
            vocab=protein_vocab,
            model_config=TransformerConfig(),
            optimizer_config=OptimizerConfig(),
            contact_train_data=None,
            token_dropout=True,
        )
        print(f"Loading RNA-ESM2 model: {self.esm_ckpt}")
        model.load_state_dict(
            torch.load(self.esm_ckpt, map_location="cpu")[
                "state_dict"
            ],
            strict=True,
        )
        return model, rna_map_vocab, rna_alphabet
    
    def forward(self, data_seq_raw, set_max_len=80):
        self.model.eval()
        self.model.to(self.device)

        output = dict()
        for i, seq in enumerate(data_seq_raw):
            if "Y" in seq:
                data_seq_raw[i] = seq.replace("Y", "N")

        with torch.no_grad():
            tokens = torch.from_numpy(self.rna_map_vocab.encode(data_seq_raw))
            infer = self.model(
                tokens.to(self.device), repr_layers=[30], return_contacts=True
            )
            embedding = infer["representations"][30]
            attention = infer["attentions"]

            try:
                embedding = F.softmax(embedding, dim=-1)
            except:
                ValueError("Error in softmax")

            output["embedding"] = embedding
            output["attention"] = attention
            output["contacts"] = infer["contacts"]

        return output


In [2]:
import numpy as np
import pickle as pkl
base_path = "/home/fkli/RNAdata"
with open(base_path + '/RNAcmap2_231.pkl', 'rb') as f:
    true_labels_pdb = pkl.load(f)
data_nums = len(true_labels_pdb)
print(data_nums, true_labels_pdb[0])

model = RNAESM2(esm_path)
model.eval()

def pred_map_to_pair(pred_contact_map, seq_len):
    pair_list = [[i, j, pred_contact_map[i][j]] for i in range(seq_len) for j in range(i)]
    pair_list.sort(key=lambda x: x[2], reverse=True)
    return [[x[1], x[0]] for x in pair_list]
count = 0
save_all_bps = []
for data in true_labels_pdb:
    true_pairs = [i for i in data[3] if abs(i[0]-i[1]) > 3]  # non-local base-pairs
    # print(true_pairs)
    L = len(data[2])
    # print(f"count is {count}, seq is {data[2]}, seq len is {L}")
    with torch.no_grad ():
        p_contact = model([data[2]])['contacts'].cpu().float()
    p_contact = torch.Tensor.tolist(p_contact.squeeze())
    pair_list = pred_map_to_pair(p_contact, L)
    n = int(L/2)
    positive_pairs = pair_list[:]
    # print(positive_pairs)
    negative_pairs = pair_list[n:]
    tp = 0;fp = 0;fn = 0
    correct_pairs = []
    for i,I in enumerate(positive_pairs):
        if I in true_pairs:
            tp +=1
            correct_pairs.append(I)
        elif I not in true_pairs:
            fp += 1
            # print(tp)
    for i,I in enumerate(true_pairs):
        if I not in positive_pairs:
            fn += 1
    tn = L*L- tp - fp - fn

    try:
        pre = tp / (tp + fp)
        sen = tp / (tp + fn)
        f1 = 2*((pre*sen)/(pre + sen))
        #with np.errstate(invalid='ignore'):
        mcc = ((tp * tn) - (fp * fn)) / np.sqrt(np.float64((tp + fp) * (tp + fn) * (tn + fn) * (tn + fp)))
    except:
        pre = 0
        sen = 0
        f1 = 0
        mcc = 0; #print(k)
    save_all_bps.append([f1, pre, sen])
    count += 1
all_metrics = np.mean(save_all_bps, axis=0) 
print(all_metrics, count)

231 [0, '1wz2_C', 'GCGGGGGUUGCCGAGCCUGGUCAAAGGCGGGGGACUCAAGAUCCCCUCCCGUAGGGGUUCCGGGGUUCGAAUCCCCGCCCCCGCACCA', [[0, 83], [1, 82], [2, 81], [3, 80], [4, 79], [5, 78], [6, 77], [7, 13], [8, 12], [9, 27], [10, 26], [11, 25], [12, 24], [13, 23], [14, 21], [14, 23], [14, 59], [18, 66], [19, 67], [21, 59], [23, 58], [28, 46], [29, 45], [30, 44], [31, 43], [32, 42], [33, 41], [47, 56], [48, 55], [49, 54], [60, 76], [61, 75], [62, 74], [63, 73], [64, 72], [65, 69]]]


Vocab contains non-special token of length > 1: <null_1>


Loading RNA-ESM2 model: /home/fkli/RNAm/RNA-ESM2-trans-2a100-mappro-KDNY-epoch_06-valid_F1_0.564.ckpt
<generator object Module.parameters at 0x7f89e46dd820>
[0.02433111 0.01235184 1.        ] 231


In [57]:
predicted_map = [[1, 0, 1],  
                 [0.1, 1, 0], 
                 [0.7, 0.4, 1.9]]
def pred_map_to_pair(pred_contact_map, seq_len):
    pair_list = [[i, j, pred_contact_map[i][j]] for i in range(seq_len) for j in range(i)]
    pair_list.sort(key=lambda x: x[2], reverse=True)
    return [[x[0], x[1]] for x in pair_list]

pair_list = pred_map_to_pair(predicted_map, 3)

[[2, 0], [2, 1], [1, 0]]


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pickle

datasets_path = "/home/fkli/Projects/RNADiffFold/dataset/dataset.pkl"
with open(datasets_path, "rb") as f:
    dataset = pickle.load(f)

train_seq_len_list = [len(seq[2]) for seq in dataset["train_dataset"]]
train_seq_len_list.sort()
train_seq_num = len(train_seq_len_list)
print(f"train seq num: {train_seq_num}")
my_x_ticks = np.arange(0, max(train_seq_len_list)+80, 80)
plt.xticks(my_x_ticks)
plt.xlabel('seq length')
plt.ylabel('seq num')
plt.hist(train_seq_len_list, bins=70)
plt.show()

In [None]:
from data.data_generator import RNADataset, diff_collate_fn, get_data_id
from torch.utils.data import DataLoader
from functools import partial
from os.path import join

DATA_PATH = "/home/fkli/Projects/DiffRNA/datasets/batching"
train = RNADataset([join(DATA_PATH, "train")], upsampling=False)
print(len(train))
partial_collate_fn = partial(diff_collate_fn)

train_loader = DataLoader(
    train,
    batch_size=1,
    shuffle=True,
    num_workers=8,
    # collate_fn=partial_collate_fn,
    pin_memory=False,
    drop_last=True,
)
print(len(train_loader))

In [None]:
import os
import numpy as np
import pickle

def rna_evaluation(preds, targets):
    preds = preds.reshape(-1)
    targets = targets.reshape(-1)
    tp = torch.sum(preds * targets)
    tn = torch.sum((1 - preds) * (1 - targets))
    fp = torch.sum(preds * (1 - targets))
    fn = torch.sum((1 - preds) * targets)
    accuracy = (tp + tn) / (tp + tn + fp + fn) # accuracy
    prec = tp / (tp + fp)  # precision
    recall = tp / (tp + fn)  # recall
    sens = tp / (tp + fn)  # senstivity
    spec = tn / (tn + fp)  # spec

    F1 = 2 * ((prec * sens) / (prec + sens))
    MCC = (tp * tn - fp * fn) / torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))

    return accuracy, prec, recall, sens, spec, F1, MCC.cpu().item()
    
def padding_two(data_array, maxlen):
    a, b = data_array.shape
    # np.pad(array, ((before_1,after_1),……,(before_n,after_n),module)
    return np.pad(data_array, ((0, maxlen - a), (0, maxlen - b)), "constant")

d_path = "/home/fkli/RNAdata/RNAcmap2/datasets/test"
i = 0
FF = 0
PP = 0
model = RNAESM2(esm_path)
model.eval()
for file in os.listdir(d_path):
    i += 1
    file_name = os.path.join(d_path, file)

    with open(file_name, "rb") as f:
        dataset = pickle.load(f)
    str = [seq_data["seq_raw"] for seq_data in dataset]
    seq_max_len = max(len(seq) for seq in str)
    with torch.no_grad ():
        p_contact = model(str, seq_max_len)['contacts'].cpu().float()
    contact_map = [padding_two(seq_data["contact"], seq_max_len) for seq_data in dataset]
    test_no_train_tmp = [rna_evaluation( p_contact[i], contact_map[i]) for i in range(p_contact.shape[0])]
    accuracy, prec, recall, sens, spec, F1, MCC = zip(*test_no_train_tmp)
    precision = np.average(np.nan_to_num(np.array(prec)))
    F1 = np.average(np.nan_to_num(np.array(F1)))
    PP += precision
    FF += F1
    torch.cuda.empty_cache()
print(FF/i, PP/i)

In [1]:
import torch
from models.model import DiffusionRNA2dPrediction
from models.rna_model.config import TransformerConfig, OptimizerConfig, Config, DataConfig, ProduceConfig, TrainConfig, LoggingConfig
esm_path = "/home/fkli/RNAm/RNA-ESM2-trans-2a100-mappro-KDNY-epoch_06-valid_F1_0.564.ckpt"
model = DiffusionRNA2dPrediction(
    num_classes=2,
    diffusion_dim=8,
    cond_dim=8,
    diffusion_steps=20,
    dp_rate=0.0,
    u_ckpt="/home/fkli/RNAm/ufold_train_alldata.pt",
    esm_ckpt=esm_path,
)
ckpt_path = '/home/fkli/RNAm/best_checkpoint_15.pt'
checkpoint = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(checkpoint["model"])
model.to("cuda:3")
print('load model from {}'.format(ckpt_path))

Vocab contains non-special token of length > 1: <null_1>


Loading RNA-ESM2 model: /home/fkli/RNAm/RNA-ESM2-trans-2a100-mappro-KDNY-epoch_06-valid_F1_0.564.ckpt
load model from /home/fkli/RNAm/best_checkpoint_15.pt


In [2]:
import os
import numpy as np
import pandas as pd
from common.data_utils import contact_map_masks
from torch.utils.data import DataLoader
from data.data_generator import RNADataset, diff_collate_fn
from functools import partial
from common.loss_utils import rna_evaluation
DATA_PATH = "/home/fkli/RNAdata/RNAcmap2/datasets"
partial_collate_fn = partial(diff_collate_fn)
test = RNADataset([os.path.join(DATA_PATH, "test")], upsampling=False)
test_loader = DataLoader(
    test,
    batch_size=1,
    shuffle=False,
    num_workers=8,
    collate_fn=partial_collate_fn,
    pin_memory=False,
    drop_last=False,
)
device = torch.device("cuda:3")
def model_test():
    model.eval()
    with torch.no_grad():
        test_no_train = list()
        total_name_list = list()
        total_length_list = list()

        for _, (contact, base_info, data_seq_raw, data_length, data_name, set_max_len, data_seq_encoding) in enumerate(test_loader):
            total_name_list += [item for item in data_name]
            total_length_list += [item.item() for item in data_length]

            base_info = base_info.to(device)
            matrix_rep = torch.zeros_like(contact)
            data_length = data_length.to(device)
            # data_seq_raw = data_seq_raw.to(device)
            data_seq_encoding = data_seq_encoding.to(device)
            contact_masks = contact_map_masks(data_length, matrix_rep).to(device)

            # calculate contact loss
            batch_size = contact.shape[0]
            pred_x0, _ = model.sample(batch_size, base_info, data_seq_raw, set_max_len, contact_masks, data_seq_encoding)

            pred_x0 = pred_x0.cpu().float()
            test_no_train_tmp = list(map(lambda i: rna_evaluation(
                pred_x0[i].squeeze(), contact.float()[i].squeeze()), range(pred_x0.shape[0])))
            test_no_train += test_no_train_tmp
            torch.cuda.empty_cache()

        accuracy, prec, recall, sens, spec, F1, MCC = zip(*test_no_train)

        f1_pre_rec_df = pd.DataFrame({'name': total_name_list,
                                        'length': total_length_list,
                                        'accuracy': list(np.array(accuracy)),
                                        'precision': list(np.array(prec)),
                                        'recall': list(np.array(recall)),
                                        'sensitivity': list(np.array(sens)),
                                        'specificity': list(np.array(spec)),
                                        'f1': list(np.array(F1)),
                                        'mcc': list(np.array(MCC))})

        accuracy = np.average(np.nan_to_num(np.array(accuracy)))
        precision = np.average(np.nan_to_num(np.array(prec)))
        recall = np.average(np.nan_to_num(np.array(recall)))
        sensitivity = np.average(np.nan_to_num(np.array(sens)))
        specificity = np.average(np.nan_to_num(np.array(spec)))
        F1 = np.average(np.nan_to_num(np.array(F1)))
        MCC = np.average(np.nan_to_num(np.array(MCC)))

        print('#' * 40)
        print('Average testing accuracy: ', round(accuracy, 3))
        print('Average testing F1 score: ', round(F1, 3))
        print('Average testing precision: ', round(precision, 3))
        print('Average testing recall: ', round(recall, 3))
        print('Average testing sensitivity: ', round(sensitivity, 3))
        print('Average testing specificity: ', round(specificity, 3))
        print('#' * 40)
        print('Average testing MCC', round(MCC, 3))
        print('#' * 40)
        print('')
    
    return {'f1': F1, 'precision': precision, 'recall': recall, 'sensitivity': sensitivity, 'specificity': specificity, 'accuracy': accuracy, 'mcc': MCC}, f1_pre_rec_df

In [4]:
model_test()
# f1_pre.to_csv(
#     os.path.join("/home/fkli/", f"test.csv"),
#     index=False,
#     header=False,
# )

sampling loop time step: 100%|██████████| 20/20 [00:00<00:00, 20.55it/s]
sampling loop time step: 100%|██████████| 20/20 [00:00<00:00, 34.60it/s]
sampling loop time step: 100%|██████████| 20/20 [00:01<00:00, 17.13it/s]
sampling loop time step: 100%|██████████| 20/20 [00:02<00:00,  9.62it/s]
sampling loop time step: 100%|██████████| 20/20 [00:00<00:00, 87.45it/s]
sampling loop time step: 100%|██████████| 20/20 [00:00<00:00, 65.67it/s]


########################################
Average testing accuracy:  0.996
Average testing F1 score:  0.387
Average testing precision:  0.386
Average testing recall:  0.41
Average testing sensitivity:  0.41
Average testing specificity:  0.998
########################################
Average testing MCC 0.39
########################################



({'f1': 0.38667193,
  'precision': 0.3858722,
  'recall': 0.41010335,
  'sensitivity': 0.41010335,
  'specificity': 0.99806064,
  'accuracy': 0.9962581,
  'mcc': 0.39037907342436556},
         name  length  accuracy  precision    recall  sensitivity  specificity  \
 0     5dcv_B      51  0.996094   0.400000  0.380952     0.380952     0.998119   
 1     6p2h_A      69  0.992969   0.250000  0.142857     0.142857     0.997643   
 2     2xdb_G      40  0.996562   0.214286  0.214286     0.214286     0.998277   
 3     4x0b_B      77  0.994687   0.470588  0.500000     0.500000     0.997173   
 4     6r47_A      50  0.996250   0.416667  0.227273     0.227273     0.998902   
 ..       ...     ...       ...        ...       ...          ...          ...   
 226  4v88_A4     158  0.998008   0.346154  0.209302     0.209302     0.999335   
 227   6n2v_A      99  0.997852   0.430769  0.608696     0.608696     0.998552   
 228   6vmy_A     148  0.997539   0.487179  0.306452     0.306452     0.999217