In [58]:
import pandas as pd

import os
import torch
import numpy as np
from rouskinhf import get_dataset

import plotly.graph_objects as go


In [59]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII_blast", "lncRNA_nonFiltered", "viral_fragments"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII_blast: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII_blast: Download complete. File saved at data/archiveII_blast/data.json
lncRNA_nonFiltered: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA_nonFiltered: Download complete. File saved at data/lncRNA_nonFiltered/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json


In [60]:
def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()

    sum_pair = torch.sum(pred_matrix) + torch.sum(target_matrix)

    if sum_pair == 0:
        return 1.0
    else:
        return (2 * torch.sum(pred_matrix * target_matrix) / sum_pair).item()
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

def basePairingmatrix2pairList(base_pairing_matrix, threshold=0.8):
    
    pair_list = np.where(base_pairing_matrix > threshold)
    base_pairs = np.array([pair_list[0], pair_list[1]]).T
    return np.unique(np.sort(base_pairs, axis=1), axis=0)

def pairList2dotbracket(pair_list, len_seq):
    dot_bracket = ['.'] * len_seq
    for pair in pair_list:
        dot_bracket[pair[0]] = '('
        dot_bracket[pair[1]] = ')'
    return ''.join(dot_bracket)

def basePairingmatrix2dotbracket(base_pairing_matrix, threshold=0.8):
    pair_list = basePairingmatrix2pairList(base_pairing_matrix, threshold)
    return pairList2dotbracket(pair_list, base_pairing_matrix.shape[0])

def pairList2ctfile(base_pairs, sequence, filename):

    pairs = np.zeros(len(sequence), dtype=int)
    base_pairs = np.array(base_pairs, dtype=int)

    if len(base_pairs) > 0:
        pairs[base_pairs[:,0]] = base_pairs[:,1]+1
        pairs[base_pairs[:,1]] = base_pairs[:,0]+1

    with open(filename, 'w') as f:
        f.write(f'{len(sequence)} ')
        f.write(f'{filename.split("/")[-1]}\n')
        for i, pair in enumerate(pairs):
            f.write(f'{i+1} {sequence[i]} {i} {i+2} {pair} {i+1}\n')


def pairList2dotbracketRNAstructure(pair_list, sequence):

    os.system('rm -rf temp.ct')
    os.system('rm -rf temp.dot')
    pairList2ctfile(pair_list, sequence, 'temp.ct')
    os.system('ct2dot temp.ct 0 temp.dot > /dev/null 2>&1')

    with open('temp.dot', 'r') as f:
        dotbracket = f.readlines()[-1].strip()
    
    os.system('rm -rf temp.ct')
    os.system('rm -rf temp.dot')

    return dotbracket

# def dotbracket2mountain(dotbracket):

#     mountain = [0]
#     for i in range(len(dotbracket)):
#         if dotbracket[i] in '({[<': mountain.append(mountain[-1] + 1)
#         elif dotbracket[i] in ')}]>': mountain.append(mountain[-1] - 1)
#         else: mountain.append(mountain[-1])

#     return mountain[1:]

def pairList2mountain(pair_list, len_seq):

    mountain = np.zeros(len_seq, dtype=int)
    for pair in pair_list:
        mountain[min(pair):max(pair)+1] += 1

    return mountain

    

In [61]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Existing algorithms
prediction = pd.read_feather('../Figure1/saved_data_plot/results_benchmark_algos.feather')
merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['reference', 'dataset'], suffixes=('_true', '_pred'))
merged['structure_pred'] = merged['structure_pred'].apply(lambda x: np.vstack(x) if len(x) > 0 else np.array([]))

# eFold
prediction = pd.read_feather('saved_data_plot/results_main_old/royal-snowball-97_epoch[16, 17, 18, 19, 20]_avg_PT+FT.feather')
merged_eFold = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))
merged_eFold = merged_eFold[merged_eFold['dataset']!= 'lncRNA']

merged_eFold['structure_pred'] = merged_eFold['structure_pred'].apply(np.vstack)
merged_eFold['structure_pred'] = merged_eFold['structure_pred'].apply(basePairingmatrix2pairList)
merged_eFold['model'] = 'eFold'

merged = pd.concat([merged.drop(columns=['Precision', 'F1', 'Recall', 'length']), merged_eFold])


models = merged.model.unique()
fig = make_subplots(rows=len(models), cols=1, shared_xaxes=True, vertical_spacing=0.02, subplot_titles=models)

itvl = 889, 974
print(merged[(merged.reference == 'PAN')].sequence.iloc[0])#[interval[0]:interval[1]])
for i, model in enumerate(models):

    rna = merged[(merged.reference == 'PAN') & (merged.model == model)].iloc[0]

    print('>'+model)
    print(compute_f1(pairList2pairMatrix(rna['structure_pred'], len(rna['sequence']))[itvl[0]:itvl[1], itvl[0]:itvl[1]], 
                     pairList2pairMatrix(rna['structure_true'], len(rna['sequence']))[itvl[0]:itvl[1], itvl[0]:itvl[1]]))
    print(rna['sequence'])
    # print(pairList2dotbracketRNAstructure(rna['structure_true'], rna['sequence'])[interval[0]:interval[1]])
    print(pairList2dotbracketRNAstructure(rna['structure_pred'], rna['sequence']))#[interval[0]:interval[1]])
    # print('---------')

    fig.add_trace(go.Scatter(y=pairList2mountain(rna['structure_true'], len(rna['sequence'])), mode='lines', name='True'), row=i+1, col=1)
    fig.add_trace(go.Scatter(y=pairList2mountain(rna['structure_pred'], len(rna['sequence'])), mode='lines', name='Pred'), row=i+1, col=1)

    # # plot 
    # pred = pairList2pairMatrix(rna['structure_pred'], len(rna['sequence']))[itvl[0]:itvl[1], itvl[0]:itvl[1]]
    # label = pairList2pairMatrix(rna['structure_true'], len(rna['sequence']))[itvl[0]:itvl[1], itvl[0]:itvl[1]]
    # confusion_matrix = compute_confusion_matrix(label, pred)
    # fig_ = plot_confusion_matrix("Test", confusion_matrix, label, pred)

    # plt.savefig(f'{model}_confusion_matrix.pdf', bbox_inches='tight')

fig.update_layout(height=1300, width=1000, title_text="RNA structure prediction")
fig.show()


ACUGGGACUGCCCAGUCACCUUGGCUGCCGCUUCACCUAUGGAUUUUGUGCUCGCUGCUUGCCUUCUUGCCGCUUCUGGUUUUCAUUGGUGCCGCCGAUUGUGGGUUGAUUGCGUCGCUUUUGGCAAUAUACCCAUCCUGGCUUUCGGCUAGGUUUUCCGUCCUACUUUUCCCACAUUGGCCUGAGAGCUGUAGUACAAAAAACACCGCGCGGUCUGGAGCUCUCCAUAAGCCCGCAGAACAAAAGCUGCGAUUUGCCCAAAAACCUUGCCAUGGCAACUAUACAGUCACCCCUUGCGGGUUAUUGCAUUGGAUUCAAUCUCCAGGCCAGUUGUAGCCCCCUUUUAUGAUAUGCGAGGAUACUUAACGUGUCUGAAUGUGGAAUAUAAUGUGAAAGGAAAGCAGCGCCCACUGGUGUAUCAGAACAGUGGUGCACUACCUAUCUGCUCAUUCGUUGUUUCGGUUCUGUGUUUGUCUGAUUCUUAGAUAGUGUUGAGGUAAUUCUAGAAAGCGGAUUGAGUGUAAAUCGGGCCACUUUGCCCUAAAUGUGACAAUCUGGAUGUGUAUCUUAUUGGUGCGUUGUGAAGCAUUUUAAAAUGCGUUUUAGAUUGUAUCAGGCUAGUGCUGUAAUGGUGUGUUUAUUUUUCCAGUGUAAGCAAGUCGAUUUGAAUGACAUAGGCGACAAAGUGAGGUGGCAUUUGUCAGAAGUUUCAAAGUCGUGUAAGAACAUUGGACUAAAGUGGUGUGCGGCAGCUGGGAGCGCUCUUUCAAUGUUAAUGUUUUAAUGUGUAUGUUGUGUUGGAAGUUCCAGGCUAAUAUUUGAUGUUUUGCUAGGUUGACUAACGAUGUUUUCUUGUAGGUGAAAGCGUUGUGUAACAAUGAUAACGGUGUUUUGGCUGGGUUUUUCCUUGUUCGCACCGGACACCUCCAGUGACCAGACGGCAAGGUUUUUAUCCCAGUGUAUAUUGGAAAAACAUGUUAUACUUUUGACAAUUUAACGU

In [215]:
def compute_precision(label, pred):
    true_positives = label * pred
    false_positives = (1 - label) * pred
    return true_positives.sum() / (true_positives.sum() + false_positives.sum())

def compute_recall(label, pred):
    true_positives = label * pred   
    false_negatives = label * (1 - pred)
    return true_positives.sum() / (true_positives.sum() + false_negatives.sum())

def compute_f1(label, pred):
    precision = compute_precision(label, pred)
    recall = compute_recall(label, pred)
    return 2 * precision * recall / (precision + recall)

def compute_confusion_matrix(label, pred):
    true_negatives = (1 - label) * (1 - pred)
    true_positives = label * pred
    false_positives = (1 - label) * pred
    false_negatives = label * (1 - pred)
    confusion_matrix = true_positives + false_positives * 2 + false_negatives * 3
    assert ((true_negatives == 1) == (confusion_matrix == 0)).all(), "True negatives are not correctly computed"
    return confusion_matrix

In [152]:
# models = ['eFold', 'E']
models = merged.model.unique()
# fig = make_subplots(cols=len(models), rows=1, shared_yaxes=True, horizontal_spacing=0.005, subplot_titles=models)

colors = np.array([[0, 0, 0], 
                    [255, 255, 255],
                    [255, 0, 0],
                    [0, 0, 255]])

itvl = 889, 974
print(merged[(merged.reference == 'PAN')].sequence.iloc[0])#[interval[0]:interval[1]])
for i, model in enumerate(models):

    rna = merged[(merged.reference == 'PAN') & (merged.model == model)].iloc[0]

    pred = pairList2pairMatrix(rna['structure_pred'], len(rna['sequence']))[itvl[0]:itvl[1], itvl[0]:itvl[1]]
    label = pairList2pairMatrix(rna['structure_true'], len(rna['sequence']))[itvl[0]:itvl[1], itvl[0]:itvl[1]]

    confusion_matrix = compute_confusion_matrix(label, pred).type(torch.int)

    fig = go.Figure()
    fig.add_trace(go.Image(z=colors[confusion_matrix]))#, row=1, col=i+1)

    fig.update_layout(height=500, width=500, title={
                                                    'text': f'{model} (F1: {compute_f1(pred, label):.2f})',
                                                    'y':0.95,
                                                    'x':0.5,
                                                    'xanchor': 'center',
                                                    'yanchor': 'top',
                                                    'font': dict(size=25, family='Time New Roman', color='black'),
                                                }, margin=dict(l=0, r=0, t=50, b=20),)
    fig.update_xaxes(tickmode="array", ticktext=np.arange(890, 980, 10), tickvals=np.arange(2, 90, 10), tickangle=-45)
    fig.update_yaxes(tickmode="array", ticktext=np.arange(890, 980, 10), tickvals=np.arange(2, 90, 10))
    fig.show()
    fig.write_image(f'images/b/{model}_confusion_matrix.pdf', width=500, height=500, scale=2)

ACUGGGACUGCCCAGUCACCUUGGCUGCCGCUUCACCUAUGGAUUUUGUGCUCGCUGCUUGCCUUCUUGCCGCUUCUGGUUUUCAUUGGUGCCGCCGAUUGUGGGUUGAUUGCGUCGCUUUUGGCAAUAUACCCAUCCUGGCUUUCGGCUAGGUUUUCCGUCCUACUUUUCCCACAUUGGCCUGAGAGCUGUAGUACAAAAAACACCGCGCGGUCUGGAGCUCUCCAUAAGCCCGCAGAACAAAAGCUGCGAUUUGCCCAAAAACCUUGCCAUGGCAACUAUACAGUCACCCCUUGCGGGUUAUUGCAUUGGAUUCAAUCUCCAGGCCAGUUGUAGCCCCCUUUUAUGAUAUGCGAGGAUACUUAACGUGUCUGAAUGUGGAAUAUAAUGUGAAAGGAAAGCAGCGCCCACUGGUGUAUCAGAACAGUGGUGCACUACCUAUCUGCUCAUUCGUUGUUUCGGUUCUGUGUUUGUCUGAUUCUUAGAUAGUGUUGAGGUAAUUCUAGAAAGCGGAUUGAGUGUAAAUCGGGCCACUUUGCCCUAAAUGUGACAAUCUGGAUGUGUAUCUUAUUGGUGCGUUGUGAAGCAUUUUAAAAUGCGUUUUAGAUUGUAUCAGGCUAGUGCUGUAAUGGUGUGUUUAUUUUUCCAGUGUAAGCAAGUCGAUUUGAAUGACAUAGGCGACAAAGUGAGGUGGCAUUUGUCAGAAGUUUCAAAGUCGUGUAAGAACAUUGGACUAAAGUGGUGUGCGGCAGCUGGGAGCGCUCUUUCAAUGUUAAUGUUUUAAUGUGUAUGUUGUGUUGGAAGUUCCAGGCUAAUAUUUGAUGUUUUGCUAGGUUGACUAACGAUGUUUUCUUGUAGGUGAAAGCGUUGUGUAACAAUGAUAACGGUGUUUUGGCUGGGUUUUUCCUUGUUCGCACCGGACACCUCCAGUGACCAGACGGCAAGGUUUUUAUCCCAGUGUAUAUUGGAAAAACAUGUUAUACUUUUGACAAUUUAACGU

In [214]:
# Create confusion matrix legend 2x2
fig = go.Figure()

fig.add_trace(go.Image(z=colors[np.array([[1, 0, 3, 2]])]))

fig.update_layout(height=500, width=500, margin=dict(l=0, r=0, t=0, b=0),)
fig.update_xaxes(tickmode="array", ticktext=['TP', 'TN', 'FN', 'FP'], tickvals=[0, 1, 2, 3], tickfont=dict(size=40), title_font=dict(size=36))
fig.update_yaxes(ticktext=None, showticklabels=False)

# add border
fig.add_shape(type="rect", x0=-0.5, x1=3.5, y0=-0.5, y1=0.5, line=dict(color="Black", width=2))


fig.show()
fig.write_image('images/b/legend_confusion_matrix.pdf')
