# Average models over different epochs

In [1]:
import os
import torch

path_models = "saved_models/models_ablation"
# read all file in path
models = set([f.split('_')[0] for f in os.listdir(path_models) if os.path.isfile(os.path.join(path_models, f))])
print(models)

for model in models:
    
    weights_final = torch.load(os.path.join(path_models, model + "_epoch20.pt"), map_location=torch.device('cpu'))
    weights = []
    epochs = [16,17,18,19,20]
    for epoch in epochs: weights.append(torch.load(os.path.join(path_models, model + f"_epoch{epoch}" + ".pt"), map_location=torch.device('cpu')))

    for key in weights_final.keys():
        weights_final[key] -= weights_final[key]
        for i in range(len(weights)):
            weights_final[key] += weights[i][key]
        weights_final[key] = weights_final[key]/float(len(weights))

    torch.save(weights_final, os.path.join(path_models, model + f"_epoch{epochs}_avg.pt"))


{'worldly-river-105', 'robust-water-107', 'royal-snowball-97', 'efficient-dream-82', 'morning-sunset-80', 'pious-dew-104', 'firm-grass-79', 'glamorous-hill-96', 'desert-paper-83', 'laced-water-90', 'logical-darkness-92', 'serene-flower-81'}


## Show distribution of structure prediction performance

Reuse same format as the algo_benchmark plot of Figure 1.

Add our algorithm and hopefully show less performance gap

**Assigned to**: Alberic

Use Ploty, and a white background

In [2]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np
import os

In [3]:
def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()

    sum_pair = torch.sum(pred_matrix) + torch.sum(target_matrix)

    if sum_pair == 0:
        return 1.0
    else:
        return (2 * torch.sum(pred_matrix * target_matrix) / sum_pair).item()
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [4]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII_blast", "viral_fragments", "lncRNA_nonFiltered"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII_blast: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII_blast: Download complete. File saved at data/archiveII_blast/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json
lncRNA_nonFiltered: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA_nonFiltered: Download complete. File saved at data/lncRNA_nonFiltered/data.json


In [5]:
data_comparison = pd.DataFrame()

path_data = 'saved_data_plot/results_main'

for model in os.listdir(path_data):

    prediction = pd.read_feather(os.path.join(path_data, model))

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))
    merged = merged[merged['dataset']!= 'lncRNA']

    f1s = []
    for i, row in merged.iterrows():
        f1s.append(compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                            pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5))

    model = model.split('.feather')[0].split('_')[-1]
    print(model)
    merged['f1'] = f1s
    merged['model'] = model
    print(merged.groupby('dataset')['f1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'sequence', 'model', 'dataset', 'f1', 'structure_true', 'structure_pred']]])

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA_nonFiltered', 'dataset'] = 'Long ncRNA'
data_comparison.loc[data_comparison['dataset']=='archiveII_blast', 'dataset'] = 'ArchiveII'

ribonanza
dataset
PDB                   0.849423
archiveII_blast       0.492899
lncRNA_nonFiltered    0.298363
viral_fragments       0.640913
Name: f1, dtype: float64
RNAstralign
dataset
PDB                   0.728857
archiveII_blast       0.664344
lncRNA_nonFiltered    0.019084
viral_fragments       0.148102
Name: f1, dtype: float64
bpRNA
dataset
PDB                   0.816722
archiveII_blast       0.647147
lncRNA_nonFiltered    0.183122
viral_fragments       0.332785
Name: f1, dtype: float64
PT-UFold
dataset
PDB                   0.843829
archiveII_blast       0.627999
lncRNA_nonFiltered    0.301590
viral_fragments       0.644054
Name: f1, dtype: float64
primiRNA
dataset
PDB                   0.900898
archiveII_blast       0.666530
lncRNA_nonFiltered    0.407392
viral_fragments       0.738163
Name: f1, dtype: float64
PT+FT
dataset
PDB                   0.889730
archiveII_blast       0.635189
lncRNA_nonFiltered    0.451463
viral_fragments       0.729952
Name: f1, dtype: float64
RNAcen

In [6]:
result_algos = pd.read_feather('../Figure1/saved_data_plot/results_benchmark_algos.feather')
result_algos = result_algos[result_algos['dataset']!= 'lncRNA']
result_algos.loc[result_algos['dataset']=='viral_fragments', 'dataset'] = 'Viral mRNA'
result_algos.loc[result_algos['dataset']=='lncRNA_nonFiltered', 'dataset'] = 'Long ncRNA'
result_algos.loc[result_algos['dataset']=='archiveII_blast', 'dataset'] = 'ArchiveII'

In [7]:
# Creat a box plot with plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
fig = make_subplots(rows=1, cols=1, shared_xaxes=True)

colors = ['#444444', '#888888', "rgb(148, 103, 189)", '#FF3333']
    
# UFold results
for i, model in enumerate(['UFold', 'MXFold2']):
    results_dataset = result_algos[result_algos['model']==model]
    fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['F1'], 
                            name=model, marker_color=colors[i], 
                            meanline_visible=True, points=False))

for i, model in enumerate(['PT', 'PT+FT']):
    results_dataset = data_comparison[data_comparison['model']==model]
    fig.add_trace(go.Violin(x=results_dataset['dataset'], y=results_dataset['f1'], 
                            name=model, marker_color=colors[i+2], 
                            meanline_visible=True, points=False))
    
fig.update_layout(
                # title='F1 score on test sets for pretraining vs pretraining + fine-tuning models', 
                yaxis_title='F1 score', xaxis_title='',
                  violinmode='group', yaxis_range=[0, 1],
                  width=1200, height=400,
                  template='plotly_white', font_size=18, font_color='black', font_family='times new roman',)
# fig.update_xaxes(categoryorder='array', categoryarray= ['RNAstructure', 'EternaFold', 'MXFold2', 'UFold'])

# make legend horizontal
fig.update_layout(legend_orientation="h", legend=dict(x=0.65, y=1.15), width=1000, height=400)
fig.show()

In [10]:
# save pdf
fig.write_image("images/a_structure_performance.pdf")