## Ablation study of our final best model, on structure prediction

Show performance difference (Delta of F1 score) of different versions of our model, vs our best model (baseline, ∆F1=0)

- Is pretraining useful (Ribo, ArchiveII, bpRNA, RNAstralign, zuber, synthetic)
- Is finetuning useful (UTR, pri-miRNA, human_mRNA fragments?)
- Is training on everything at once better ?
- Is is useful to have a SHAPE/DMS head ?
- Pearson loss for SHAPE/DMS vs MSE
- model architecture (Evoformer, Transformer, CNN)

The test set could be a weighted average of the tree standard set (PDB, viral_fragments, lncRNA) if there is no outliers


**Assigned to**: Alberic

Use Ploty, and a white background

In [2]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np
import os

def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()

    sum_pair = torch.sum(pred_matrix) + torch.sum(target_matrix)

    if sum_pair == 0:
        return 1.0
    else:
        return (2 * torch.sum(pred_matrix * target_matrix) / sum_pair).item()
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [3]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII_blast", "lncRNA", "viral_fragments"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
archiveII_blast: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

archiveII_blast: Download complete. File saved at data/archiveII_blast/data.json
lncRNA: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA: Download complete. File saved at data/lncRNA/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json


In [4]:
data_comparison = pd.DataFrame()

path_data = '../Figure5/saved_data_plot/results_main'

for model in os.listdir(path_data):

    prediction = pd.read_feather(os.path.join(path_data, model))

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))

    f1s = []
    for i, row in merged.iterrows():
        f1s.append(compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                            pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5))

    model = model.split('.feather')[0].split('_')[-1]
    print(model)
    merged['f1'] = f1s
    merged['model'] = model
    print(merged.groupby('dataset')['f1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'sequence', 'model', 'dataset', 'f1']]])

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA', 'dataset'] = 'long ncRNA'
data_comparison.loc[data_comparison['dataset']=='archiveII_blast', 'dataset'] = 'archiveII'

data_comparison.loc[data_comparison['model']=='PT+FT', 'model'] = 'baseline<br> (PT+FT)'
data_comparison.loc[data_comparison['model']=='PT', 'model'] = 'no finetuning<br> (PT)'
data_comparison.loc[data_comparison['model']=='PT-UFold', 'model'] = 'no finetuning, with <br> UFold architecture'
data_comparison.loc[data_comparison['model']=='PT+FT-mRNA', 'model'] = 'finetune on mRNA'
data_comparison.loc[data_comparison['model']=='PT+FT-primiRNA', 'model'] = 'finetune on pri-miRNA'

PT+FT-primiRNA
dataset
PDB                0.900898
archiveII_blast    0.666530
lncRNA             0.399686
viral_fragments    0.696693
Name: f1, dtype: float64
ribonanza
dataset
PDB                0.849423
archiveII_blast    0.492899
lncRNA             0.295483
viral_fragments    0.605082
Name: f1, dtype: float64
RNAstralign
dataset
PDB                0.728857
archiveII_blast    0.664344
lncRNA             0.019434
viral_fragments    0.124726
Name: f1, dtype: float64
bpRNA
dataset
PDB                0.816722
archiveII_blast    0.647147
lncRNA             0.192751
viral_fragments    0.302771
Name: f1, dtype: float64
PT-UFold
dataset
PDB                0.843829
archiveII_blast    0.627999
lncRNA             0.262879
viral_fragments    0.609994
Name: f1, dtype: float64
PT+FT-mRNA
dataset
PDB                0.890331
archiveII_blast    0.611764
lncRNA             0.423370
viral_fragments    0.660925
Name: f1, dtype: float64
PT+FT
dataset
PDB                0.889730
archiveII_blast    0.6351

In [6]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

models = [
            'baseline<br> (PT+FT)',
            'finetune on mRNA',
            'finetune on pri-miRNA',
            'no finetuning<br> (PT)',
            'no finetuning, with <br> UFold architecture',
          ]

datasets = ['PDB', "archiveII", 'viral mRNA', 'long ncRNA']

baseline = data_comparison[data_comparison['model']=='baseline<br> (PT+FT)'].groupby('dataset')['f1'].mean()

fig = make_subplots(rows=len(models), cols=len(datasets), 
                    subplot_titles=([f'{dataset}' for dataset in datasets]), shared_yaxes=True, shared_xaxes=True,
                    vertical_spacing=0,
                    horizontal_spacing=0.08
                    )
palette = px.colors.qualitative.Pastel
colors = {'baseline<br> (PT+FT)': palette[1],
          'finetune on mRNA': palette[7],
          'finetune on pri-miRNA': palette[4],
          'no finetuning<br> (PT)': palette[0],
          'no finetuning, with <br> UFold architecture': palette[9]}

for i, model in enumerate(models):
    for j, dataset in enumerate(datasets):
        
        data = data_comparison[(data_comparison['model']==model) & (data_comparison['dataset']==dataset)]

        # fig.add_trace(go.Box(x=data['f1'] - baseline[dataset], #meanline_visible=True, 
        #                      boxpoints=False,
        #                      name=model, marker_color=colors[i], showlegend=False), row=1+i, col=1+j)
        
        fig.add_trace(go.Scatter(x=[data['f1'].mean() - baseline[dataset]], y=[model], 
                                 error_x=dict(type='data', array=[data['f1'].std()], thickness=3),
                                 mode='markers', marker_color=colors[model], marker_size=7,
                                 showlegend=False), row=1+i, col=1+j)
        
fig.update_annotations(font_size=20)
fig.update_xaxes(range=[-0.4, 0.25], mirror=True,)
fig.update_layout(
                    xaxis=dict(
                        tickfont=dict(size=1)  # Set the desired font size here
                    ),
                    height=len(models)*100 +200, width=1000,
                    template='plotly_white', 
                    font_size=20, font_color='black', font_family='helvetica light')
fig.update_layout(
xaxis = dict(
tickfont = dict(size=5)))
fig.show()

In [8]:
# save pdf
fig.write_image("images/S4/a_ablation_study.pdf")