In [8]:
import os
import medmnist
from medmnist import INFO
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import pandas as pd
import scipy.stats as stats

from models.FunctorModel import FunctorModel
from utils.initialise_W_utils import initialise_W_real_Cn_irreps
from utils.char_tables import Z2_CharTable, CharTable, Cn_CharTable
import torch
from torch import nn


In [9]:
def get_results(path):
    result = {}
    event_acc = EventAccumulator(path)
    event_acc.Reload()
    x = event_acc.Scalars('test_acc')
    assert len(x) == 1
    result['acc'] = x[0].value
    x = event_acc.Scalars('test_auc')
    assert len(x) == 1
    result['auc'] = x[0].value
    return result

In [10]:
def get_W_from_checkpoint(path):
    checkpoint = torch.load(path, map_location='cuda:0')
    base_model = checkpoint['hyper_parameters']['model_flag']
    n_classes = checkpoint['hyper_parameters']['n_classes']
    task = checkpoint['hyper_parameters']['task']
    data_flag = checkpoint['hyper_parameters']['data_flag']
    size = checkpoint['hyper_parameters']['size']
    name = checkpoint['hyper_parameters']['run']
    W_init = checkpoint['hyper_parameters']['W_init']
    lambda_t = checkpoint['hyper_parameters']['lambda_t']
    lambda_W = checkpoint['hyper_parameters']['lambda_W']
    W_block_size = checkpoint['hyper_parameters']['W_block_size']
    fix_rep = checkpoint['hyper_parameters']['fix_rep']
    mod_exponent = checkpoint['hyper_parameters']['modularity_exponent']
    model = FunctorModel(base_model, 3, n_classes, task, data_flag, size, name, W_init=W_init, lambda_t=lambda_t, lambda_W=lambda_W, W_block_size=W_block_size, fix_rep=fix_rep, modularity_exponent=mod_exponent)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to('cuda:0')
    return model.get_W(), W_init, fix_rep

In [24]:
data = pd.DataFrame(columns=["dataset", "lambdaT", "lambdaW", "W_init", "fixed_rep", "algebra_satisfied", "acc", "auc"])


datasets = list(INFO.keys())
datasets = [dataset for dataset in datasets if "3d" not in dataset]
tb_logs = "./tb_logs"

# Select dataset
for dataset in datasets:
    path = f"{tb_logs}/{dataset}/z2"
    experiment_paths = os.listdir(path)

    #Select experiment
    for experiment in experiment_paths:
        lambda_t = experiment.split("lambdaT_")[1].split("_")[0]
        lambda_W = experiment.split("lambdaW_")[1].split("_")[0]

        #Select version
        for version in range(0, 10):
            result_path = f"{path}/{experiment}/version_{version}"
            print(result_path)
            results = get_results(result_path)
            results['lambdaT'] = lambda_t
            results['lambdaW'] = lambda_W
            results['dataset'] = dataset

            if 'vanilla' in experiment:
                W_init = 'Vanilla'
                fixed_rep = 'None'
                n_ones = 'None'
                n_neg_ones = 'None'
                algebra_satisfied = 'None'

            else:
                W, W_init, fixed_rep = get_W_from_checkpoint(f"{result_path}/checkpoints/best_model.ckpt")
                if 'MSE' in experiment:
                    W_init += 'MSE'
                eigs, _ = torch.linalg.eig(W)
                
                n_ones = torch.sum(torch.abs(eigs - 1) < 1e-2).item()
                n_neg_ones = torch.sum(torch.abs(eigs + 1) < 1e-2).item()

                algebra_loss = torch.dist(W @ W, torch.eye(W.shape[0], device=W.device)).item()
                if algebra_loss < 1e-2:
                    algebra_satisfied = True
                else:
                    algebra_satisfied = False

            results['W_init'] = W_init
            results['fixed_rep'] = fixed_rep
            results['n_ones'] = n_ones
            results['n_neg_ones'] = n_neg_ones
            results['algebra_satisfied'] = algebra_satisfied

            data = pd.concat([data, pd.DataFrame([results])])

    

./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_0
Using latent transformation from generators


  data = pd.concat([data, pd.DataFrame([results])])


./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_1
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_2
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_3
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_4
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_5
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_6
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_7
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lambdaW_0.5/version_8
Using latent transformation from generators
./tb_logs/pathmnist/z2/functor_resnet18_lambdaT_0.5_lamb

In [25]:
data.to_csv("results.csv")

In [15]:
data[(data['W_init']=='orthogonalMSE')].head(30)

Unnamed: 0,dataset,lambdaT,lambdaW,W_init,fixed_rep,algebra_satisfied,acc,auc,n_ones,n_neg_ones
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.917967,0.989571,214,203
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.905153,0.987483,192,185
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.9039,0.985793,174,152
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.897911,0.987041,182,140
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.875627,0.988975,203,198
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.887326,0.983905,181,145
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.891922,0.98738,219,200
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.882033,0.988527,190,140
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.892479,0.990098,196,178
0,pathmnist,0.5,0.5,orthogonalMSE,False,False,0.872284,0.989542,215,206


In [102]:
def check_significance_t_test(mean1, std1, n1, mean2, std2, n2, alpha=0.05):
    # Compute the t-statistic and p-value
    t_stat, p_value = stats.ttest_ind_from_stats(mean1, std1, n1, mean2, std2, n2, equal_var=False)
    if p_value < alpha:
        return True
    else:
        return False

def better_than_baseline(performance, baseline):
    if performance > baseline:
        return True
    else:
        return False
    
def check_significance_error_bars(mean1, std1, n1, mean2, std2, n2):
    # Compute Standard Error of the Mean (SEM)
    sem1 = std1 / (n1 ** 0.5)
    sem2 = std2 / (n2 ** 0.5)

    # Compute error bar ranges
    lower1, upper1 = mean1 - sem1, mean1 + sem1
    lower2, upper2 = mean2 - sem2, mean2 + sem2
    # Check if error bars overlap
    if upper1 < lower2 or upper2 < lower1:
        return True
    else:
        return False
    
STATISTIC = 'acc'
POOLING = 'max'
summary = data.groupby(['dataset', 'W_init'])[STATISTIC].agg([POOLING, 'std'])

# Get statistics for W_init = None
summary = summary.reset_index()
baseline = summary[summary['W_init'] == 'Vanilla']

#Apply check significance to each row of summary
if POOLING == 'max':
    summary['significant_error_bars'] = summary.apply(lambda row: better_than_baseline(row[POOLING],  baseline[POOLING][baseline['dataset']==row['dataset']].values[0]), axis=1)
else:
    summary['significant_error_bars'] = summary.apply(lambda row: check_significance_error_bars(row[POOLING], row['std'], 10, 
                                                                        baseline[POOLING][baseline['dataset']==row['dataset']].values[0], 
                                                                        baseline['std'][baseline['dataset']==row['dataset']].values[0], 
                                                                        10), axis=1)
# summary['significant_t_test'] = summary.apply(lambda row: check_significance_t_test(row[POOLING], row['std'], 10, 
#                                                                       baseline[POOLING][baseline['dataset']==row['dataset']].values[0], 
#                                                                       baseline['std'][baseline['dataset']==row['dataset']].values[0], 
#                                                                       10), axis=1)

summary['difference'] = summary.apply(lambda row: (row[POOLING] - baseline[POOLING][baseline['dataset']==row['dataset']]).values[0], axis=1)

# See when significant error_bars is different than significant t-test
#summary[summary['significant_error_bars'] != summary['significant_t_test']].head(100)
#summary[(summary['significant_error_bars'] == True) & (summary['difference']>0)].head(100)

summary.head()

Unnamed: 0,dataset,W_init,max,std,significant_error_bars,difference
0,bloodmnist,Vanilla,0.965799,0.007603,False,0.0
1,bloodmnist,block_diagonal,0.969892,0.008107,True,0.004092
2,bloodmnist,identity,0.971938,0.005714,True,0.006139
3,bloodmnist,orthogonal,0.970476,0.008192,True,0.004677
4,bloodmnist,orthogonalMSE,0.971353,0.003836,True,0.005554


In [103]:
summary.to_csv(f"summary_{STATISTIC}_{POOLING}.csv")

In [106]:
print(POOLING)

max


In [108]:
import pandas as pd

# Read the CSV file (adjust the filename if necessary)
df = pd.read_csv(f'summary_auc_{POOLING}.csv')

# Get sorted unique datasets and models.
datasets = sorted(df['dataset'].unique())
models = sorted(df['W_init'].unique(), key=str)  # Sort by string representation

# Create a dictionary for easy lookup: 
# key = (dataset, W_init) and value = corresponding row as a Series.
data_dict = {(row['dataset'], row['W_init']): row for _, row in df.iterrows()}

# Start building the LaTeX table
latex_lines = [
    "\\begin{table}[h]",
    "    \\centering",
    "    \\renewcommand{\\arraystretch}{1.2}  % Adjust row height for readability",
    "    \\resizebox{\\textwidth}{!}{  % Resize table to fit within the page width",
    "        \\begin{tabular}{l|" + "c" * len(models) + "}",
    "            \\hline",
    "            Dataset & " + " & ".join(map(str, models)) + " \\\\ \\hline"
]

# Process each dataset row by row
for dataset in datasets:
    # Extract all entries for the current dataset.
    subset = df[df['dataset'] == dataset].copy()
    
    # Determine the highest mean in this dataset
    max_mean = subset[POOLING].max()

    # Build the table row for the current dataset
    row_entries = []
    for model in models:
        key = (dataset, model)
        if key in data_dict:
            row_data = data_dict[key]
            # Convert to percentages
            mean_val = row_data[POOLING] * 100
            std_val = row_data['std'] * 100
            # Append a star if the result is significant.
            star = "*" if row_data['significant_error_bars'] else ""
            if POOLING == 'max':
                entry = f"{mean_val:.2f}"
            else:
                entry = f"{mean_val:.2f}$\\pm${std_val:.2f}{star}"
            
            # Apply coloring if applicable.
            if row_data['significant_error_bars']:
                if row_data['difference'] > 0:
                    entry = f"\\textcolor{{ForestGreen}}{{{entry}}}"
                elif row_data['difference'] < 0:
                    entry = f"\\textcolor{{red}}{{{entry}}}"

            # Bold the highest mean value
            if row_data[POOLING] == max_mean:
                entry = f"\\textbf{{{entry}}}"
        else:
            entry = "-"
        row_entries.append(entry)
    
    line = f"            {dataset} & " + " & ".join(row_entries) + " \\\\ \\hline"
    latex_lines.append(line)

# Compute aggregate mean and standard deviation of the means per model
aggregate_means = df.groupby('W_init')[POOLING].mean() * 100  # Convert to percentages
df['variance'] = df['std'] ** 2
def calculate_pooled_std(x):
    return (((x ** 2).sum() / (len(x) - 1)) ** 0.5)*100
aggregate_stds = df.groupby('W_init')['variance'].agg(calculate_pooled_std)

# Determine the highest aggregate mean
max_aggregate_mean = aggregate_means.max()

# Build the final row for aggregate statistics
aggregate_row = []
for model in models:
    mean_val = aggregate_means.get(model, float('nan'))
    std_val = aggregate_stds.get(model, float('nan'))
    baseline_mean = aggregate_means.get('Vanilla', float('nan'))
    baseline_std = aggregate_stds.get('Vanilla', float('nan'))
    if POOLING == 'max':
        star = "*" if better_than_baseline(mean_val, baseline_mean) else ""
        entry = f"{mean_val:.2f}"
    else:
        star = "*" if check_significance_error_bars(mean_val, std_val, 12, baseline_mean, baseline_std, 12) else ""
        entry = f"{mean_val:.2f}$\\pm${std_val:.2f}{star}"

    # Apply coloring if applicable.
    if star == '*':
        if mean_val - baseline_mean > 0:
            entry = f"\\textcolor{{ForestGreen}}{{{entry}}}"
        elif mean_val - baseline_mean < 0:
            entry = f"\\textcolor{{red}}{{{entry}}}"

    # Bold the highest mean
    if mean_val == max_aggregate_mean:
        entry = f"\\textbf{{{entry}}}"

    aggregate_row.append(entry)

# Append aggregate row to the table
latex_lines.append(f"            \\textbf{{Aggregate}} & " + " & ".join(aggregate_row) + " \\\\ \\hline")

# Closing the LaTeX table
latex_lines.extend([
    "        \\end{tabular}",
    "    }",
    "    \\caption{Summary of ML experiments. Statistically significant improvements are in \\textcolor{ForestGreen}{green} if positive and \\textcolor{red}{red} if negative. The highest mean in each row is \\textbf{bold}. A * indicates statistical significance. The last row shows the aggregate mean and standard deviation of the means per model.}",
    "    \\label{tab:results}",
    "\\end{table}"
])

# Combine all lines into one LaTeX table string
latex_table = "\n".join(latex_lines)

# Print or save the output
print(latex_table)


\begin{table}[h]
    \centering
    \renewcommand{\arraystretch}{1.2}  % Adjust row height for readability
    \resizebox{\textwidth}{!}{  % Resize table to fit within the page width
        \begin{tabular}{l|cccccc}
            \hline
            Dataset & Vanilla & block_diagonal & identity & orthogonal & orthogonalMSE & regular \\ \hline
            bloodmnist & 99.83 & \textcolor{ForestGreen}{99.84} & \textbf{\textcolor{ForestGreen}{99.84}} & \textcolor{ForestGreen}{99.83} & 99.83 & 99.83 \\ \hline
            breastmnist & 91.37 & \textcolor{ForestGreen}{91.73} & 90.60 & \textbf{\textcolor{ForestGreen}{92.04}} & \textcolor{ForestGreen}{91.83} & 91.35 \\ \hline
            chestmnist & 78.00 & 77.49 & \textcolor{ForestGreen}{78.03} & 77.37 & 77.60 & \textbf{\textcolor{ForestGreen}{78.41}} \\ \hline
            dermamnist & 92.57 & \textcolor{ForestGreen}{92.86} & \textbf{\textcolor{ForestGreen}{93.10}} & 91.75 & 92.55 & \textcolor{ForestGreen}{92.94} \\ \hline
            octmnist 