In [1]:
from collections import defaultdict
import os.path as op
from cycler import cycler
from glob import glob

import h5py
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from viz import (plot_learning, plot_state_traj, plot_all_units,
                 plot_weight_distr)

In [5]:
fnames = glob('/home/ryan/time_coding_output/*.hdf5')
fnames = sorted(fnames)
n_nets = len(fnames)

for fname in fnames:
    with h5py.File(fname, 'r') as fread:
        trial = fread['losses']

KeyError: "Unable to synchronously open object (object 'losses' doesn't exist)"

In [None]:
with h5py.File('/home/ryan/Desktop/sim_data_net0_5c09edf_2025-08-25_16:37:47.hdf5', 'r') as f_read:
    trial = f_read['z_t'][-1, 0, :, :2]

In [None]:
# parse data
learning_metrics = defaultdict(list)
learning_metrics['stp'] = param_keys.tolist()
divergences = list()
response_times = list()
perturbation_mags = list()
stp_types = np.repeat(param_keys, 3)

for key in res[0].keys():
    for trial in res:
        if trial[key] is not None:
            if key == 'divergence':
                divergences.extend(trial[key])
            elif key == 'perturbation_mag':
                perturbation_mags.extend(trial[key])
            elif key == 'response_times':
                response_times.extend(np.tile(trial[key], (3, 1)).tolist())
            else:
                learning_metrics[key].append(trial[key])

# fig_divergence, ax = plt.subplots(1, 1, figsize=(3, 3))
# divergence = np.mean(metrics['divergence'], axis=0)
# delay_times = metrics['response_times'][0]
# perturb_mags = metrics['perturbation_mag'][0]
# plot_divergence(divergence, delay_times, perturb_mags, ax=ax)
# fig_divergence.tight_layout()

# plot avg divergence over time for each STP condition w/ error bars
n_times = len(response_times[0])
data = np.array([np.ravel(divergences), np.ravel(response_times),
                 np.repeat(perturbation_mags, n_times),
                 np.repeat(stp_types, n_times)])
div_df = pd.DataFrame(data.T, columns=['MSE', 'time (s)', 'perturbation',
                                       'stp_type'])
fig_divergence, axes = plt.subplots(1, len(param_labels), sharey=True,
                                    figsize=(10, 3))
for stp_type_idx, stp_type in enumerate(param_labels):
    sns.lineplot(data=div_df[div_df['stp_type'] == stp_type], x='time (s)',
                 y='MSE', hue='perturbation', ax=axes[stp_type_idx])
fig_divergence.tight_layout()
fname = 'divergence.png'
fig_divergence.savefig(op.join(output_dir, fname))

# plot avg learning curve across STP conditions on one set of axes
cm_hidden = sns.color_palette('colorblind')
fig_learning, axes = plt.subplots(1, 1, figsize=(4, 3))
axes.set_prop_cycle(cycler('color', cm_hidden))
loss_groupby_stp = defaultdict(list)
for stp_type, losses in zip(learning_metrics['stp'], learning_metrics['losses']):
    loss_groupby_stp[stp_type].append(losses)

for idx, (key, val) in enumerate(loss_groupby_stp.items()):
    losses_avg = np.mean(val, axis=0)
    iter_idxs = np.arange(len(losses_avg))
    axes.semilogy(iter_idxs, losses_avg, lw=2, label=key)
axes.grid(axis='y')
axes.grid(which="minor", color="0.9")
ub_xtick = iter_idxs[-1]
axes.set_xticks([0, ub_xtick])
axes.set_xlabel('iteration')
axes.set_ylabel('normalized MSE')
axes.legend()
fig_learning.tight_layout()
fname = 'learning.png'
fig_learning.savefig(op.join(output_dir, fname))