In [1]:
from collections import defaultdict
import os.path as op
from cycler import cycler
from glob import glob

import h5py
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from viz import (plot_learning, plot_state_traj, plot_all_units,
                 plot_weight_distr)

In [None]:
# tau, include_stp, noise_tau, noise_std, include_corr_noise, p_rel_range
sim_params_all = [[0.01, False, 0.01, 0.0, False, 2],
                  [0.05, False, 0.01, 0.0, False, 2],
                  [0.01, False, 0.01, 0.0, False, 2],
                  [0.01, True, 0.01, 0.0, False, 2],
                  [0.01, False, 0.01, 1e-6, False, 2],
                  [0.01, True, 0.01, 1e-6, False, 2],
                  [0.01, False, 0.01, 1e-6, True, 2],
                  [0.01, True, 0.01, 1e-6, True, 2],
                  [0.01, True, 0.01, 1e-6, False, 1],
                  [0.01, True, 0.01, 1e-6, False, 0]]
sim_type_labels = ['w/o STP, small tau',
                   'w/o STP, large tau',
                   'w/o STP, no noise',
                   'w/ STP, no noise',
                   'w/o STP + uncorr. noise',
                   'w/ STP + uncorr. noise',
                   'w/o STP + corr. noise',
                   'w/ STP + corr. noise',
                   'w/o STP + uncorr. noise, low hetero p_rel',
                   'w/ STP + uncorr. noise, homo p_rel']
n_sim_types = len(sim_params_all)

In [None]:
f_names = glob('/home/ryan/time_coding_output/*.hdf5')
f_names = sorted(f_names)
n_nets = len(f_names)

loss_traj = list()
for f_name in f_names:
    with h5py.File(f_name, 'r') as f_read:
        losses = f_read.get('losses')[:]
        loss_traj.append(losses)

loss_traj = np.stack(loss_traj)

[[5.7222376  2.7740734  2.0313363  ... 0.28380057 0.28359058 0.28338084]
 [1.1941375  1.1088403  1.0771329  ... 0.15273821 0.15263924 0.15254046]
 [0.74793476 0.63584197 0.60506755 ... 0.17522496 0.17513373 0.17504257]
 ...
 [0.9653939  0.9335863  0.92111754 ... 0.46545443 0.40763834 0.50626695]
 [0.68672246 0.52280897 0.5155395  ... 0.33147064 0.36252144 0.32060727]
 [0.77325946 0.59137934 0.6521334  ... 0.34862483 0.32984307 0.42596486]]


In [10]:
with h5py.File('/home/ryan/time_coding_output/sim_data_net00_3752213_2025-08-27_14:19:25.hdf5', 'r') as f_read:
    trial = f_read['z_t'][-1, 0, :, :2]

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '/home/ryan/time_coding_output/sim_data_net00_3752213_2025-08-27_14:19:25.hdf5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
# parse data
learning_metrics = defaultdict(list)
learning_metrics['stp'] = param_keys.tolist()
divergences = list()
response_times = list()
perturbation_mags = list()
stp_types = np.repeat(param_keys, 3)

for key in res[0].keys():
    for trial in res:
        if trial[key] is not None:
            if key == 'divergence':
                divergences.extend(trial[key])
            elif key == 'perturbation_mag':
                perturbation_mags.extend(trial[key])
            elif key == 'response_times':
                response_times.extend(np.tile(trial[key], (3, 1)).tolist())
            else:
                learning_metrics[key].append(trial[key])

# fig_divergence, ax = plt.subplots(1, 1, figsize=(3, 3))
# divergence = np.mean(metrics['divergence'], axis=0)
# delay_times = metrics['response_times'][0]
# perturb_mags = metrics['perturbation_mag'][0]
# plot_divergence(divergence, delay_times, perturb_mags, ax=ax)
# fig_divergence.tight_layout()

# plot avg divergence over time for each STP condition w/ error bars
n_times = len(response_times[0])
data = np.array([np.ravel(divergences), np.ravel(response_times),
                 np.repeat(perturbation_mags, n_times),
                 np.repeat(stp_types, n_times)])
div_df = pd.DataFrame(data.T, columns=['MSE', 'time (s)', 'perturbation',
                                       'stp_type'])
fig_divergence, axes = plt.subplots(1, len(param_labels), sharey=True,
                                    figsize=(10, 3))
for stp_type_idx, stp_type in enumerate(param_labels):
    sns.lineplot(data=div_df[div_df['stp_type'] == stp_type], x='time (s)',
                 y='MSE', hue='perturbation', ax=axes[stp_type_idx])
fig_divergence.tight_layout()
fname = 'divergence.png'
fig_divergence.savefig(op.join(output_dir, fname))

# plot avg learning curve across STP conditions on one set of axes
cm_hidden = sns.color_palette('colorblind')
fig_learning, axes = plt.subplots(1, 1, figsize=(4, 3))
axes.set_prop_cycle(cycler('color', cm_hidden))
loss_groupby_stp = defaultdict(list)
for stp_type, losses in zip(learning_metrics['stp'], learning_metrics['losses']):
    loss_groupby_stp[stp_type].append(losses)

for idx, (key, val) in enumerate(loss_groupby_stp.items()):
    losses_avg = np.mean(val, axis=0)
    iter_idxs = np.arange(len(losses_avg))
    axes.semilogy(iter_idxs, losses_avg, lw=2, label=key)
axes.grid(axis='y')
axes.grid(which="minor", color="0.9")
ub_xtick = iter_idxs[-1]
axes.set_xticks([0, ub_xtick])
axes.set_xlabel('iteration')
axes.set_ylabel('normalized MSE')
axes.legend()
fig_learning.tight_layout()
fname = 'learning.png'
fig_learning.savefig(op.join(output_dir, fname))