In [2]:
# %% cell 1: imports and load npz data
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
from collections import defaultdict

%matplotlib inline

# load the apd weights npz file (saved by apd_train.py)
npz_path = "apd_stuff/apd_weights.npz"
npz_data = np.load(npz_path)
print("loaded npz file with keys:")
print(npz_data.files)

loaded npz file with keys:
['epoch_0_layer_0', 'epoch_0_layer_1', 'epoch_0_layer_2', 'epoch_0_layer_3', 'epoch_0_layer_4', 'epoch_0_layer_5', 'epoch_0_layer_6', 'epoch_0_layer_7', 'epoch_0_layer_8', 'epoch_0_layer_9', 'epoch_0_layer_10', 'epoch_0_layer_11', 'epoch_1_layer_0', 'epoch_1_layer_1', 'epoch_1_layer_2', 'epoch_1_layer_3', 'epoch_1_layer_4', 'epoch_1_layer_5', 'epoch_1_layer_6', 'epoch_1_layer_7', 'epoch_1_layer_8', 'epoch_1_layer_9', 'epoch_1_layer_10', 'epoch_1_layer_11', 'epoch_2_layer_0', 'epoch_2_layer_1', 'epoch_2_layer_2', 'epoch_2_layer_3', 'epoch_2_layer_4', 'epoch_2_layer_5', 'epoch_2_layer_6', 'epoch_2_layer_7', 'epoch_2_layer_8', 'epoch_2_layer_9', 'epoch_2_layer_10', 'epoch_2_layer_11', 'epoch_3_layer_0', 'epoch_3_layer_1', 'epoch_3_layer_2', 'epoch_3_layer_3', 'epoch_3_layer_4', 'epoch_3_layer_5', 'epoch_3_layer_6', 'epoch_3_layer_7', 'epoch_3_layer_8', 'epoch_3_layer_9', 'epoch_3_layer_10', 'epoch_3_layer_11', 'epoch_4_layer_0', 'epoch_4_layer_1', 'epoch_4_layer

In [3]:
# %% cell 2: aggregate weights data by layer and epoch
# keys are in the form "epoch_{epoch}_layer_{layer_idx}"
data = defaultdict(list)
pattern = r"epoch_(\d+)_layer_(\d+)"
for key in npz_data.files:
    m = re.match(pattern, key)
    if m:
        epoch = int(m.group(1))
        layer = int(m.group(2))
        weights = npz_data[key]  # shape: (n_components,)
        data[layer].append((epoch, weights))

# sort each layer's data by epoch
for layer in data:
    data[layer].sort(key=lambda x: x[0])
    
# summary of epochs available per layer
for layer in sorted(data.keys()):
    epochs = [e for e, _ in data[layer]]
    print(f"layer {layer}: epochs recorded: {epochs}")

layer 0: epochs recorded: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216

In [11]:
# %% cell 3: analyze early training weight changes (first 200 epochs)
summary_stats = {}
for layer in sorted(data.keys()):
    epochs, vec_list = zip(*data[layer])
    vec_array = np.array(vec_list)
    
    # compute diffs and get early training mask
    diffs = np.diff(vec_array, axis=0)
    early_epochs = np.array(epochs)[:-1]
    early_mask = early_epochs < 200
    early_diffs = diffs[early_mask]
    
    # compute key statistics
    max_changes = np.max(np.abs(early_diffs), axis=0)  # max change per component
    mean_changes = np.mean(np.abs(early_diffs), axis=0)  # mean change per component
    active_comps = np.where(max_changes > np.percentile(np.abs(diffs), 90))[0]
    
    # find epoch with max change (properly indexed)
    max_change_idx = np.unravel_index(np.argmax(np.abs(early_diffs)), early_diffs.shape)[0]
    
    summary_stats[layer] = {
        'most_active_components': active_comps.tolist(),
        'max_change_epoch': early_epochs[max_change_idx],
        'max_change_magnitude': np.max(np.abs(early_diffs)),
        'mean_activity': np.mean(np.abs(early_diffs))
    }
    
    print(f"\nlayer {layer}:")
    print(f"most active components: {active_comps}")
    print(f"peak change at epoch {summary_stats[layer]['max_change_epoch']}")
    print(f"max change magnitude: {summary_stats[layer]['max_change_magnitude']:.6f}")
    print(f"mean activity level: {summary_stats[layer]['mean_activity']:.6f}")


layer 0:
most active components: [0 1 2 3 4 5 6 7]
peak change at epoch 66
max change magnitude: 0.028958
mean activity level: 0.009849

layer 1:
most active components: [0 1 2 3 4 5 6 7]
peak change at epoch 76
max change magnitude: 0.028533
mean activity level: 0.009336

layer 2:
most active components: [0 1 2 3 4 5 6 7]
peak change at epoch 0
max change magnitude: 0.031290
mean activity level: 0.011540

layer 3:
most active components: [0 1 2 3 4 5 6 7]
peak change at epoch 137
max change magnitude: 0.031122
mean activity level: 0.010400

layer 4:
most active components: [0 1 2 3 4 5 6 7]
peak change at epoch 169
max change magnitude: 0.032681
mean activity level: 0.011899

layer 5:
most active components: [0 1 2 3 4 5 6 7]
peak change at epoch 82
max change magnitude: 0.027223
mean activity level: 0.009479

layer 6:
most active components: [0 1 2 3 4 5 6 7]
peak change at epoch 92
max change magnitude: 0.032481
mean activity level: 0.011757

layer 7:
most active components: [0 1 2

In [13]:
# %% cell 4: analyze learning waves through network
for layer in sorted(data.keys()):
    epochs, vec_list = zip(*data[layer])
    vec_array = np.array(vec_list)
    diffs = np.diff(vec_array, axis=0)
    early_epochs = np.array(epochs)[:-1]
    early_mask = early_epochs < 200
    early_diffs = diffs[early_mask]
    
    # find top 5 biggest change epochs
    top_changes = np.sort(np.abs(early_diffs).max(axis=1))[-5:]
    top_epochs = early_epochs[np.argsort(np.abs(early_diffs).max(axis=1))[-5:]]
    
    print(f"\nlayer {layer}:")
    print(f"top 5 change epochs: {top_epochs.tolist()}")
    print(f"change magnitudes: {[f'{x:.6f}' for x in top_changes.tolist()]}")


layer 0:
top 5 change epochs: [41, 146, 25, 131, 66]
change magnitudes: ['0.026466', '0.026578', '0.026775', '0.027780', '0.028958']

layer 1:
top 5 change epochs: [158, 23, 22, 153, 76]
change magnitudes: ['0.025762', '0.025769', '0.026291', '0.026717', '0.028533']

layer 2:
top 5 change epochs: [167, 114, 103, 29, 0]
change magnitudes: ['0.028931', '0.029230', '0.029957', '0.031078', '0.031290']

layer 3:
top 5 change epochs: [41, 30, 82, 127, 137]
change magnitudes: ['0.027921', '0.027988', '0.029197', '0.030499', '0.031122']

layer 4:
top 5 change epochs: [18, 174, 115, 87, 169]
change magnitudes: ['0.029915', '0.030541', '0.031356', '0.031470', '0.032681']

layer 5:
top 5 change epochs: [83, 139, 130, 49, 82]
change magnitudes: ['0.025433', '0.025766', '0.025856', '0.026554', '0.027223']

layer 6:
top 5 change epochs: [90, 150, 93, 131, 92]
change magnitudes: ['0.030510', '0.031381', '0.031387', '0.031764', '0.032481']

layer 7:
top 5 change epochs: [32, 135, 113, 97, 33]
change 