In [None]:
import pandas as pd
import rouskinhf as rhf
import numpy as np
import torch
%load_ext autoreload

bppm = pd.read_feather("test_results_PT+FT.feather").rename(columns={'structure': 'bppm'})
truth = pd.concat({
    dataset:pd.DataFrame.from_dict(rhf.get_dataset(dataset), orient='index') for dataset in ['PDB', 'viral_fragments','lncRNA', 'archiveII_blast']}
                  ).reset_index().rename(columns={'level_0':'dataset', 'structure': 'label'})[['dataset','sequence', 'label']]

df  = pd.merge(bppm, truth, on='sequence', how='outer')
df['bppm'] = df['bppm'].apply(lambda x: torch.tensor(np.stack(x)))
df.head()

In [None]:
from algos import *
from util import *

idx +=1
plt.figure(figsize=(12, 10))

ax = plt.subplot(2,2,1)
ha = HungarianAlgo(label=bp_to_matrix(df['label'].iloc[idx], len(df['sequence'].iloc[idx]))).run(bppm=df['bppm'].iloc[idx], threshold=0.5)
ha.plot_confusion_matrix(ax=ax)

ax = plt.subplot(2,2,2)
ufold = UFoldAlgo(label=bp_to_matrix(df['label'].iloc[idx], len(df['sequence'].iloc[idx]))).run(bppm=df['bppm'].iloc[idx], sequence=df['sequence'].iloc[idx])
ufold.plot_confusion_matrix(ax=ax)

ax = plt.subplot(2,2,3)
ha.plot_bppm(ax=ax)

ax = plt.subplot(2,2,4)
ribo = RibonanzaNetAlgo(label=bp_to_matrix(df['label'].iloc[idx], len(df['sequence'].iloc[idx])))\
    .run(bppm=df['bppm'].iloc[idx], 
         prob_to_0_threshold_prior=float("-inf"), prob_to_1_threshold_prior=float("inf"),
         sigmoid_slope_factor=1, 
        add_p_unpaired=False,
         theta=0.7)
ribo.plot_confusion_matrix(ax=ax)

plt.tight_layout()

In [456]:
# Compute F1 scores for all datapoints
import tqdm

out = []
for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)):
    label, sequence, bppm = row['label'], row['sequence'], row['bppm']
    label = bp_to_matrix(label, len(sequence))
    ha = HungarianAlgo(label=label).run(bppm=bppm, threshold=0.7)
    ufold = UFoldAlgo(label=label).run(bppm=bppm, sequence=sequence)
    ribo = RibonanzaNetAlgo(label=label)\
        .run(bppm=row['bppm'], 
             prob_to_0_threshold_prior=float("-inf"), prob_to_1_threshold_prior=float("inf"),
             sigmoid_slope_factor=1, 
             add_p_unpaired=False,
             theta=0.7)
    
    for name, algo in zip(['hungarian', 'ufold', 'ribonanza'], [ha, ufold, ribo]):
        out.append({
            'dataset': row['dataset'],
            'reference': row['reference'],
            'algo': name,
            'f1': algo.f1,
            'precision': algo.precision,
            'recall': algo.recall,
        })

out = pd.DataFrame(out)
out.to_csv("algo_comparison.csv", index=False)
    

100%|██████████| 802/802 [09:30<00:00,  1.40it/s]


In [467]:
import plotly.express as px

median_table = out.groupby(['dataset', 'algo'])[["f1", "precision", "recall"]].median().reset_index()
print(median_table)

for metric in ["f1", "precision", "recall"]:
    fig = px.violin(out, y=metric, x="dataset", color="algo", box=True)
    fig.update_layout(title=metric)
    fig.show()

            dataset       algo        f1  precision    recall
0               PDB  hungarian  0.951600   1.000000  0.960000
1               PDB  ribonanza  0.955534   1.000000  1.000000
2               PDB      ufold  0.947368   1.000000  1.000000
3   archiveII_blast  hungarian  0.643875   0.752381  0.586582
4   archiveII_blast  ribonanza  0.626866   0.619048  0.641782
5   archiveII_blast      ufold  0.671498   0.729996  0.625000
6            lncRNA  hungarian  0.433333   0.498299  0.390135
7            lncRNA  ribonanza  0.426610   0.347676  0.551913
8            lncRNA      ufold  0.443114   0.467687  0.465882
9   viral_fragments  hungarian  0.663744   0.719444  0.616570
10  viral_fragments  ribonanza  0.605847   0.552778  0.659165
11  viral_fragments      ufold  0.690598   0.736332  0.658956
