In [2]:
import sys
sys.path.append('../../eFold')

import efold.core as core

import pandas as pd
import numpy as np
from rouskinhf import get_dataset
import torch

In [3]:
ground_truth = pd.DataFrame()

for test_set in ["PDB", "archiveII_blast", "viral_fragments", "lncRNA_nonFiltered"]:
    data = get_dataset(test_set, force_download=False)
    data = pd.DataFrame(data).T[['sequence', 'structure']]
    data['dataset'] = test_set

    ground_truth = pd.concat([ground_truth, data])
    del data

In [4]:
def ratio_nonCanonical(sequence, structure):

    if len(structure) == 0: return 0

    canonical_pairs = set(['AU', 'UA', 'GC', 'CG', 'GU', 'UG'])

    non_canonical = [ ''.join([sequence[i] for i in pair]) not in canonical_pairs for pair in structure]

    return sum(non_canonical)/len(structure)



def ratio_sharpLoops(structure, min_dist=3):

    if len(structure) == 0: return 0

    sharp_loops = [ np.abs(pair[1]-pair[0])<=min_dist for pair in structure ]

    return sum(sharp_loops).item()/len(structure)

sequence = 'AUGAC'
structure = [[0,2], [0,1]]

ratio_nonCanonical(sequence, structure), ratio_sharpLoops(structure)

(0.5, 1.0)

In [5]:
def ListofPairs2pairMatrix(pairs, length):
    matrix = torch.zeros((length, length))

    if len(pairs) == 0: return matrix
    matrix[pairs[:,0], pairs[:,1]] = 1
    matrix[pairs[:,1], pairs[:,0]] = 1

    return matrix.int()


def compute_f1(pred_matrix, target_matrix, threshold=0.5):
    """
    Compute the F1 score of the predictions.

    :param pred_matrix: Predicted pairing matrix probability  (L,L)
    :param target_matrix: True binary pairing matrix (L,L)
    :return: precision, recall F1 score for this RNA structure
    """

    pred_matrix = (pred_matrix > threshold).float()


    TP = torch.sum(pred_matrix*target_matrix)
    PP = torch.sum(pred_matrix)
    P = torch.sum(target_matrix)
    sum_pair = PP + P

    if sum_pair == 0:
        return [1.0, 1.0, 1.0]
    else:
        return [
                (TP / PP).item(),
                (TP / P).item(),
                (2 * TP / sum_pair).item()
                ]

def compute_confusion_matrix(label, pred):
    true_negatives = (1 - label) * (1 - pred)
    true_positives = label * pred
    false_positives = (1 - label) * pred
    false_negatives = label * (1 - pred)
    confusion_matrix = true_positives + false_positives * 2 + false_negatives * 3
    assert ((true_negatives == 1) == (confusion_matrix == 0)).all(), "True negatives are not correctly computed"
    return confusion_matrix

In [6]:
ground_truth['non_canonical'] = ground_truth.apply(lambda x: ratio_nonCanonical(x['sequence'], x['structure']), axis=1)
ground_truth['sharp_loops'] = ground_truth.apply(lambda x: ratio_sharpLoops(x['structure']), axis=1)
ground_truth[ground_truth['sharp_loops'] >0]
ground_truth['pairing_matrix'] = ground_truth.apply(lambda x: ListofPairs2pairMatrix(np.array(x['structure']), len(x['sequence'])), axis=1)

In [7]:
import time
from tqdm import tqdm
from efold import inference

eFold_processed = pd.DataFrame()

thresholds = [0.5]

for threshold in thresholds:
    print(f"Threshold: {threshold}")
    postprocesser = core.Postprocess(threshold=threshold, canonical_only=True, min_hairpin_length=3)

    Precisions = []
    Recalls = []
    F1s = []
    predictions = []
    dTs = []

    for idx, row in tqdm(ground_truth.iterrows(), total=len(ground_truth)):
        true_structure = torch.tensor(row['structure'])
        sequence = row['sequence']

        t0 = time.time()
        prediction = torch.tensor(inference(sequence, fmt='bp')[sequence])-1
        dT = time.time() - t0

        precision, recall, f1 = compute_f1(ListofPairs2pairMatrix(prediction, len(sequence)), 
                                           ListofPairs2pairMatrix(true_structure, len(sequence)))

        Precisions.append(precision)
        Recalls.append(recall)
        F1s.append(f1)
        predictions.append(prediction)
        dTs.append(dT)

    eFold_processed = pd.concat([eFold_processed, pd.DataFrame({'reference': ground_truth.index, 'sequence': ground_truth['sequence'],
                                                                'threshold': threshold,
                                                                'precision': Precisions, 'recall': Recalls, 'f1': F1s, 'dT': dTs,
                                                                'structure': predictions})], axis=0)
# Add dataset name
eFold_processed = eFold_processed.merge(ground_truth.reset_index().rename(columns={'index':'reference'})[['reference', 'sequence', 'dataset']], on=['reference', 'sequence'])
# eFold_processed['basePairs'] = eFold_processed['structure'].apply(lambda x: torch.unique(torch.sort(torch.stack(torch.where(x>0)).T, dim=1)[0], dim=0))

Threshold: 0.5


  0%|          | 0/781 [00:00<?, ?it/s]

100%|██████████| 781/781 [05:48<00:00,  2.24it/s]


In [10]:
eFold_processed['non_canonical'] = eFold_processed.apply(lambda x: ratio_nonCanonical(x['sequence'], x['structure']), axis=1)
eFold_processed['sharp_loops'] = eFold_processed.apply(lambda x: ratio_sharpLoops(x['structure']), axis=1)
eFold_processed['pairing_matrix'] = eFold_processed.apply(lambda x: ListofPairs2pairMatrix(x['structure'], len(x['sequence'])), axis=1)
eFold_processed['multiPairs'] = eFold_processed['pairing_matrix'].apply(lambda x: (x.sum(axis=0) >1).sum().item()/len(x) )
eFold_processed['length'] = eFold_processed['sequence'].apply(len)

eFold_processed.groupby('threshold')[['non_canonical', 'sharp_loops', 'multiPairs']].mean()

Unnamed: 0_level_0,non_canonical,sharp_loops,multiPairs
threshold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0.5,0.0,0.0,0.0


In [11]:
# Group the data by model and dataset and calculate the mean for each group
grouped = eFold_processed.groupby(['threshold', 'dataset']).mean(numeric_only=True).reset_index()

# Pivot the table to create a multi-level column structure
pivot_df = pd.pivot_table(grouped, index='threshold', columns='dataset', values=['precision', 'recall', 'f1'])

# Swap the level of the columns to have dataset as the top level and the metrics as the second level
pivot_df = pivot_df.swaplevel(i=0, j=1, axis=1).sort_index(axis=1)

# Define the new order for the models and reorder the rows
# new_order = ['SimpleThreshold', 'HungarianAlgorithm', 'UFold_processing', 'OptimalProcessing']
# pivot_df = pivot_df.reindex(new_order)

pivot_df = pivot_df.reindex(columns=pivot_df.columns.reindex(['precision', 'recall', 'f1'], level=1)[0])[['PDB', 'archiveII_blast', 'viral_fragments', 'lncRNA_nonFiltered']]

pivot_df = pivot_df.style\
            .format(precision=3)\
            .highlight_max(axis=0, props="font-weight:bold;font-color:black;")\
            .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap="viridis", text_color_threshold=0)\
            .set_properties(**{'text-align': 'center'})\
            .set_table_styles(
                        [{"selector": "th", "props": [('text-align', 'center')]},
                        ])
pivot_df

dataset,PDB,PDB,PDB,archiveII_blast,archiveII_blast,archiveII_blast,viral_fragments,viral_fragments,viral_fragments,lncRNA_nonFiltered,lncRNA_nonFiltered,lncRNA_nonFiltered
Unnamed: 0_level_1,precision,recall,f1,precision,recall,f1,precision,recall,f1,precision,recall,f1
threshold,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
0.5,0.89,0.907,0.88,0.598,0.688,0.635,0.721,0.748,0.733,0.461,0.448,0.453
