## Figure 2 and table 2 reproduction

In this notebook, we will reproduce Figure 2, Wasserstein distances between events.   
    
<img src="paper_figures/events_diffs.png" alt="Table 2" width="400"/>

<img src="paper_figures/table2.png" alt="Table 2" width="800"/>

In [1]:
import numpy as np
import os
import re
import pandas as pd
import ot
import plotly.graph_objects as go
import itertools
from tqdm import tqdm

In [2]:
def delete_timestamp(df1, df2):
    cols1 = df1.columns
    if "Timestamp" in cols1:
        df1 = df1.drop(["Timestamp"], axis=1)

    cols2 = df2.columns
    if "Timestamp" in cols2:
        df2 = df2.drop(["Timestamp"], axis=1)
    return df1, df2

def extract_number(filename):
    match = re.search(r'(\d+)', filename)
    if match:
        return int(match.group(1))
    return float('inf')  

In [3]:
folder_path = '../data/MESWE-38-processed'
file_names = [file_name for file_name in os.listdir(folder_path) if file_name.endswith('.csv')]
sorted_csv_file_names = sorted(file_names, key=extract_number)
sorted_csv_file_names_numbers = [filename.replace("meswe_event_", "").replace(".csv", "") for filename in sorted_csv_file_names]

In [4]:
distance_matrix =  np.zeros((len(sorted_csv_file_names), len(sorted_csv_file_names)))

In [None]:

for i in range(len(sorted_csv_file_names)):
    print(i)
    for j in range(len(sorted_csv_file_names)):

        #print(f"{sorted_csv_file_names[i]} - {sorted_csv_file_names[j]}")

        data1 = pd.read_csv(folder_path + "/" + sorted_csv_file_names[i])
        data2 = pd.read_csv(folder_path + "/" + sorted_csv_file_names[j])

        data1, data2 = delete_timestamp(data1, data2)

        data1_np = data1.values
        data2_np = data2.values

        M = ot.dist(data1_np, data2_np)

        weights_P = np.ones(data1_np.shape[0]) / data1_np.shape[0]
        weights_Q = np.ones(data2_np.shape[0]) / data2_np.shape[0]

        wasserstein_dist = ot.emd2(weights_P, weights_Q, M)

        distance_matrix[i, j] = wasserstein_dist
        distance_matrix[j, i] = wasserstein_dist




In [6]:
fig = go.Figure(data=go.Heatmap(
        z=distance_matrix,
        x=sorted_csv_file_names_numbers,
        y=sorted_csv_file_names_numbers,
        colorscale='Blues',
        text=distance_matrix, 
        hoverinfo="text"
    ))


fig.update_layout(
    title="Wasserstein Distance Matrix Between Events",
    xaxis_title="Events",
    yaxis_title="Events",
    xaxis_nticks=len(sorted_csv_file_names_numbers),  
    yaxis_nticks=len(sorted_csv_file_names_numbers),
    width=800,   
    height=800   
)

fig.show()

## Table 2 reproduction

### Plain Language Instructions
1. We obtain Waseterian distance for every event with respect to every other events (basically figure 2)
2. WDS(Event) is sum of Waseterian distance of one event with respect to every other events, devided by 100 - just have smaller numbers
3. Valid set consists of 3 events, where 1 event is from CIR category, 1 event is from CME category and 1 is from Unclassified category
4. Same for test set, test set consists of 3 events, where 1 event is from CIR category, 1 event is from CME category and 1 is from Unclassified category
5. Valid Set Diversity (VSD) is sum of WDS(Event) for all events in validation set. If validation set should have events 1,2,3 then VSD(1,2,3) = WDS(1) + WDS(2) + WDS(3)
6. Same for test set, Test Set Diversity is sum of WDS(Event) for all events in test set. Same as 5 but for test set.
7. Calculate difference of valid set diversity and test set diversity with formula: "diff = ((test_diversity/val_diversity)*100)-100"
8. Go trough all valid posible combinations of validation-test pairs constructed from evnets and find best 5 pairs with least difference

In [7]:
def calculate_event_diversity(event, distances_matrix, events_names_list: list):
    column_index = events_names_list.index(str(event))
    column_sum = np.sum(distances_matrix[:, column_index])
    return column_sum / 100


def find_k_fold(val_events: list, test_events: list):
    val_event_1_diversity = calculate_event_diversity(val_events[0], distance_matrix, sorted_csv_file_names_numbers)
    val_event_2_diversity = calculate_event_diversity(val_events[1], distance_matrix, sorted_csv_file_names_numbers)
    val_event_3_diversity = calculate_event_diversity(val_events[2], distance_matrix, sorted_csv_file_names_numbers)

    val_diversity = val_event_1_diversity + val_event_2_diversity + val_event_3_diversity

    test_event_1_diversity = calculate_event_diversity(test_events[0], distance_matrix, sorted_csv_file_names_numbers)
    test_event_2_diversity = calculate_event_diversity(test_events[1], distance_matrix, sorted_csv_file_names_numbers)
    test_event_3_diversity = calculate_event_diversity(test_events[2], distance_matrix, sorted_csv_file_names_numbers)

    test_diversity = test_event_1_diversity + test_event_2_diversity + test_event_3_diversity

    difference = ((test_diversity/val_diversity)*100)-100

    return val_diversity, test_diversity, abs(difference)


In [8]:
cir_events = [4, 5, 6, 7, 8, 10, 22, 25, 34, 35, 36, 37, 38]
cme_events = [9, 14, 15, 17, 19, 21, 30, 31, 32, 33]
idk_events = [1, 11, 12, 13, 23, 28, 29]

val_diversity, test_diversity, difference = find_k_fold([31,35,7], [21,10,23])
print(val_diversity)
print(test_diversity)
print(difference)

13058.805523182404
12960.324132831578
0.7541378128037621


In [9]:
def increment_progress(steps):
    global progress_variable
    progress_variable += steps
    progress_bar.update(steps)

cir_combinations = list(itertools.combinations(cir_events, 1))
cme_combinations = list(itertools.combinations(cme_events, 1))
idk_combinations = list(itertools.combinations(idk_events, 1))

top_k_results = []

k_results = 5

iteration = 1
# approximate count of interation needed
all_iterations = len(cir_events) * len(cme_events) * len(idk_events) * len(cir_events) * len(cme_events) * len(idk_events) / 2
progress_bar = tqdm(total=all_iterations)
progress_variable = 0

for cir_val in cir_combinations:
    for cme_val in cme_combinations:
        for idk_val in idk_combinations:
            val_set = [cir_val[0], cme_val[0], idk_val[0]]
            
            for cir_test in cir_combinations:
                for cme_test in cme_combinations:
                    for idk_test in idk_combinations:
                        test_set = [cir_test[0], cme_test[0], idk_test[0]]
                        
                        if val_set < test_set:
                            val_diversity, test_diversity, difference = find_k_fold(val_set, test_set)
                            increment_progress(1)
                            
                            top_k_results.append((val_set, test_set, difference))
                            top_k_results = sorted(top_k_results, key=lambda x: x[2])[:k_results]

progress_bar.close()
print("All done!")


100%|█████████▉| 413595/414050.0 [00:05<00:00, 71280.01it/s]

All done!





In [10]:
for idx, (best_val_set, best_test_set, min_difference) in enumerate(top_k_results, 1):
    print(f"Pair {idx}:")
    print("  Validation Set:", best_val_set)
    print("  Test Set:", best_test_set)
    
    val_diversity, test_diversity, difference = find_k_fold(best_val_set, best_test_set)
    print("  Validation diversity:", val_diversity)
    print("  Test diversity:", test_diversity)
    print("  Difference between val and test:", difference)
    print("--"*25)

Pair 1:
  Validation Set: [4, 21, 11]
  Test Set: [8, 14, 28]
  Validation diversity: 12172.994016663155
  Test diversity: 12172.994067932721
  Difference between val and test: 4.2117464715829556e-07
--------------------------------------------------
Pair 2:
  Validation Set: [22, 19, 28]
  Test Set: [25, 21, 1]
  Validation diversity: 11300.068263784347
  Test diversity: 11300.065832367618
  Difference between val and test: 2.1516832219958815e-05
--------------------------------------------------
Pair 3:
  Validation Set: [4, 31, 1]
  Test Set: [34, 15, 28]
  Validation diversity: 12426.595099787635
  Test diversity: 12426.568826186029
  Difference between val and test: 0.00021143041513482785
--------------------------------------------------
Pair 4:
  Validation Set: [7, 15, 13]
  Test Set: [35, 21, 1]
  Validation diversity: 10472.18861848962
  Test diversity: 10472.245269805137
  Difference between val and test: 0.0005409692050051262
------------------------------------------------