In [None]:
from easydl.dml.simulation import generate_2d_gaussian_points
import pandas as pd
import numpy as np
import plotly.express as px
from easydl.dml.evaluation import calculate_cosine_similarity_matrix, create_pairwise_similarity_ground_truth_matrix, evaluate_pairwise_score_matrix_with_true_label
from easydl.visualization import plot_precision_recall_vs_threshold_curve, plot_precision_vs_recall_curve

group_1 = generate_2d_gaussian_points(100, [0, 1], 0.1, random_seed=42)
group_2 = generate_2d_gaussian_points(100, [1, 0], 0.1, random_seed=43)
embeddings = np.concatenate([group_1, group_2], axis=0)

df_points = pd.DataFrame(embeddings, columns=['x', 'y'])
df_points['label'] = ['G1'] * len(group_1) + ['G2'] * len(group_2)

fig = px.scatter(df_points, x='x', y='y', color='label', width=600, height=600)
fig.show()

similarity_score_matrix = calculate_cosine_similarity_matrix(embeddings)
pairwise_similarity_ground_truth_matrix = create_pairwise_similarity_ground_truth_matrix(df_points['label'].to_numpy())


eval_metrics = evaluate_pairwise_score_matrix_with_true_label(pairwise_similarity_ground_truth_matrix, similarity_score_matrix)
print(eval_metrics)


plot_precision_recall_vs_threshold_curve(eval_metrics['precision_list'], eval_metrics['recall_list'], eval_metrics['threshold_list']).show()    
plot_precision_vs_recall_curve(eval_metrics['precision_list'], eval_metrics['recall_list']).show()



{'precision_list': array([0.49748744, 0.49751244, 0.49753744, ..., 0.85714286, 1.        ,
       1.        ]), 'recall_list': array([1.00000000e+00, 1.00000000e+00, 1.00000000e+00, ...,
       6.06060606e-04, 1.01010101e-04, 0.00000000e+00]), 'pr_auc': 0.9907913835078976, 'threshold_list': array([-0.99999994, -0.999999  , -0.99997073, ...,  0.99999994,
        1.        ,  1.0000001 ], dtype=float32), 'top1_accuracy': 0.99}
19828 19829 19829
