In [73]:
import pandas as pd
import joblib 
import numpy as np
import os
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
from sklearn.preprocessing import LabelEncoder

In [74]:
all_predictions = os.listdir('cache/predictions/all')

In [75]:
all_predictions = [x for x in all_predictions if x.startswith('outputs_dict_')]

In [76]:
data = []
for prediction in all_predictions:
    try:
        outputs_dict = joblib.load('cache/predictions/all/' + prediction)
        if 'meta' in outputs_dict.keys():
            data.append(outputs_dict)
    except Exception as e:
        print(e)
        continue





In [97]:
new_data = [
    data_point
    for data_point
    in data
    if data_point["meta"]["data_dir"] == "data/finegrained"
]
len(new_data)

135

In [101]:
[
    data_point 
    for data_point in new_data
    if data_point['predictions']['metrics']['test_f1'] > 0.642
][0]['meta']

{'batch_size': 16,
 'cbr_threshold': -10000000,
 'classifier_dropout': 0.3,
 'data_dir': 'data/finegrained',
 'learning_rate': 7.484147412800621e-05,
 'num_cases': 1,
 'num_epochs': 15,
 'predictions_dir': 'cache/predictions/all',
 'retrievers': ['simcse'],
 'weight_decay': 0.00984762513370293,
 'return_dict': True,
 'output_hidden_states': False,
 'output_attentions': False,
 'torchscript': False,
 'torch_dtype': None,
 'use_bfloat16': False,
 'tf_legacy_loss': False,
 'pruned_heads': {},
 'tie_word_embeddings': True,
 'is_encoder_decoder': False,
 'is_decoder': False,
 'cross_attention_hidden_size': None,
 'add_cross_attention': False,
 'tie_encoder_decoder': False,
 'max_length': 20,
 'min_length': 0,
 'do_sample': False,
 'early_stopping': False,
 'num_beams': 1,
 'num_beam_groups': 1,
 'diversity_penalty': 0.0,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'typical_p': 1.0,
 'repetition_penalty': 1.0,
 'length_penalty': 1.0,
 'no_repeat_ngram_size': 0,
 'encoder_no_repeat_ngr

In [80]:
def get_metrics(y_true, y_pred):
    return {
        'f1': f1_score(y_true, y_pred, average = "weighted"),
        'precision': precision_score(y_true, y_pred, average = "weighted"),
        'recall': recall_score(y_true, y_pred, average = 'weighted'),
        'accuracy': accuracy_score(y_true, y_pred)
    }

In [81]:
def get_overlap(sample_data):
    label_encoder = sample_data["label_encoder"]
    
    labels = sample_data["predictions"]["label_ids"]
    
    cbr_labels = sample_data["cbr_labels"]
    
    all_cbr_labels = []
    for sample_cbr_labels in cbr_labels:
        unfolded_sample_cbr_labels = []
        for retriever_cbr_samples in sample_cbr_labels:
            for filtered_cbr_sample in retriever_cbr_samples:
                for inner_label in filtered_cbr_sample:
                    unfolded_sample_cbr_labels.append(inner_label)
        all_cbr_labels.append(unfolded_sample_cbr_labels)
        
    all_cbr_labels = [
        label_encoder.transform(cbr_labels)
        for cbr_labels
        in all_cbr_labels
    ]
    
    predicted_labels = np.argmax(sample_data["predictions"]["predictions"], axis = -1)
    correct_predictions = np.where(labels == predicted_labels)[0]

    overlap_count = 0
    for index in correct_predictions:
        if labels[index] in all_cbr_labels[index]:
            overlap_count += 1
    return overlap_count / len(correct_predictions)

In [82]:
total_num_cases = []
total_retrievers = []
total_overlaps = []
total_thresholds = []
total_f1_scores = []
total_precisions = []
total_recalls = []
total_accuracies = []
for sample_data in new_data:
    total_num_cases.append(sample_data['meta']['num_cases'])
    total_retrievers.append(' '.join(sample_data['meta']["retrievers"]))
    total_overlaps.append(get_overlap(sample_data))
    total_thresholds.append(sample_data['meta']['cbr_threshold'])
    total_f1_scores.append(sample_data['predictions']['metrics']['test_f1'])
    total_precisions.append(sample_data['predictions']['metrics']['test_precision'])
    total_recalls.append(sample_data['predictions']['metrics']['test_recall'])
    total_accuracies.append(sample_data['predictions']['metrics']['test_accuracy'])
    
results_df = pd.DataFrame({
    'num_cases': total_num_cases,
    'threshold': total_thresholds,
    'retrievers': total_retrievers,
    'overlaps': total_overlaps,
    'f1': total_f1_scores,
    'precision': total_precisions,
    'recall': total_recalls,
    'accuracy': total_accuracies
})


In [83]:
print(results_df.shape)
results_df.head(3)

(70, 8)


Unnamed: 0,num_cases,threshold,retrievers,overlaps,f1,precision,recall,accuracy
0,4,0.8,simcse,0.119171,0.640093,0.646796,0.643333,0.643333
1,1,0.8,empathy,0.324324,0.609747,0.624443,0.616667,0.616667
2,3,0.8,coarse,0.539394,0.52997,0.547138,0.55,0.55


In [89]:
results_df[results_df['f1'] > 0.63]

Unnamed: 0,num_cases,threshold,retrievers,overlaps,f1,precision,recall,accuracy
0,4,0.8,simcse,0.119171,0.640093,0.646796,0.643333,0.643333
13,1,-10000000.0,simcse,0.525773,0.643005,0.64609,0.646667,0.646667
56,1,0.5,empathy,0.329843,0.63013,0.646402,0.636667,0.636667


In [24]:
results_df.groupby('num_cases')[['overlaps', 'f1', 'precision', 'recall']].mean()

Unnamed: 0_level_0,overlaps,f1,precision,recall
num_cases,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,0.354736,0.874317,0.877919,0.874405
3,0.727254,0.809894,0.823409,0.810823
4,0.754629,0.799139,0.814792,0.80188
5,0.814431,0.830661,0.841763,0.831548


In [25]:
results_df.groupby('threshold')[['overlaps', 'f1', 'precision', 'recall']].mean()

Unnamed: 0_level_0,overlaps,f1,precision,recall
threshold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
-10000000.0,0.786584,0.812232,0.827292,0.814399
0.5,0.822212,0.802428,0.818328,0.803968
0.8,0.475377,0.850771,0.856269,0.851247


In [26]:
results_df.groupby('retrievers')[['overlaps', 'f1', 'precision', 'recall']].mean()

Unnamed: 0_level_0,overlaps,f1,precision,recall
retrievers,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
empathy,0.830853,0.809991,0.824639,0.811993
simcse,0.5161,0.84009,0.848499,0.840829
