### Notebook to Demonstrate Gradient Decent Over Intrinsic ###

In [1]:
# Basic Imports
from intrinsic.module import PlasticEdges
from intrinsic.model import Intrinsic
import numpy as np
import torch
from matplotlib import pyplot as plt

We'll initialize a fully connected Intrinsic network with 3 state nodes, each with 3 channels and spatial dimmesnions with size 5x5. 

In [2]:
rnet = Intrinsic(num_nodes=3, node_shape=(1, 3, 5, 5), kernel_size=3, edge_module=PlasticEdges, track_activation_history=True, mask=None, inject_noise=True, optimize_weights=False)

First run 1000 iterations of the randomly initialized network with no input to get a sense of what it's internal dynamics are like. 

In [3]:
for i in range(1000):
    rnet()

Plot the series of states for some unit somewhere in the middle of the run (ideally after the dynamics have largely converged). Here we look at time step 200 - 600 of a unit at spatial location (1, 1) in channel 1 of node 2

In [4]:
history = rnet.past_states
history = np.array([s.detach().squeeze().numpy() for s in history])
plt.plot(history[200:600, 2, 1, 1, 1])

Note that there are some semi-repetitive oscillatory dynamics that arise from the structure and local update functions on the random graph. 

 Now attempt to use gradient decent to alter the dynamics at this unit to match a given temporal pattern (a cosine function, below)

In [5]:
x = torch.arange(0, 200)
target_pattern = 2 * torch.sin((x * 4 * torch.pi) / 60)
plt.plot(target_pattern)

In [6]:
loss_history = []

In the below training loop, take the time course of unit of interest (2, 1, 1, 1) at each generation and compute the mean square error (mse) relative to the target cosine time series above. This mse is minimized via gradient decent over the initial weight, channel map, and plasticity parameters of the model ~200 parameters.

In [11]:
# initialize a sgd optimizer over global parameters
optim = torch.optim.Adam(params=rnet.parameters(), eps=1e-6, lr=.001)
# we will run 1250 generations with gradient update after each gen
for gen in range(2000):
    rnet.past_states = []
    rnet = rnet.detach(reset_intrinsic=True)
    optim.zero_grad()
    # will model 200 time steps per generation
    for i in range(200):
        rnet()
    history = torch.stack(rnet.past_states)
    node_history = history[:, 2, 1, 1, 1].squeeze()
    # compute the mean square error between target time series and observed time series at unit of interest
    loss = torch.sqrt(torch.sum((target_pattern - node_history) ** 2))
    loss_history.append(loss.detach().cpu().item())
    print("gen", gen, "loss", loss)
    # preform gradient update.
    loss.backward()
    optim.step()

The loss generally decreases, but is clearly a very complex landscape. Adam optimization (or perhaps gradient decent in general) is not very well suited.

In [15]:
 from scipy.ndimage import gaussian_filter
 plt.plot(gaussian_filter(np.array(loss_history), 1))

The time course of the unit of interest now looks like our target cosine time course.

In [16]:
plt.plot(history[0:, 2, 1, 1, 1].detach().numpy())

If we look at units in other nodes / channels they have various oscillatory patterns with similar frequency. 

In [17]:
plt.plot(history[0:, 0, 2, 3, 2].detach().numpy())

Now show the plasticity values for each node to node edge. 

In [19]:
plast = rnet.edge.plasticity.detach().numpy()
print(plast)
# plt.imshow(plast, cmap="hot")

In [11]:
import pickle
with open("../models/cos_entrained.pkl", "wb") as f:
    pickle.dump(rnet, f)

In [None]:
print(rnet.edge.chan_map.detach().numpy())

In [None]:
print(torch.linalg.inv(rnet.edge.chan_map.detach()).numpy())