In [1]:
import os
import medmnist
from medmnist import INFO
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import pandas as pd
import scipy.stats as stats

from models.FunctorModel import FunctorModel
from utils.initialise_W_utils import initialise_W_real_Cn_irreps
from utils.char_tables import Z2_CharTable, CharTable, Cn_CharTable
import torch
from torch import nn


In [2]:
def get_results(path):
    result = {}
    event_acc = EventAccumulator(path)
    event_acc.Reload()
    #print(event_acc.Tags()['scalars'])
    x = event_acc.Scalars('test_acc/dataloader_idx_0')
    assert len(x) == 1
    result['acc'] = x[0].value
    x = event_acc.Scalars('aug_test_acc/dataloader_idx_1')
    assert len(x) == 1
    result['aug_acc'] = x[0].value
    return result

In [3]:
def get_W_from_checkpoint(path):
    checkpoint = torch.load(path, map_location='cuda:0')
    base_model = checkpoint['hyper_parameters']['model_flag']
    n_classes = checkpoint['hyper_parameters']['n_classes']
    task = checkpoint['hyper_parameters']['task']
    data_flag = checkpoint['hyper_parameters']['data_flag']
    size = checkpoint['hyper_parameters']['size']
    name = checkpoint['hyper_parameters']['run']
    W_init = checkpoint['hyper_parameters']['W_init']
    lambda_t = checkpoint['hyper_parameters']['lambda_t']
    lambda_W = checkpoint['hyper_parameters']['lambda_W']
    W_block_size = checkpoint['hyper_parameters']['W_block_size']
    fix_rep = checkpoint['hyper_parameters']['fix_rep']
    mod_exponent = checkpoint['hyper_parameters']['modularity_exponent']
    model = FunctorModel(base_model, 3, n_classes, task, data_flag, size, name, W_init=W_init, lambda_t=lambda_t, lambda_W=lambda_W, W_block_size=W_block_size, fix_rep=fix_rep, modularity_exponent=mod_exponent)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to('cuda:0')
    return model.get_W(), W_init, fix_rep

In [4]:
def get_name(s, lambda_t):
    if "vanilla" in s:
        return "Vanilla"
    if "D8_regular_functor" in s:
        return f"regular_{lambda_t}"
    if "only" in s:
        return f"fineTunedAfterEquiv"
    if "fine_tune" in s and "only" not in s:
        return f"fineTunedWithoutEquiv"
    return "regular"

In [5]:
data = pd.DataFrame(columns=["dataset", "lambdaT", "lambdaW", "W_init", "fixed_rep", "algebra_satisfied", "acc", "auc"])


datasets = list(INFO.keys())
#datasets = [dataset for dataset in datasets if "3d" not in dataset]
datasets = ['pathmnist']
tb_logs = "./tb_logs"

# Select dataset
for dataset in datasets:
    path = f"{tb_logs}/{dataset}/ddmnist_c4"
    experiment_paths = os.listdir(path)

    #Select experiment
    for experiment in experiment_paths:
        lambda_t = experiment.split("lambdaT_")[1].split("_")[0]
        lambda_W = experiment.split("lambdaW_")[1].split("_")[0]
        #lr = 0.0005#experiment.split("lr=")[1].split("_")[0]
        if "lr=" in experiment:
            lr = experiment.split("lr=")[1].split("_")[0]
        else:
            lr = 0.0005

        #Select version
        for version in range(0, 5):
            result_path = f"{path}/{experiment}/version_{version}"
            print(result_path)
            results = get_results(result_path)
            results['lambdaT'] = lambda_t
            results['lambdaW'] = lambda_W
            results['dataset'] = dataset

            if 'vanilla' in experiment:
                W_init = 'Vanilla'
                fixed_rep = 'None'
                n_ones = 'None'
                n_neg_ones = 'None'
                algebra_satisfied = 'None'

            else:
                #W, W_init, fixed_rep = get_W_from_checkpoint(f"{result_path}/checkpoints/best_model.ckpt")
                if 'MSE' in experiment:
                    W_init += 'MSE'
                # eigs, _ = torch.linalg.eig(W)
                
                # n_ones = torch.sum(torch.abs(eigs - 1) < 1e-2).item()
                # n_neg_ones = torch.sum(torch.abs(eigs + 1) < 1e-2).item()

                # algebra_loss = torch.dist(W @ W, torch.eye(W.shape[0], device=W.device)).item()
                # if algebra_loss < 1e-2:
                #     algebra_satisfied = True
                # else:
                #     algebra_satisfied = False
                W_init = 'regular'
                n_ones = 256
                n_neg_ones = 256
                algebra_satisfied = True
                fixed_rep = True

            #results['W_init'] = f"{W_init}_{lambda_t}" 
            #W_init = "fine_tuned" if lambda_t == '0.0' and W_init != 'Vanilla' else W_init
            results['W_init'] = get_name(experiment, lambda_t)
            results['fixed_rep'] = fixed_rep
            results['n_ones'] = n_ones
            results['n_neg_ones'] = n_neg_ones
            results['algebra_satisfied'] = algebra_satisfied
            results['lr'] = lr
            if "id_" in experiment:
                layer_id = experiment.split("id_")[1].split("_")[0]
            else:
                layer_id = '9'

            results['layer_id'] = layer_id

            data = pd.concat([data, pd.DataFrame([results])])

    

./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.005_resnet18_lambdaT_1.0_lambdaW_0.1/version_0


2025-05-06 02:46:49.068013: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-06 02:46:49.082601: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746496009.100795 3473152 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746496009.106109 3473152 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-06 02:46:49.127008: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.005_resnet18_lambdaT_1.0_lambdaW_0.1/version_1
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.005_resnet18_lambdaT_1.0_lambdaW_0.1/version_2
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.005_resnet18_lambdaT_1.0_lambdaW_0.1/version_3
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.005_resnet18_lambdaT_1.0_lambdaW_0.1/version_4
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.0009_resnet18_lambdaT_0.0_lambdaW_0.1/version_0
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.0009_resnet18_lambdaT_0.0_lambdaW_0.1/version_1
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.0009_resnet18_lambdaT_0.0_lambdaW_0.1/version_2
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.0009_resnet18_lambdaT_0.0_lambdaW_0.1/version_3
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.0009_resnet18_lambdaT_0.0_lambdaW_0.1/version_4
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.0009_resnet18_lambdaT_0.05_lambdaW_0.1/version_0
./tb_logs/pathmnist/ddmnist_c4/id_89_lr=0.0009_resnet18_lambdaT_0.05_lambdaW_0.1/version_1
./tb_logs/pa

In [9]:
# Show test_acc grouped by lambdaT
summary = data.groupby(['lambdaT', 'layer_id', 'lr'])['aug_acc'].agg(['mean', 'sem'])
summary.head(100)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mean,sem
lambdaT,layer_id,lr,Unnamed: 3_level_1,Unnamed: 4_level_1
0.0,89,0.0009,0.87652,0.002201
0.0,89,0.005,0.89092,0.005358
0.0,89,0.008,0.8806,0.010118
0.05,89,0.0009,0.89884,0.002423
0.05,89,0.005,0.88944,0.004212
0.05,89,0.008,0.90192,0.002577
0.5,89,0.0009,0.90396,0.00383
0.5,89,0.005,0.89736,0.004823
0.5,89,0.008,0.89492,0.004833
1.0,89,0.0009,0.91076,0.003096


In [14]:
data.head()

Unnamed: 0,dataset,lambdaT,lambdaW,W_init,fixed_rep,algebra_satisfied,acc,auc,n_ones,n_neg_ones
0,pathmnist,1.5,0.1,regular,True,True,0.9116,,256.0,256.0
0,pathmnist,1.5,0.1,regular,True,True,0.9022,,256.0,256.0
0,pathmnist,1.5,0.1,regular,True,True,0.9096,,256.0,256.0
0,pathmnist,1.5,0.1,regular,True,True,0.9178,,256.0,256.0
0,pathmnist,1.5,0.1,regular,True,True,0.9154,,256.0,256.0


In [17]:
data.to_csv("all_d8.csv")

In [None]:
other_data = pd.read_csv("results/z2/results_regular_sweep.csv")
data = pd.read_csv("results_flip.csv")

vanilla_data = pd.read_csv("results/z2/results_big_run.csv")
vanilla_data = vanilla_data[vanilla_data['W_init'].isin(['Vanilla'])]

# Merge other_data and data
merged_data = pd.concat([data, other_data, vanilla_data], ignore_index=True)
# Delete entries where W_init='orthogonal', 'orthogonalMSE', 'block_diagonal', 'identity'
merged_data = merged_data[merged_data['W_init'].isin(['regular_1.0', 'Vanilla', 'flip_1.0'])]

print(merged_data['W_init'].unique())

# Rename regular to regular_0.5
data = merged_data

['flip_1.0' 'regular_1.0' 'Vanilla']


In [26]:
def check_significance_t_test(mean1, std1, n1, mean2, std2, n2, alpha=0.05):
    # Compute the t-statistic and p-value
    t_stat, p_value = stats.ttest_ind_from_stats(mean1, std1, n1, mean2, std2, n2, equal_var=False)
    if p_value < alpha:
        return True
    else:
        return False

def better_than_baseline(performance, baseline):
    if performance > baseline:
        return True
    else:
        return False
    
def check_significance_error_bars(mean1, sem1, mean2, sem2):
    # Compute Standard Error of the Mean (SEM)
    # sem1 = std1 / (n1 ** 0.5)
    # sem2 = std2 / (n2 ** 0.5)

    # Compute error bar ranges
    lower1, upper1 = mean1 - sem1, mean1 + sem1
    lower2, upper2 = mean2 - sem2, mean2 + sem2
    # Check if error bars overlap
    if upper1 < lower2 or upper2 < lower1:
        return True
    else:
        return False
    
STATISTIC = 'acc'
POOLING = 'mean'
summary = data.groupby(['dataset', 'W_init'])[STATISTIC].agg([POOLING, 'sem'])

# Get statistics for W_init = None
summary = summary.reset_index()
baseline = summary[summary['W_init'] == 'Vanilla']

#Apply check significance to each row of summary
if POOLING == 'max':
    summary['significant_error_bars'] = summary.apply(lambda row: better_than_baseline(row[POOLING],  baseline[POOLING][baseline['dataset']==row['dataset']].values[0]), axis=1)
else:
    summary['significant_error_bars'] = summary.apply(lambda row: check_significance_error_bars(row[POOLING], row['sem'], 
                                                                        baseline[POOLING][baseline['dataset']==row['dataset']].values[0], 
                                                                        baseline['sem'][baseline['dataset']==row['dataset']].values[0]), axis=1)


summary['difference'] = summary.apply(lambda row: (row[POOLING] - baseline[POOLING][baseline['dataset']==row['dataset']]).values[0], axis=1)

# See when significant error_bars is different than significant t-test
#summary[summary['significant_error_bars'] != summary['significant_t_test']].head(100)
#summary[(summary['significant_error_bars'] == True) & (summary['difference']>0)].head(100)

summary.head(10)

Unnamed: 0,dataset,W_init,mean,sem,significant_error_bars,difference
0,bloodmnist,Vanilla,0.970564,0.001327,False,0.0
1,bloodmnist,fineTunedAfterEquiv,0.968781,0.001678,False,-0.001783
2,bloodmnist,fineTunedWithoutEquiv,0.972581,0.001665,False,0.002017
3,bloodmnist,regular_1.0,0.973136,0.000533,True,0.002572
4,bloodmnist,regular_2.0,0.971266,0.001278,False,0.000702
5,breastmnist,Vanilla,0.819231,0.015432,False,0.0
6,breastmnist,fineTunedAfterEquiv,0.835897,0.004984,False,0.016667
7,breastmnist,fineTunedWithoutEquiv,0.828846,0.007466,False,0.009615
8,breastmnist,regular_1.0,0.811538,0.010692,False,-0.007692
9,breastmnist,regular_2.0,0.797436,0.016307,False,-0.021795


In [27]:
summary.to_csv(f"all_d8_summary_{STATISTIC}_{POOLING}.csv")

In [28]:
print(POOLING)
print(STATISTIC)

mean
acc


In [29]:
import pandas as pd

# Read the CSV file (adjust the filename if necessary)
df = pd.read_csv(f'all_d8_summary_{STATISTIC}_{POOLING}.csv')

# Get sorted unique datasets and models.
datasets = sorted(df['dataset'].unique())
models = sorted(df['W_init'].unique(), key=str)  # Sort by string representation

# Create a dictionary for easy lookup: 
# key = (dataset, W_init) and value = corresponding row as a Series.
data_dict = {(row['dataset'], row['W_init']): row for _, row in df.iterrows()}

# Start building the LaTeX table
latex_lines = [
    "\\begin{table}[h]",
    "    \\centering",
    "    \\renewcommand{\\arraystretch}{1.2}  % Adjust row height for readability",
    "    \\resizebox{\\textwidth}{!}{  % Resize table to fit within the page width",
    "        \\begin{tabular}{l|" + "c" * len(models) + "}",
    "            \\hline",
    "            Dataset & " + " & ".join(map(str, models)) + " \\\\ \\hline"
]

# Process each dataset row by row
for dataset in datasets:
    # Extract all entries for the current dataset.
    subset = df[df['dataset'] == dataset].copy()
    
    # Determine the highest mean in this dataset
    max_mean = subset[POOLING].max()

    # Build the table row for the current dataset
    row_entries = []
    for model in models:
        key = (dataset, model)
        if key in data_dict:
            row_data = data_dict[key]
            # Convert to percentages
            mean_val = row_data[POOLING] * 100
            std_val = row_data['sem'] * 100
            # Append a star if the result is significant.
            star = "*" if row_data['significant_error_bars'] else ""
            if POOLING == 'max':
                entry = f"{mean_val:.2f}"
            else:
                entry = f"{mean_val:.2f}$\\pm${std_val:.2f}{star}"
            
            # Apply coloring if applicable.
            if row_data['significant_error_bars']:
                if row_data['difference'] > 0:
                    entry = f"\\textcolor{{ForestGreen}}{{{entry}}}"
                elif row_data['difference'] < 0:
                    entry = f"\\textcolor{{red}}{{{entry}}}"

            # Bold the highest mean value
            if row_data[POOLING] == max_mean:
                entry = f"\\textbf{{{entry}}}"
        else:
            entry = "-"
        row_entries.append(entry)
    
    line = f"            {dataset} & " + " & ".join(row_entries) + " \\\\ \\hline"
    latex_lines.append(line)

# Compute aggregate mean and standard deviation of the means per model
aggregate_means = df.groupby('W_init')[POOLING].mean() * 100  # Convert to percentages
#df['variance'] = df['std'] ** 2
# def calculate_pooled_std(x):
#     return (((x ** 2).sum() / (len(x) - 1)) ** 0.5)*100
# aggregate_stds = df.groupby('W_init')['std'].agg(calculate_pooled_std)
def calculate_pooled_sem(x):
    return (((x ** 2).sum() / (len(x) - 1)) ** 0.5)*100
aggregate_stds = df.groupby('W_init')['sem'].agg(calculate_pooled_sem)

# Determine the highest aggregate mean
max_aggregate_mean = aggregate_means.max()

# Build the final row for aggregate statistics
aggregate_row = []
for model in models:
    mean_val = aggregate_means.get(model, float('nan'))
    std_val = aggregate_stds.get(model, float('nan'))
    baseline_mean = aggregate_means.get('Vanilla', float('nan'))
    baseline_std = aggregate_stds.get('Vanilla', float('nan'))
    if POOLING == 'max':
        star = "*" if better_than_baseline(mean_val, baseline_mean) else ""
        entry = f"{mean_val:.2f}"
    else:
        star = "*" if check_significance_error_bars(mean_val, std_val, baseline_mean, baseline_std) else ""
        entry = f"{mean_val:.2f}$\\pm${std_val:.2f}{star}"

    # Apply coloring if applicable.
    if star == '*':
        if mean_val - baseline_mean > 0:
            entry = f"\\textcolor{{ForestGreen}}{{{entry}}}"
        elif mean_val - baseline_mean < 0:
            entry = f"\\textcolor{{red}}{{{entry}}}"

    # Bold the highest mean
    if mean_val == max_aggregate_mean:
        entry = f"\\textbf{{{entry}}}"

    aggregate_row.append(entry)

# Append aggregate row to the table
latex_lines.append(f"            \\textbf{{Aggregate}} & " + " & ".join(aggregate_row) + " \\\\ \\hline")

# Closing the LaTeX table
latex_lines.extend([
    "        \\end{tabular}",
    "    }",
    "    \\caption{Summary of ML experiments. Statistically significant improvements are in \\textcolor{ForestGreen}{green} if positive and \\textcolor{red}{red} if negative. The highest mean in each row is \\textbf{bold}. A * indicates statistical significance. The last row shows the aggregate mean and standard deviation of the means per model.}",
    "    \\label{tab:results}",
    "\\end{table}"
])

# Combine all lines into one LaTeX table string
latex_table = "\n".join(latex_lines)

# Print or save the output
print(latex_table)


\begin{table}[h]
    \centering
    \renewcommand{\arraystretch}{1.2}  % Adjust row height for readability
    \resizebox{\textwidth}{!}{  % Resize table to fit within the page width
        \begin{tabular}{l|ccccc}
            \hline
            Dataset & Vanilla & fineTunedAfterEquiv & fineTunedWithoutEquiv & regular_1.0 & regular_2.0 \\ \hline
            bloodmnist & 97.06$\pm$0.13 & 96.88$\pm$0.17 & 97.26$\pm$0.17 & \textbf{\textcolor{ForestGreen}{97.31$\pm$0.05*}} & 97.13$\pm$0.13 \\ \hline
            breastmnist & 81.92$\pm$1.54 & \textbf{83.59$\pm$0.50} & 82.88$\pm$0.75 & 81.15$\pm$1.07 & 79.74$\pm$1.63 \\ \hline
            chestmnist & 94.75$\pm$0.01 & \textbf{\textcolor{ForestGreen}{94.76$\pm$0.01*}} & 94.75$\pm$0.01 & 94.74$\pm$0.01 & 94.76$\pm$0.01 \\ \hline
            dermamnist & 76.66$\pm$0.24 & 76.54$\pm$0.18 & \textbf{76.92$\pm$0.26} & 76.80$\pm$0.24 & 76.57$\pm$0.34 \\ \hline
            octmnist & 77.26$\pm$0.42 & 76.42$\pm$0.48 & 77.07$\pm$0.38 & \textbf{77.28$\p