In [None]:
# Import from local package
import sys
sys.path.append('../model')
import torch
import numpy as np

# Device configuration
device = torch.device("cpu")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
from models import SimpleCNN2Layer as Model
from collectdata import collect_data

In [None]:
_, _, dataset_test = collect_data(
    "/data/schreihf/PvFinder/July_31_75000.npz",
    30000, 10000,
    device=device)

In [None]:
model = Model().to(device)

In [None]:
name = 'July_31_75000_2layer_{}.pyt'
def run_over(i):
    model.load_state_dict(torch.load(name.format(i)))
    model.eval()
    vals = model.state_dict()['conv1.weight'].numpy()
    return vals

In [None]:
fig, axs = plt.subplots(1,5, figsize=(18,3), sharey=True)
lns = [0,0,0,0,0]
for i, ax in enumerate(axs):
    lns[i], = ax.plot(run_over(0)[i,0,:], '.')
    ax.set_xlim(0, 24)
    ax.set_ylim(-4, 4)
fig.subplots_adjust(wspace=0)
txt = axs[4].text(17, 3, "Epoch 0")

In [None]:
def plot_over(i):
    vals = run_over(i)
    for j in range(5):
        lns[j].set_data(np.arange(25), vals[j,0,:])
    txt.set_text(f"Epoch {i}")

In [None]:
ani = animation.FuncAnimation(fig, plot_over, frames=range(200), interval=100, repeat_delay=2_000)
ani.save('July_31_75000_training.gif', dpi=80, writer='imagemagick')