# Imports

In [1]:
# imports
import numpy as np
import torch
from tueplots import bundles, figsizes
import wandb
import matplotlib.pyplot as plt

import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [2]:
from analysis import sweep2df, RED, BLUE, plot_typography


In [3]:
USETEX = True

In [4]:
plt.rcParams.update(bundles.neurips2022(usetex=USETEX))
plt.rcParams.update({
    'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
                            r'\usepackage{amsmath}'] # boldsymbol
})

In [5]:
plot_typography(usetex=USETEX, small=12, medium=16, big=20)

In [6]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "sam_test"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

In [20]:
SWEEP_ID = "qqqqlma7"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"sam_vae_{SWEEP_ID}"
df_sam_vae, sam_val_loss, sam_val_scale_inv = sweep2df(sweep.runs, filename, save=True, load=False)

In [41]:
SWEEP_ID = "4aggfh82"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"fix_enc_var_sweep_{SWEEP_ID}"
df_fix_enc_var, fix_enc_val_loss, fix_enc_val_scale_inv = sweep2df(sweep.runs, filename, save=True, load=False)

In [7]:
SWEEP_ID = "1agmyttm"
sweep = api.sweep(f"{ENTITY}/{PROJECT}/{SWEEP_ID}")
filename = f"rae_sweep_{SWEEP_ID}"
df_rae, rae_val_loss, rae_val_scale_inv = sweep2df(sweep.runs, filename, save=True, load=False)

In [42]:
fix_enc_val_scale_inv

array([[9.20854677e-09, 6.72012301e-09, 6.08861101e-09, ...,
        7.75040569e-09, 7.86115712e-09, 7.76962865e-09],
       [8.47482161e-09, 5.96795048e-09, 5.22379315e-09, ...,
        8.18566717e-09, 8.16604696e-09, 8.17441614e-09],
       [1.03964577e-08, 6.77937201e-09, 5.81463123e-09, ...,
        7.01794258e-09, 7.08914814e-09, 7.13616138e-09],
       ...,
       [4.24299612e-06, 2.68058561e-06, 2.11165634e-06, ...,
        7.83850408e-07, 7.65252633e-07, 7.47969170e-07],
       [9.60158810e-06, 5.41860685e-06, 3.50665025e-06, ...,
        9.35496138e-07, 9.26833558e-07, 9.20830557e-07],
       [5.60429885e-06, 2.89496794e-06, 2.23653902e-06, ...,
        9.42207134e-07, 9.32686608e-07, 9.29666110e-07]])

# Pre-processing


# Plots

In [318]:
TICK_PADDING = 2
LABELPAD = 3
DIMS = [3,5,8,10]
COLORS = ["tab:blue","tab:orange","tab:green","tab:red"]

In [None]:
LABELPAD = 1
TICK_PADDING = 2

fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=2, rel_width=1)['figure.figsize'])

"""MCC vs CIMA over different gamma"""
ax = fig.add_subplot(121)
ax.grid(True, which="both", ls="-.")

# create double y-axis
# ax_cima = ax.twinx()
# ax_mcc = ax.twinx()

# Link the respective y-axes for grid and plot
# ax.get_shared_y_axes().join(ax, ax_mcc)

# Remove ticks and labels and set which side to label
# ticksoff = dict(labelleft=False, labelright=False, left=False, right=False)
# ax.tick_params(axis="y", **ticksoff)
# ax_mcc.tick_params(axis="y", labelleft=True, labelright=False, left=True, right=False)
# ax_cima.tick_params(axis="y", labelleft=False, labelright=True, left=False, right=True)

# MCC
# ax_mcc.errorbar(range(min_len), mcc.mean(0), yerr=mcc.std(0), label='mcc', c=BLUE)

# CIMA
# ax_cima.errorbar(range(min_len), np.log10(cima).mean(0), yerr=np.log10(cima).std(0), label='cima', c=RED)

ax.errorbar()

# set z-order to make CIMA the top plot
# https://stackoverflow.com/a/30506077/16912032
# ax.set_zorder(ax.get_zorder()+1)
# ax.set_frame_on(False)

ax_cima.set_ylabel("$\log_{10} c_{\mathrm{IMA}}$", labelpad=LABELPAD)
ax.set_ylabel("$\mathrm{MCC}$", labelpad=LABELPAD+17)

ax.set_xlabel("Epoch", labelpad=LABELPAD)
val_epoch_factor = 25
ax.set_xticklabels(range(0, min_len * val_epoch_factor, val_epoch_factor))
ax.grid(True, which="both", ls="-.")

handle1, label1 = ax_mcc.get_legend_handles_labels()
handle2, label2 = ax_cima.get_legend_handles_labels()

plt.legend([handle2[0], handle1[0]],["$\log_{10} c_{\mathrm{IMA}}$", "$\mathrm{MCC}$"], loc='center right')

plt.savefig("dsprites_mcc_cima.svg")