In [1]:
import os
import medmnist
from medmnist import INFO
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import pandas as pd
import scipy.stats as stats

from models.FunctorModel import FunctorModel
from utils.initialise_W_utils import initialise_W_real_Cn_irreps
from utils.char_tables import Z2_CharTable, CharTable, Cn_CharTable
import torch
from torch import nn


In [2]:
def get_results(path):
    result = {}
    event_acc = EventAccumulator(path)
    event_acc.Reload()
    x = event_acc.Scalars('test_acc')
    assert len(x) == 1
    result['acc'] = x[0].value
    x = event_acc.Scalars('test_auc')
    assert len(x) == 1
    result['auc'] = x[0].value
    return result

In [3]:
def get_W_from_checkpoint(path):
    checkpoint = torch.load(path, map_location='cuda:0')
    base_model = checkpoint['hyper_parameters']['model_flag']
    n_classes = checkpoint['hyper_parameters']['n_classes']
    task = checkpoint['hyper_parameters']['task']
    data_flag = checkpoint['hyper_parameters']['data_flag']
    size = checkpoint['hyper_parameters']['size']
    name = checkpoint['hyper_parameters']['run']
    W_init = checkpoint['hyper_parameters']['W_init']
    lambda_t = checkpoint['hyper_parameters']['lambda_t']
    lambda_W = checkpoint['hyper_parameters']['lambda_W']
    W_block_size = checkpoint['hyper_parameters']['W_block_size']
    fix_rep = checkpoint['hyper_parameters']['fix_rep']
    mod_exponent = checkpoint['hyper_parameters']['modularity_exponent']
    model = FunctorModel(base_model, 3, n_classes, task, data_flag, size, name, W_init=W_init, lambda_t=lambda_t, lambda_W=lambda_W, W_block_size=W_block_size, fix_rep=fix_rep, modularity_exponent=mod_exponent)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to('cuda:0')
    return model.get_W(), W_init, fix_rep

In [5]:
data = pd.DataFrame(columns=["dataset", "lambdaT", "lambdaW", "W_init", "fixed_rep", "algebra_satisfied", "acc", "auc"])


datasets = list(INFO.keys())
datasets = [dataset for dataset in datasets if "3d" not in dataset]
tb_logs = "./tb_logs"

# Select dataset
for dataset in datasets:
    path = f"{tb_logs}/{dataset}/d8"
    experiment_paths = os.listdir(path)

    #Select experiment
    for experiment in experiment_paths:
        lambda_t = experiment.split("lambdaT_")[1].split("_")[0]
        lambda_W = experiment.split("lambdaW_")[1].split("_")[0]

        #Select version
        for version in range(0, 10):
            result_path = f"{path}/{experiment}/version_{version}"
            print(result_path)
            results = get_results(result_path)
            results['lambdaT'] = lambda_t
            results['lambdaW'] = lambda_W
            results['dataset'] = dataset

            if 'vanilla' in experiment:
                W_init = 'Vanilla'
                fixed_rep = 'None'
                n_ones = 'None'
                n_neg_ones = 'None'
                algebra_satisfied = 'None'

            else:
                #W, W_init, fixed_rep = get_W_from_checkpoint(f"{result_path}/checkpoints/best_model.ckpt")
                if 'MSE' in experiment:
                    W_init += 'MSE'
                # eigs, _ = torch.linalg.eig(W)
                
                # n_ones = torch.sum(torch.abs(eigs - 1) < 1e-2).item()
                # n_neg_ones = torch.sum(torch.abs(eigs + 1) < 1e-2).item()

                # algebra_loss = torch.dist(W @ W, torch.eye(W.shape[0], device=W.device)).item()
                # if algebra_loss < 1e-2:
                #     algebra_satisfied = True
                # else:
                #     algebra_satisfied = False
                W_init = 'regular'
                n_ones = 256
                n_neg_ones = 256
                algebra_satisfied = True
                fixed_rep = True

            #results['W_init'] = f"{W_init}_{lambda_t}" 
            results['W_init'] = f"{W_init}_{lambda_t}" if W_init != 'Vanilla' else W_init
            results['fixed_rep'] = fixed_rep
            results['n_ones'] = n_ones
            results['n_neg_ones'] = n_neg_ones
            results['algebra_satisfied'] = algebra_satisfied

            data = pd.concat([data, pd.DataFrame([results])])

    

./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_0


2025-04-15 09:48:24.759347: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-15 09:48:24.774099: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744706904.792729 1871981 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744706904.798173 1871981 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-15 09:48:24.819194: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_1
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_2
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_3
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_4
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_5
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_6
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_7
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_8
./tb_logs/pathmnist/d8/D8_regular_functor_resnet18_lambdaT_1.0_lambdaW_0.1/version_9
./tb_logs/pathmnist/d8/vanilla_resnet18_lambdaT_0.0_lambdaW_0.1/version_0
./tb_logs/pathmnist/d8/vanilla_resnet18_lambdaT_0.0_lambdaW_0.1/version_1
./tb_logs/pathmnist/d8/vanilla_resnet18_lambdaT_0.0_lambdaW_0.1/version_2
./tb_logs/pat

In [6]:
data.to_csv("results_regular_d8.csv")

In [None]:
other_data = pd.read_csv("results/z2/results_regular_sweep.csv")
data = pd.read_csv("results_flip.csv")

vanilla_data = pd.read_csv("results/z2/results_big_run.csv")
vanilla_data = vanilla_data[vanilla_data['W_init'].isin(['Vanilla'])]

# Merge other_data and data
merged_data = pd.concat([data, other_data, vanilla_data], ignore_index=True)
# Delete entries where W_init='orthogonal', 'orthogonalMSE', 'block_diagonal', 'identity'
merged_data = merged_data[merged_data['W_init'].isin(['regular_1.0', 'Vanilla', 'flip_1.0'])]

print(merged_data['W_init'].unique())

# Rename regular to regular_0.5
data = merged_data

['flip_1.0' 'regular_1.0' 'Vanilla']


In [12]:
def check_significance_t_test(mean1, std1, n1, mean2, std2, n2, alpha=0.05):
    # Compute the t-statistic and p-value
    t_stat, p_value = stats.ttest_ind_from_stats(mean1, std1, n1, mean2, std2, n2, equal_var=False)
    if p_value < alpha:
        return True
    else:
        return False

def better_than_baseline(performance, baseline):
    if performance > baseline:
        return True
    else:
        return False
    
def check_significance_error_bars(mean1, sem1, mean2, sem2):
    # Compute Standard Error of the Mean (SEM)
    # sem1 = std1 / (n1 ** 0.5)
    # sem2 = std2 / (n2 ** 0.5)

    # Compute error bar ranges
    lower1, upper1 = mean1 - sem1, mean1 + sem1
    lower2, upper2 = mean2 - sem2, mean2 + sem2
    # Check if error bars overlap
    if upper1 < lower2 or upper2 < lower1:
        return True
    else:
        return False
    
STATISTIC = 'acc'
POOLING = 'mean'
summary = data.groupby(['dataset', 'W_init'])[STATISTIC].agg([POOLING, 'sem'])

# Get statistics for W_init = None
summary = summary.reset_index()
baseline = summary[summary['W_init'] == 'Vanilla']

#Apply check significance to each row of summary
if POOLING == 'max':
    summary['significant_error_bars'] = summary.apply(lambda row: better_than_baseline(row[POOLING],  baseline[POOLING][baseline['dataset']==row['dataset']].values[0]), axis=1)
else:
    summary['significant_error_bars'] = summary.apply(lambda row: check_significance_error_bars(row[POOLING], row['sem'], 
                                                                        baseline[POOLING][baseline['dataset']==row['dataset']].values[0], 
                                                                        baseline['sem'][baseline['dataset']==row['dataset']].values[0]), axis=1)


summary['difference'] = summary.apply(lambda row: (row[POOLING] - baseline[POOLING][baseline['dataset']==row['dataset']]).values[0], axis=1)

# See when significant error_bars is different than significant t-test
#summary[summary['significant_error_bars'] != summary['significant_t_test']].head(100)
#summary[(summary['significant_error_bars'] == True) & (summary['difference']>0)].head(100)

summary.head(10)

Unnamed: 0,dataset,W_init,mean,sem,significant_error_bars,difference
0,bloodmnist,Vanilla,0.970564,0.001327,False,0.0
1,bloodmnist,regular_1.0,0.973136,0.000533,True,0.002572
2,breastmnist,Vanilla,0.819231,0.015432,False,0.0
3,breastmnist,regular_1.0,0.811538,0.010692,False,-0.007692
4,chestmnist,Vanilla,0.947506,6.4e-05,False,0.0
5,chestmnist,regular_1.0,0.947448,0.000135,False,-5.7e-05
6,dermamnist,Vanilla,0.766584,0.002423,False,0.0
7,dermamnist,regular_1.0,0.76803,0.002442,False,0.001446
8,octmnist,Vanilla,0.7726,0.004164,False,0.0
9,octmnist,regular_1.0,0.7728,0.003756,False,0.0002


In [13]:
summary.to_csv(f"d8_regular_summary_{STATISTIC}_{POOLING}.csv")

In [14]:
print(POOLING)
print(STATISTIC)

mean
acc


In [15]:
import pandas as pd

# Read the CSV file (adjust the filename if necessary)
df = pd.read_csv(f'd8_regular_summary_{STATISTIC}_{POOLING}.csv')

# Get sorted unique datasets and models.
datasets = sorted(df['dataset'].unique())
models = sorted(df['W_init'].unique(), key=str)  # Sort by string representation

# Create a dictionary for easy lookup: 
# key = (dataset, W_init) and value = corresponding row as a Series.
data_dict = {(row['dataset'], row['W_init']): row for _, row in df.iterrows()}

# Start building the LaTeX table
latex_lines = [
    "\\begin{table}[h]",
    "    \\centering",
    "    \\renewcommand{\\arraystretch}{1.2}  % Adjust row height for readability",
    "    \\resizebox{\\textwidth}{!}{  % Resize table to fit within the page width",
    "        \\begin{tabular}{l|" + "c" * len(models) + "}",
    "            \\hline",
    "            Dataset & " + " & ".join(map(str, models)) + " \\\\ \\hline"
]

# Process each dataset row by row
for dataset in datasets:
    # Extract all entries for the current dataset.
    subset = df[df['dataset'] == dataset].copy()
    
    # Determine the highest mean in this dataset
    max_mean = subset[POOLING].max()

    # Build the table row for the current dataset
    row_entries = []
    for model in models:
        key = (dataset, model)
        if key in data_dict:
            row_data = data_dict[key]
            # Convert to percentages
            mean_val = row_data[POOLING] * 100
            std_val = row_data['sem'] * 100
            # Append a star if the result is significant.
            star = "*" if row_data['significant_error_bars'] else ""
            if POOLING == 'max':
                entry = f"{mean_val:.2f}"
            else:
                entry = f"{mean_val:.2f}$\\pm${std_val:.2f}{star}"
            
            # Apply coloring if applicable.
            if row_data['significant_error_bars']:
                if row_data['difference'] > 0:
                    entry = f"\\textcolor{{ForestGreen}}{{{entry}}}"
                elif row_data['difference'] < 0:
                    entry = f"\\textcolor{{red}}{{{entry}}}"

            # Bold the highest mean value
            if row_data[POOLING] == max_mean:
                entry = f"\\textbf{{{entry}}}"
        else:
            entry = "-"
        row_entries.append(entry)
    
    line = f"            {dataset} & " + " & ".join(row_entries) + " \\\\ \\hline"
    latex_lines.append(line)

# Compute aggregate mean and standard deviation of the means per model
aggregate_means = df.groupby('W_init')[POOLING].mean() * 100  # Convert to percentages
#df['variance'] = df['std'] ** 2
# def calculate_pooled_std(x):
#     return (((x ** 2).sum() / (len(x) - 1)) ** 0.5)*100
# aggregate_stds = df.groupby('W_init')['std'].agg(calculate_pooled_std)
def calculate_pooled_sem(x):
    return (((x ** 2).sum() / (len(x) - 1)) ** 0.5)*100
aggregate_stds = df.groupby('W_init')['sem'].agg(calculate_pooled_sem)

# Determine the highest aggregate mean
max_aggregate_mean = aggregate_means.max()

# Build the final row for aggregate statistics
aggregate_row = []
for model in models:
    mean_val = aggregate_means.get(model, float('nan'))
    std_val = aggregate_stds.get(model, float('nan'))
    baseline_mean = aggregate_means.get('Vanilla', float('nan'))
    baseline_std = aggregate_stds.get('Vanilla', float('nan'))
    if POOLING == 'max':
        star = "*" if better_than_baseline(mean_val, baseline_mean) else ""
        entry = f"{mean_val:.2f}"
    else:
        star = "*" if check_significance_error_bars(mean_val, std_val, baseline_mean, baseline_std) else ""
        entry = f"{mean_val:.2f}$\\pm${std_val:.2f}{star}"

    # Apply coloring if applicable.
    if star == '*':
        if mean_val - baseline_mean > 0:
            entry = f"\\textcolor{{ForestGreen}}{{{entry}}}"
        elif mean_val - baseline_mean < 0:
            entry = f"\\textcolor{{red}}{{{entry}}}"

    # Bold the highest mean
    if mean_val == max_aggregate_mean:
        entry = f"\\textbf{{{entry}}}"

    aggregate_row.append(entry)

# Append aggregate row to the table
latex_lines.append(f"            \\textbf{{Aggregate}} & " + " & ".join(aggregate_row) + " \\\\ \\hline")

# Closing the LaTeX table
latex_lines.extend([
    "        \\end{tabular}",
    "    }",
    "    \\caption{Summary of ML experiments. Statistically significant improvements are in \\textcolor{ForestGreen}{green} if positive and \\textcolor{red}{red} if negative. The highest mean in each row is \\textbf{bold}. A * indicates statistical significance. The last row shows the aggregate mean and standard deviation of the means per model.}",
    "    \\label{tab:results}",
    "\\end{table}"
])

# Combine all lines into one LaTeX table string
latex_table = "\n".join(latex_lines)

# Print or save the output
print(latex_table)


\begin{table}[h]
    \centering
    \renewcommand{\arraystretch}{1.2}  % Adjust row height for readability
    \resizebox{\textwidth}{!}{  % Resize table to fit within the page width
        \begin{tabular}{l|cc}
            \hline
            Dataset & Vanilla & regular_1.0 \\ \hline
            bloodmnist & 97.06$\pm$0.13 & \textbf{\textcolor{ForestGreen}{97.31$\pm$0.05*}} \\ \hline
            breastmnist & \textbf{81.92$\pm$1.54} & 81.15$\pm$1.07 \\ \hline
            chestmnist & \textbf{94.75$\pm$0.01} & 94.74$\pm$0.01 \\ \hline
            dermamnist & 76.66$\pm$0.24 & \textbf{76.80$\pm$0.24} \\ \hline
            octmnist & 77.26$\pm$0.42 & \textbf{77.28$\pm$0.38} \\ \hline
            organamnist & 76.49$\pm$0.56 & \textbf{77.34$\pm$0.32} \\ \hline
            organcmnist & 75.71$\pm$0.63 & \textbf{77.11$\pm$0.85} \\ \hline
            organsmnist & 75.49$\pm$0.24 & \textbf{\textcolor{ForestGreen}{76.12$\pm$0.20*}} \\ \hline
            pathmnist & \textbf{88.18$\pm$0.62} & 86