## Ablation study of our final best model, on structure prediction

Show performance difference (Delta of F1 score) of different versions of our model, vs our best model (baseline, ∆F1=0)

- Is pretraining useful (Ribo, ArchiveII, bpRNA, RNAstralign, zuber, synthetic)
- Is finetuning useful (UTR, pri-miRNA, human_mRNA fragments?)
- Is training on everything at once better ?
- Is is useful to have a SHAPE/DMS head ?
- Pearson loss for SHAPE/DMS vs MSE
- model architecture (Evoformer, Transformer, CNN)

The test set could be a weighted average of the tree standard set (PDB, viral_fragments, lncRNA) if there is no outliers


**Assigned to**: Alberic

Use Ploty, and a white background

In [2]:
import pandas as pd
from rouskinhf import get_dataset
import torch
import numpy as np

def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()

    sum_pair = torch.sum(pred_matrix) + torch.sum(target_matrix)

    if sum_pair == 0:
        return 1.0
    else:
        return (2 * torch.sum(pred_matrix * target_matrix) / sum_pair).item()
    
def pairList2pairMatrix(pair_list, len_seq):
    pair_list = np.array(pair_list).astype(int)
    pairing_matrix = torch.zeros((len_seq, len_seq))

    if len(pair_list) > 0:
        pairing_matrix[pair_list[:,0], pair_list[:,1]] = 1.0
        pairing_matrix[pair_list[:,1], pair_list[:,0]] = 1.0

    return pairing_matrix

In [127]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "lncRNA", "viral_fragments"]:
    data = get_dataset(test_set, force_download=True)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

PDB: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

PDB: Download complete. File saved at data/PDB/data.json
lncRNA: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

lncRNA: Download complete. File saved at data/lncRNA/data.json
viral_fragments: Downloading dataset from HuggingFace Hub...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

viral_fragments: Download complete. File saved at data/viral_fragments/data.json


In [128]:
data_comparison = pd.DataFrame()

for model in ['PT', 'PT+FT', 'UFoldPT', 'PT+mRNA_FT', 'PT+primiRNA_FT']:

    prediction = pd.read_feather(f'data/test_results_{model}.feather')

    merged = ground_truth.reset_index().rename(columns={'index': 'reference'}).merge(prediction, on=['sequence', 'reference'], suffixes=('_true', '_pred'))

    f1s = []
    for i, row in merged.iterrows():
        f1s.append(compute_f1(torch.tensor(np.stack(row['structure_pred'])), 
                            pairList2pairMatrix(row['structure_true'], len(row['sequence'])), threshold=0.5))

    merged['f1'] = f1s
    merged['model'] = model
    print(model)
    print(merged.groupby('dataset')['f1'].mean())

    data_comparison = pd.concat([data_comparison, merged[['reference', 'sequence', 'model', 'dataset', 'f1']]])

data_comparison.loc[data_comparison['dataset']=='viral_fragments', 'dataset'] = 'viral mRNA'
data_comparison.loc[data_comparison['dataset']=='lncRNA', 'dataset'] = 'long ncRNA'

data_comparison.loc[data_comparison['model']=='PT+FT', 'model'] = 'baseline'
data_comparison.loc[data_comparison['model']=='PT', 'model'] = 'no finetuning'
data_comparison.loc[data_comparison['model']=='UFoldPT', 'model'] = 'no finetuning, <br>with UFold architecture'
data_comparison.loc[data_comparison['model']=='PT+mRNA_FT', 'model'] = 'finetune on mRNA'
data_comparison.loc[data_comparison['model']=='PT+primiRNA_FT', 'model'] = 'finetune on pri-miRNA'

PT
dataset
PDB                0.846920
lncRNA             0.315571
viral_fragments    0.527542
Name: f1, dtype: float64
PT+FT
dataset
PDB                0.879710
lncRNA             0.368814
viral_fragments    0.576445
Name: f1, dtype: float64
UFoldPT
dataset
PDB                0.829544
lncRNA             0.073930
viral_fragments    0.491223
Name: f1, dtype: float64
PT+mRNA_FT
dataset
PDB                0.841140
lncRNA             0.364154
viral_fragments    0.559930
Name: f1, dtype: float64
PT+primiRNA_FT
dataset
PDB                0.857658
lncRNA             0.278404
viral_fragments    0.544281
Name: f1, dtype: float64


In [133]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

models = [
            'baseline',
            'finetune on mRNA',
            'finetune on pri-miRNA',
            'no finetuning',
            'no finetuning, <br>with UFold architecture',
          ]

baseline = data_comparison[data_comparison['model']=='baseline'].groupby('dataset')['f1'].mean()

fig = make_subplots(rows=len(models), cols=len(data_comparison['dataset'].unique()), 
                    subplot_titles=([f'∆F1 score on {dataset}' for dataset in data_comparison['dataset'].unique()]), shared_yaxes=True, shared_xaxes=True)

colors = px.colors.qualitative.D3
for i, model in enumerate(models):
    for j, dataset in enumerate(data_comparison['dataset'].unique()):
        
        data = data_comparison[(data_comparison['model']==model) & (data_comparison['dataset']==dataset)]

        # fig.add_trace(go.Box(x=data['f1'] - baseline[dataset], #meanline_visible=True, 
        #                      boxpoints=False,
        #                      name=model, marker_color=colors[i], showlegend=False), row=1+i, col=1+j)
        
        fig.add_trace(go.Scatter(x=[data['f1'].mean() - baseline[dataset]], y=[model], 
                                 error_x=dict(type='data', array=[data['f1'].std()]),
                                 mode='markers', marker_color=colors[i], showlegend=False), row=1+i, col=1+j)
        

fig.update_xaxes(range=[-0.5, 0.2])
fig.update_layout(
                # title= 'Ablation study', 
                # yaxis_title='F1 score', #xaxis_title='test set',
                    # xaxis_title='F1 score difference to baseline',
                    # xaxis_range=[-1, 1],
                    height=len(models)*130, width=1100,
                    template='plotly_white', 
                    font_size=15, font_color='black',)

fig.show()

In [None]:
# save pdf
fig.write_image("images/b_ablation_study.pdf")