# Elman_network_ER.ipynb

Using an Elman network to test Error Regression mechanisms.

In [None]:
import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pickle
# import lovely_tensors as lt
# lt.monkey_patch()
# Expects to be in '/home/z/OIST Dropbox/Sergio Verduzco/code/python/pytorch/deep_explorations/rnn'
%cd ./data/

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"device = {device}")

In [3]:
def set_length(seq, des_len):
    """ Change the sequence so it has the desired length. """
    seq_len = seq.shape[0]
    if des_len > seq_len:  # we will repeat a set of random points
        new_seq = torch.zeros((des_len,) + seq.shape[1:], dtype=seq.dtype)
        new_idxs = torch.randint(0, seq_len-1, (des_len - seq_len,))
        idx = 0
        for i in range(seq_len):
            new_seq[idx] = seq[i]
            if i in new_idxs:
                new_seq[idx] = seq[i+1]
                idx += 1
            idx += 1
    elif des_len < seq_len:  # we will remove a set of random points
        new_seq = torch.zeros((des_len,) + seq.shape[1:], dtype=seq.dtype)
        remove_idxs = torch.randperm(seq_len)[:seq_len-des_len]
        idx = 0
        for i in range(seq_len):
            if i in remove_idxs:
                continue
            else:
                new_seq[idx] = seq[i]
                idx += 1
    else:
        new_seq = seq
    return new_seq

In [None]:
# Load training data
seq_len = 940
fname = 'switch_circle5_triangle5_1.pkl'
with open(fname, 'rb') as f:
    c1 = torch.tensor(pickle.load(f))
    c1 = set_length(c1, seq_len)
    print(f"c1 shape: {c1.shape}")

fname = 'infty10_1.pkl'
with open(fname, 'rb') as f:
    c2 = torch.tensor(pickle.load(f))
    c2 = set_length(c2, seq_len)
    print(f"c2 shape: {c2.shape}")

fname = 'eye.pkl'
with open(fname, 'rb') as f:
    c3 = torch.tensor(pickle.load(f))
    c3 = set_length(c3, seq_len)
    print(f"c3 shape: {c3.shape}")

fname = 'switch_circle2_circle2.pkl'
with open(fname, 'rb') as f:
    c4 = torch.tensor(pickle.load(f))
    c4 = set_length(c4, seq_len)
    print(f"c4 shape: {c4.shape}")

coordinates = torch.stack([c1, c2, c3, c4], dim=0)
# coordinates = coordinates[:2].to(device).type(torch.float)
# put each pattern in a different part of the plane
# coordinates[0, :, 0] -= 2.0  # first pattern moves left
# coordinates[1, :, 0] += 2.0  # second pattern moves right
# coordinates[2, :, 1] += 2.0  # third pattern moves up
# coordinates[3, :, 1] -= 2.0  # fourth pattern moves down

print(coordinates.shape)

In [None]:
# plot the loaded array
fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(16, 4))
for col in range(4):
    axs[col].axis('equal')
    axs[col].plot(coordinates[col,:, 0].cpu().numpy(), coordinates[col,:, 1].cpu().numpy())

coordinates = torch.tensor(coordinates, dtype=torch.float32, device=device)
print(f"Coordinates shape: {coordinates.shape}")

## The Elman RNN

In [7]:
# Create the network, the optimizer, and the loss function
input_size = 2  # this shouldn't change
hidden_size = 256
nonlinearity = 'tanh'  # tanh or relu
lr = 3e-4  # learning rate
all_generated = []
all_generated_er = []

torch.manual_seed(345)

class ElmanRNN_ER(nn.Module):
    """ RNN with a the same 'A' tensor for all time steps. """
    def __init__(self, input_size, hidden_size, nonlinearity, device):
        super(ElmanRNN_ER, self).__init__()
        self.input_size = input_size
        self.batch_size = input_size[0]
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.rnn = nn.RNN(input_size=input_size,
                          hidden_size=hidden_size,
                          nonlinearity=nonlinearity,
                          num_layers=1,
                          batch_first=True,
                          device=device)
        self.fc = nn.Linear(hidden_size, 2, device=device)
        self.A = nn.Parameter(torch.zeros((self.batch_size, hidden_size), device=device), requires_grad=True)

    def forward(self, x, h0=None):
        if h0 is None:
            h0 =  torch.zeros((self.batch_size, hidden_size), device=device)
        out, h = self.rnn(x, h0 + self.A)
        return self.fc(out), h


class ElmanRNN_ER2(nn.Module):
    """ RNN with an 'A' tensor for each time step. 
    
        Based on:
        https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html
    """
    def __init__(self, input_size, hidden_size, batch_size, seq_len, nonlinearity, device):
        super(ElmanRNN_ER2, self).__init__()
        self.input_size = input_size
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.rnn = nn.RNN(input_size=input_size,
                          hidden_size=hidden_size,
                          nonlinearity=nonlinearity,
                          num_layers=1,  # more layers would require another dimension in A
                          batch_first=True,
                          device=device)
        self.fc = nn.Linear(hidden_size, 2, device=device)
        scaling_factor = np.sqrt(batch_size * seq_len * hidden_size) ** (-1)
        self.A = nn.Parameter(scaling_factor * torch.randn((batch_size, seq_len, hidden_size),
                                                           device=device), requires_grad=True)
        self.params = dict(self.rnn.named_parameters())

    def forward(self, x, hx=None, batch_first=False):
        if batch_first:
            x = x.transpose(0, 1)
        # print(f"x shape: {x.shape}")
        # print(f"A shape: {self.A.shape}")
        seq_len, batch_size, _ = x.size()
        # assert seq_len == self.A.shape[1], f"Different input and A sizes: {seq_len}, {self.A.shape[1]}."
        assert batch_size == self.A.shape[0], f"Different input and A batch sizes: {batch_size}, {self.A.shape[0]}."
        if hx is None:
            hx = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size, device=device)
        h_t_minus_1 = hx.clone()
        h_t = hx.clone()
        output = []
        for t in range(seq_len):
            for layer in range(self.rnn.num_layers):
                input_t = x[t] if layer == 0 else h_t[layer - 1]
                h_t[layer] = torch.tanh(
                    input_t @ self.params[f"weight_ih_l{layer}"].T
                    + (h_t_minus_1[layer] + self.A[:, t]) @ self.params[f"weight_hh_l{layer}"].T
                    + self.params[f"bias_hh_l{layer}"]
                    + self.params[f"bias_ih_l{layer}"]
                )
            output.append(self.fc(h_t[-1].clone()))
            h_t_minus_1 = h_t.clone()
        output = torch.stack(output)
        if batch_first:
            output = output.transpose(0, 1)
        return output, h_t

    def reshape_A_batches(self, new_batch_size, device):
        """ Change the batch size for the 'A' tensor. """
        seq_len = self.A.shape[1]
        hidden_size = self.A.shape[2]
        scaling_factor = np.sqrt(new_batch_size * seq_len * hidden_size) ** (-1)
        self.A = nn.Parameter(scaling_factor * torch.randn((new_batch_size, seq_len, hidden_size),
                                                           device=device), requires_grad=True)


# rnn = ElmanRNN_ER(input_size, hidden_size, nonlinearity, coordinates.size()[0], device)
# rnn = ElmanRNN_ER2(input_size, hidden_size, coordinates.shape[0], coordinates.size()[1], nonlinearity, device)
# rnn.reshape_A_batches(4, device)

# optim = torch.optim.Adam(rnn.parameters(), lr=lr)

# loss = nn.MSELoss(reduction='mean')

In [None]:
# training loop with all training examples in a single batch (non-PVRNN)
n_epochs = 1000 # number of epochs
# xcoordinates = coordinates.unsqueeze(1)

for epoch in range(n_epochs):
    optim.zero_grad()
    out, _  = rnn(coordinates[:, :-1], batch_first=True)
    error = loss(out, coordinates[:, 1:])
    error.backward()
    optim.step()
    if epoch % 20 == 0:
        print(f"Error at epoch {epoch} = {error}")

In [None]:
# training loop breaking the data into two batches of 2 elements each (non-PVRNN)
n_epochs = 1000
batch1 = coordinates[0:2]
batch2 = coordinates[2:]
rnn.reshape_A_batches(2, device)

for epoch in range(n_epochs):
    # batch 1
    optim.zero_grad()
    out, _  = rnn(batch1[:, :-1], batch_first=True)
    error = loss(out, batch1[:, 1:])
    error.backward()
    optim.step()
    if epoch % 20 == 0:
        print(f"Error at batch 1 epoch {epoch} = {error}")
    # batch 2 
    optim.zero_grad()
    out, _  = rnn(batch2[:, :-1], batch_first=True)
    error = loss(out, batch2[:, 1:])
    error.backward()
    optim.step()
    if epoch % 20 == 0:
        print(f"Error at batch 2 epoch {epoch} = {error}")


---
### The Elman PVRNN

In [6]:
# Create the network, the optimizer, and the loss function
input_size = 2  # this shouldn't change
hidden_size = 256
nonlinearity = 'tanh'  # tanh or relu
lr = 3e-4  # learning rate
all_generated = []
all_generated_er = []

torch.manual_seed(345)

class ElmanPVRNN(nn.Module):
    def __init__(self, input_size, hidden_size, batch_size, seq_len, nonlinearity, device):
        super(ElmanPVRNN, self).__init__()
        self.input_size = input_size
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.rnn = nn.RNN(input_size=input_size,
                          hidden_size=hidden_size,
                          nonlinearity=nonlinearity,
                          num_layers=1,  # more layers would require another dimension in A
                          batch_first=True,
                          device=device)
        self.fc = nn.Linear(hidden_size, 2, device=device)
        self.prior_decoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size, device=device),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, 2 * hidden_size, device=device)
            )
        scaling_factor = np.sqrt(batch_size * seq_len * hidden_size) ** (-1)
        self.A_mu = nn.Parameter(scaling_factor * torch.randn((batch_size, seq_len, hidden_size),
                                                           device=device), requires_grad=True)
        self.A_logvar = nn.Parameter(scaling_factor * torch.randn((batch_size, seq_len, hidden_size),
                                                           device=device), requires_grad=True)
        self.zh_l = nn.Linear(hidden_size, hidden_size, device=device)
        self.params = dict(self.rnn.named_parameters())

    def forward(self, x, hx=None, A_mu=None, batch_first=False, use_prior=False, deterministic=False):
        """ Predict a sequence with Elman PV-RNN.
        
        :param x: input sequence, shaped (seq_len, N, seq_dim) or (N, seq_len, seq_dim) 
        :type x: torch.tensor
        :hx: initial hidden state (n_layers, N, hidden_size)
        :type hx: torch.tensor
        :param A_mu: The posterior A_mu tensor to use. `use_prior` overrides this. 
        :type A_mu: torch.tensor, shape (N, seq_len, seq_dim).
        :param batch_first: if True, x has shape (N, seq_len, seq_dim)
        :type batch_first: bool
        :param use_prior: if True, use the z value from the prior decoder
        :type use_prior: bool
        :param deterministic: if True, z will be equal to mu when sampling
        :returns: predicted sequence, final hidden state, sequence of mu and logvar priors
        :rtype: 4-tuple
        """
        if batch_first:
            x = x.transpose(0, 1)
        # print(f"x shape: {x.shape}")
        # print(f"A shape: {self.A.shape}")
        seq_len, batch_size, _ = x.size()
        # assert seq_len == self.A.shape[1], f"Different input and A sizes: {seq_len}, {self.A.shape[1]}."
        assert batch_size == self.A_mu.shape[0], f"Different input and A batch sizes: {batch_size}, {self.A_mu.shape[0]}."
        if hx is None:
            hx = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size, device=device)
        h_t_minus_1 = hx.clone()
        h_t = hx.clone()
        output = []
        mu_priors = []
        logvar_priors = []
        # print(f"seq_len = {seq_len}!!!!!!!!!!!!!!!!!!!")
        for t in range(seq_len):
            mu_prior_t, logvar_prior_t = torch.tensor_split(self.prior_decoder(h_t[0].clone().detach()), 2, dim=1) # 2*(batch_size, hidden size)

            if use_prior:
                z_t = self.sample_z_prior(mu_prior_t, logvar_prior_t, deterministic=deterministic).unsqueeze(0)
            elif A_mu is None:
                z_t = self.sample_z_post(t, deterministic=deterministic).unsqueeze(0)
            else:
                z_t = A_mu[:, t]
            for layer in range(self.rnn.num_layers):
                input_t = x[t] if layer == 0 else h_t[layer - 1]
                h_t[layer] = torch.tanh(
                    input_t @ self.params[f"weight_ih_l{layer}"].T
                    + h_t_minus_1[layer] @ self.params[f"weight_hh_l{layer}"].T
                    + self.params[f"bias_hh_l{layer}"]
                    + self.params[f"bias_ih_l{layer}"]
                    + self.zh_l(z_t)
                )
            output.append(self.fc(h_t[-1].clone()))
            mu_priors.append(mu_prior_t.clone())
            logvar_priors.append(logvar_prior_t.clone())
            h_t_minus_1 = h_t.clone()
        output = torch.stack(output)
        if batch_first:
            output = output.transpose(0, 1)
        mu_priors = torch.stack(mu_priors, dim=1)
        logvar_priors = torch.stack(logvar_priors, dim=1)
        return output, h_t, mu_priors, logvar_priors

    def sample_z_post(self, step, deterministic=False):
        if deterministic:
            z_post = self.A_mu[:, step]
        else:
            z_post = self.A_mu[:, step] + (
                0.5 * torch.exp(self.A_logvar[:, step]) *
                torch.randn_like(self.A_mu[:, step]))
        return z_post

    def sample_z_prior(self, mu_prior_t, logvar_prior_t, deterministic=False):
        if deterministic:
            z_prior = mu_prior_t
        else:
            z_prior = mu_prior_t + (
                0.5 * torch.randn_like(mu_prior_t) *
                torch.exp(logvar_prior_t))
        return z_prior

    def reshape_A_batches(self, new_batch_size, device, keep_norm=False):
        """ Change the batch size for the 'A' tensor. 
        
            WARNING: this will disconnect the optimizer from the new
            A_mu and A_logvar tensors. Make sure to update the optimiezer.
        """
        seq_len = self.A_mu.shape[1]
        hidden_size = self.A_mu.shape[2]
        if keep_norm:
            mu_norm = self.A_mu.norm()
            logvar_norm =self.A_logvar.norm()
        else:
            mu_norm = 1.0
            logvar_norm = 1.0
        scaling_factor_mu = mu_norm * np.sqrt(new_batch_size * seq_len * hidden_size) ** (-1)
        scaling_factor_logvar = logvar_norm * np.sqrt(new_batch_size * seq_len * hidden_size) ** (-1)
        self.A_mu = nn.Parameter(scaling_factor_mu * torch.randn((new_batch_size, seq_len, hidden_size),
                                                           device=device), requires_grad=True)
        self.A_logvar = nn.Parameter(scaling_factor_logvar * torch.randn((new_batch_size, seq_len, hidden_size),
                                                           device=device), requires_grad=True)

    def set_A_mu(self, A_mu_values):
        """ Replace the values of A_mu. """
        with torch.no_grad():
            self.A_mu.copy_(A_mu_values)


def GaussianKLD(mu1, logvar1, mu2, logvar2, *, 
                logvar_min=-50.0, logvar_max=80.0, eps=1e-8, compress_thresh=None):
    """ KL divergence between two Gaussian distributions.

        :param mu1: mean of the posterior for all time steps
        :type mu1: torch.Tensor shaped (N, T, hidden_size)
        :param logvar1: log variance of the posterior for all time steps
        :type logvar1: torch.Tensor shaped (N, T, hidden_size)
        :param mu2: mean of the prior for all time steps
        :type mu2: torch.Tensor shaped (N, T, hidden_size)
        :param logvar2: log variance of the prior for all time steps
        :type logvar2: torch.Tensor shaped (N, T, hidden_size)
        :param compress_thresh: values beyond this will be "compressed" by the log function
        :type compress_thresh: float
        :returns: sum of all N * T * hidden_size KL divergences
        :rtype: torch.Tensor shaped (1,)
    
    """
    # Ensure same dtype/device (cheap no-ops if already matching)
    mu2       = mu2.to(dtype=mu1.dtype, device=mu1.device)
    logvar2   = logvar2.to(dtype=mu1.dtype, device=mu1.device)
    # Clamp log-variances to avoid exp overflow/underflow
    lv1 = torch.clamp(logvar1, min=logvar_min, max=logvar_max)
    lv2 = torch.clamp(logvar2, min=logvar_min, max=logvar_max)
    var1 = torch.exp(lv1)              # bounded by clamp
    var2 = torch.exp(lv2) + eps        # eps guards divide-by-zero
    diff = mu1 - mu2
    # # KL per element
    kl = 0.5 * (lv2 - lv1 + (var1 + diff * diff) / var2 - 1.0)
    # # Sum over all elements -> scalar
    kl_sum = kl.sum()
    # Optional smooth compression of huge values (keeps graph/device)
    if compress_thresh is not None:
        kl_sum = torch.where(kl_sum < compress_thresh, kl_sum, kl_sum + torch.log(kl_sum + eps))
    return kl_sum


rnn = ElmanPVRNN(input_size, hidden_size, coordinates.shape[0], coordinates.size()[1], nonlinearity, device)
# rnn.reshape_A_batches(4, device)
optim = torch.optim.Adam(rnn.parameters(), lr=lr)
loss = nn.MSELoss(reduction='mean')


In [None]:
# training loop for ElmanPVRNN, single batch
n_epochs = 1000 # number of epochs
deterministic = True
shuffle = True

for epoch in range(n_epochs):
    batch_idxs = torch.randperm(4) if shuffle else torch.arange(4)
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(coordinates[batch_idxs, :-1], batch_first=True, deterministic=deterministic)
    # print(f"mu_priors shape: {mu_priors.shape}")
    # print(f"logvar_priors shape: {logvar_priors.shape}")
    pred_error = loss(out, coordinates[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 20 == 0:
        print(f"Error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")

In [7]:
no_A_params_list = []
for name, param in rnn.named_parameters():
    if name not in ['A_mu', 'A_sigma']:
        no_A_params_list.append(param)


optim = torch.optim.Adam(no_A_params_list, lr=lr)

In [None]:
# training loop for ElmanPVRNN, single batch, Jacobian Trace Maximization
from torch.func import jvp
n_epochs = 1000 # number of epochs
deterministic = True
shuffle = True
A_norm_thresh = 0.001
num_probes = 4


def loss_A(A):  # the gradient of A as a function of A
    optim.zero_grad()
    out, _, _, _ = rnn(coordinates[:, :-1], A_mu=A, batch_first=True, deterministic=True)
    return loss(out, coordinates[:, 1:])


for epoch in range(n_epochs):
    batch_idxs = torch.randperm(4) if shuffle else torch.arange(4)
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(coordinates[batch_idxs, :-1], batch_first=True, deterministic=deterministic)
    # print(f"mu_priors shape: {mu_priors.shape}")
    # print(f"logvar_priors shape: {logvar_priors.shape}")
    pred_error = loss(out, coordinates[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 20 == 0:
        print(f"Error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")
        print(f"A_mu grad norm: {rnn.A_mu.grad.norm()}")
        print(f"hh grad norm: {rnn.rnn.weight_hh_l0.grad.norm()}")
    
    if rnn.A_mu.grad.norm() < A_norm_thresh:
        A_mu = rnn.A_mu.clone()
        trace_est = 0.0
        optim.zero_grad()
        for _ in range(num_probes):
            z = torch.empty_like(A_mu).bernoulli_(0.5).mul(2).sub(1)
            _, Jz = jvp(loss_A, (A_mu,), (z,) )
            trace_est = trace_est + (Jz * z).sum()
        # trace_est = -trace_est / (num_probes * A_mu.shape[1])  # invert sign to maximize
        trace_est = -trace_est / num_probes
        print(f"----------------->> Epoch {epoch}. Trace est: {-trace_est}")
        trace_est.backward()
        optim.step()

In [None]:
(rnn.A_mu - init_A).norm()

In [None]:
# training loop for ElmanPVRNN, two batches
n_epochs = 1000
deterministic = True

batch1 = coordinates[0:2]
batch2 = coordinates[2:]
rnn.reshape_A_batches(2, device, keep_norm=True)

for epoch in range(n_epochs):
    # batch 1
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(batch1[:, :-1], batch_first=True, deterministic=deterministic)
    pred_error = loss(out, batch1[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 10 == 0:
        print(f"Batch1 error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")

    # batch 2 
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(batch2[:, :-1], batch_first=True, deterministic=deterministic)
    pred_error = loss(out, batch2[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 10 == 0:
        print(f"Batch2 error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")
        print(f"A_mu grad norm: {rnn.A_mu.grad.norm()}")
        print(f"hh grad norm: {rnn.rnn.weight_hh_l0.grad.norm()}")


In [None]:
# training loop for ElmanPVRNN, two batches, Jacobian trace maximization
from torch.func import jvp

n_epochs = 1000
deterministic = True
A_norm_thresh = 0.001
num_probes = 4

batch1 = coordinates[0:2]
batch2 = coordinates[2:]
rnn.reshape_A_batches(2, device, keep_norm=True)
optim = torch.optim.Adam(rnn.parameters(), lr=lr)

def loss_A1(A):  # the gradient of A as a function of A
    # rnn.set_A_mu(A)  # doesn't work with jvp
    optim.zero_grad()
    out, _, _, _ = rnn(batch1[:, :-1], A_mu=A, batch_first=True, deterministic=True)
    return loss(out, batch1[:, 1:])

def loss_A2(A):  # the gradient of A as a function of A
    # rnn.set_A_mu(A)  # doesn't work with jvp
    optim.zero_grad()
    out, _, _, _ = rnn(batch2[:, :-1], A_mu=A, batch_first=True, deterministic=True)
    return loss(out, batch2[:, 1:])
    
for epoch in range(n_epochs):
    # batch 1
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(batch1[:, :-1], batch_first=True, deterministic=deterministic)
    pred_error = loss(out, batch1[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 10 == 0:
        print(f"Batch1 error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")
        print(f"A_mu grad norm: {rnn.A_mu.grad.norm()}")
    
    if rnn.A_mu.grad.norm() < A_norm_thresh:
        A_mu = rnn.A_mu.clone()
        trace_est = 0.0
        optim.zero_grad()
        for _ in range(num_probes):
            z = torch.empty_like(A_mu).bernoulli_(0.5).mul(2).sub(1)
            _, Jz = jvp(loss_A1, (A_mu,), (z,) )
            trace_est = trace_est + (Jz * z).sum()
        # trace_est = -trace_est / (num_probes * A_mu.shape[1])  # invert sign to maximize
        trace_est = -trace_est / num_probes
        print(f"----------------->> Epoch {epoch}. Trace est: {-trace_est}")
        trace_est.backward()
        optim.step()

    # batch 2 
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(batch2[:, :-1], batch_first=True, deterministic=deterministic)
    pred_error = loss(out, batch2[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 10 == 0:
        print(f"Batch2 error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")
        print(f"A_mu grad norm: {rnn.A_mu.grad.norm()}")
        print(f"hh grad norm: {rnn.rnn.weight_hh_l0.grad.norm()}")

    if rnn.A_mu.grad.norm() < A_norm_thresh:
        A_mu = rnn.A_mu.clone()
        trace_est = 0.0
        optim.zero_grad()
        for _ in range(num_probes):
            z = torch.empty_like(A_mu).bernoulli_(0.5).mul(2).sub(1)
            _, Jz = jvp(loss_A2, (A_mu,), (z,) )
            trace_est = trace_est + (Jz * z).sum()
        # trace_est = -trace_est / (num_probes * A_mu.shape[1])  # invert sign to maximize
        trace_est = -trace_est / num_probes
        print(f"------------------------>>  Epoch {epoch}. Trace est: {-trace_est}")
        trace_est.backward()
        optim.step()
            


In [None]:
# training loop for ElmanPVRNN, two batches, Error Regression, Jacobian trace maximization
from torch.func import jvp

n_epochs = 5000
deterministic = False
A_norm_thresh = 0.05
num_probes = 4
er_steps = 6
er_window = 5
er_interval =20 
er_stride = 8 * er_window
bidirectional_er = True

batch1 = coordinates[0:2]
batch2 = coordinates[2:]
rnn.reshape_A_batches(2, device, keep_norm=True)
optim = torch.optim.Adam(rnn.parameters(), lr=lr)
seq_len = coordinates.shape[1]


def loss_A1(A):  # the gradient of A as a function of A
    # rnn.set_A_mu(A)  # doesn't work with jvp
    optim.zero_grad()
    out, _, _, _ = rnn(batch1[:, :-1], A_mu=A, batch_first=True, deterministic=True)
    return loss(out, batch1[:, 1:])


def loss_A2(A):  # the gradient of A as a function of A
    # rnn.set_A_mu(A)  # doesn't work with jvp
    optim.zero_grad()
    out, _, _, _ = rnn(batch2[:, :-1], A_mu=A, batch_first=True, deterministic=True)
    return loss(out, batch2[:, 1:])


def pointwise_error_regression(model, optim, init_idx, h, er_steps, er_window, coordinates):
    """ Run error regression one time step at a time. 

        :param model: the model to run
        :type model: torch.Module
        :param optim: the parameter optimizer
        :type optim: torch.optim
        :param init_idx: index of the first coordinate to use.
        :type init_idx: int
        :param h: hidden state, shape (1, batch_size, hidden_size)
        :type h: torch.tensor
        :param er_steps: number of ER steps
        :type er_steps: int
        :param er_window: size of ER window
        :type er_window: int
        :param coordinates: all coordinates, shape (batch_size, seq_len, 2)
        :type coordinates: torch.tensor
    """
    batch_size = coordinates.shape[0]
    seq_len = coordinates.shape[1]
    gen_coordinates = torch.zeros((batch_size, seq_len, 2), device=device)
    coords = coordinates[:, init_idx:init_idx+1].clone()
    for point_idx in range(init_idx, min(seq_len, init_idx + er_window)):
        for er_idx in range(er_steps):
            optim.zero_grad()
            h = h.detach()  # so h doesn't point to the t-1 computational graph
            if er_idx == 0:
                use_prior = True
            else:
                use_prior = False
            new_coords, h, _, _ = model(coords, h, batch_first=True, use_prior=use_prior, deterministic=deterministic)
            gen_coordinates[:, point_idx, :] = new_coords.squeeze()
            init_idx = point_idx - max(0, point_idx - er_window)
            error = (coordinates[:, init_idx:point_idx + 1] - 
                     gen_coordinates[:, init_idx:point_idx + 1]).norm()
            error.backward()
            gen_coordinates = gen_coordinates.detach()
            # print(f"pass {er_idx}")
            for name, param in model.named_parameters():
                if name in ["A_mu", "A_logvar"]:
                    continue
                else:
                    if param.grad is not None:
                        param.grad.detach()
                        param.grad.zero_()
            optim.step()


def error_regression(model, optim, init_idx, h, er_steps, er_window, coordinates, bidirectional=True):
    """ Run error regression. 

        :param model: the model to run
        :type model: torch.Module
        :param optim: the parameter optimizer
        :type optim: torch.optim
        :param init_idx: index of the first coordinate to use.
        :type init_idx: int
        :param h: hidden state, shape (1, batch_size, hidden_size)
        :type h: torch.tensor
        :param er_steps: number of ER steps
        :type er_steps: int
        :param er_window: size of ER window
        :type er_window: int
        :param coordinates: all coordinates, shape (batch_size, seq_len, 2)
        :type coordinates: torch.tensor
    """
    # batch_size = coordinates.shape[0]
    seq_len = coordinates.shape[1]
    # gen_coordinates = torch.zeros((batch_size, seq_len, 2), device=device)
    init_idx = max(0, init_idx - er_window)
    if bidirectional:
        end_idx = min(seq_len, init_idx + er_window)
    else:
        end_idx = min(seq_len, init_idx + 1)
    coords = coordinates[:, init_idx:end_idx].clone()
    for er_idx in range(er_steps):
        optim.zero_grad()
        h = h.detach()  # so h doesn't point to the t-1 computational graph
        if er_idx == 0:
            use_prior = True
        else:
            use_prior = False
        new_coords, h, _, _ = model(coords, h, batch_first=True, use_prior=use_prior, deterministic=deterministic)
        # gen_coordinates[:, init_idx:end_idx, :] = new_coords.squeeze()
        error = (coordinates[:, init_idx:end_idx] - new_coords).norm()
        error.backward()
        # gen_coordinates = gen_coordinates.detach()
        # print(f"pass {er_idx}")
        for name, param in model.named_parameters():
            if name in ["A_mu", "A_logvar"]:
                continue
            else:
                if param.grad is not None:
                    param.grad.detach()
                    param.grad.zero_()
        optim.step()

#------------------------------ the training loop ------------------------------
    
for epoch in range(n_epochs):
    # batch 1
    if epoch % er_interval == 0 and epoch > 20:
        h = torch.zeros((1, batch1.shape[0], hidden_size), device=device)
        for coord_idx in range(0, seq_len - er_window, er_stride):
            error_regression(rnn, optim, coord_idx, h, er_steps, er_window, batch1)
            print('.', end='')
        print(" ")
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(batch1[:, :-1], batch_first=True, deterministic=deterministic)
    pred_error = loss(out, batch1[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 10 == 0:
        print(f"Batch1 error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")
        print(f"A_mu grad norm: {rnn.A_mu.grad.norm()}")
        print(f"A_mu norm: {rnn.A_mu.norm()}")
    
    # if rnn.A_mu.grad.norm() < A_norm_thresh:
    #     A_mu = rnn.A_mu.clone()
    #     trace_est = 0.0
    #     optim.zero_grad()
    #     for _ in range(num_probes):
    #         z = torch.empty_like(A_mu).bernoulli_(0.5).mul(2).sub(1)
    #         _, Jz = jvp(loss_A1, (A_mu,), (z,) )
    #         trace_est = trace_est + (Jz * z).sum()
    #     trace_est = -trace_est / num_probes
    #     print(f"------------------------>>  Epoch {epoch}. Trace est: {-trace_est}")
    #     trace_est.backward()
    #     optim.step()

    # batch 2 
    if epoch % er_interval == 0 and epoch > 20:
        h = torch.zeros((1, batch2.shape[0], hidden_size), device=device)
        for coord_idx in range(0, seq_len - er_window, er_stride):
            error_regression(rnn, optim, coord_idx, h, er_steps, er_window, batch2)
            print('*', end='')
        print(" ")
    optim.zero_grad()
    out, _, mu_priors, logvar_priors  = rnn(batch2[:, :-1], batch_first=True, deterministic=deterministic)
    pred_error = loss(out, batch2[:, 1:])
    KL_div = GaussianKLD(rnn.A_mu[:, :-1], rnn.A_logvar[:, :-1], mu_priors, logvar_priors)
    error = pred_error + 0.0001 * KL_div
    error.backward()
    optim.step()
    if epoch % 10 == 0:
        print(f"Batch2 error at epoch {epoch} = {error}. pred error: {pred_error}, KL div: {KL_div}")
        print(f"A_mu grad norm: {rnn.A_mu.grad.norm()}")
        print(f"A_mu norm: {rnn.A_mu.norm()}")
        print(f"hh grad norm: {rnn.rnn.weight_hh_l0.grad.norm()}")

    # A_mu_norm = rnn.A_mu.grad.norm()
    # if  A_mu_norm < A_norm_thresh or A_mu_norm > 20:
    #     trace_factor = -0.5 if A_mu_norm < A_norm_thresh else 1.0
    #     A_mu = rnn.A_mu.clone()
    #     trace_est = 0.0
    #     optim.zero_grad()
    #     for _ in range(num_probes):
    #         z = torch.empty_like(A_mu).bernoulli_(0.5).mul(2).sub(1)
    #         _, Jz = jvp(loss_A2, (A_mu,), (z,) )
    #         trace_est = trace_est + (Jz * z).sum()
    #     trace_est = trace_factor * trace_est / num_probes
    #     print(f"-------------->>  Epoch {epoch}. Trace est: {trace_est}. trace factor: {trace_factor}. A_mu norm: {A_mu_norm}")
    #     trace_est.backward()
    #     optim.step()
            


In [None]:
optim.zero_grad(set_to_none=False)
print(rnn.A_mu.grad.norm())
print(rnn.rnn.weight_hh_l0.grad.norm())
# should be True:
print(any(p is rnn.A_mu for g in optim.param_groups for p in g["params"]))


In [16]:
# save network
torch.save(rnn.state_dict(), 'pvrnn_2batches_ER_JTM_nonoise_epoch5000')

In [None]:
# load saved data (2 batches)
rnn.reshape_A_batches(2, device, keep_norm=True)
rnn.load_state_dict(torch.load('pvrnn_2batches_epoch10000'))

In [None]:
# Get distances between sections of A
distances = []
angles = []
batch_size = A_clone.shape[0]
for i in range(batch_size):
    for j in range(i+1, batch_size):
        dist = (A_clone[i] - A_clone[j]).norm()
        Ai = A_clone[i] / A_clone[i].norm()
        Aj = A_clone[j] / A_clone[j].norm()
        angle = np.arccos((Ai * Aj).sum().detach().cpu())
        print(f"i = {i}, j = {j}, dist = {dist}, angle = {angle}")
        distances.append(dist)
        angles.append(angle)

In [None]:
rnn.reshape_A_batches(4, device)
with torch.no_grad():
    rnn.A *= A_clone.norm()

print(rnn.A.norm())

In [None]:
for subten in A_clone:
    print(subten.norm())
    


In [27]:
all_generated_er = []
all_generated = []
all_generated_er.append(gen_coordinates_er)
all_generated.append(gen_coordinates)

In [None]:
# Generate a trajectory autoregressively (non-PVRNN)
n_points = coordinates.shape[1]
# coordinate = torch.tensor([-1, 1], dtype=torch.float32, device=device)
init_coords = coordinates[:,0:1]
print(f"init coords shape : {init_coords.shape}")
er_steps = 30
er_window = 15

# rnn.reshape_A_batches(2, device)

def generate_trajectory(init_coordinates, n_points):
    """ Standard autoregressive trajectory generation. """
    batch_size = init_coordinates.shape[0]
    gen_coordinates = np.empty((batch_size, n_points, 2))
    coords = init_coordinates.clone()
    h = torch.zeros((1, batch_size, hidden_size), device=device)
    with torch.no_grad():
        for point_idx in range(n_points):
            # print(f"coordinate shape before call: {coords.shape}")
            coords, h = rnn(coords, h, batch_first=True)
            # print(f"coordinate shape after call: {coordinate.shape}")
            gen_coordinates[:, point_idx, :] = coords.detach().squeeze().cpu().numpy()
    return gen_coordinates


def generate_trajectory_ER(init_coordinates, n_points, er_steps, er_window, all_coordinates):
    """ Autoregressive trajectory generated with error regression. """
    batch_size = init_coordinates.shape[0]
    gen_coordinates = torch.zeros((batch_size, n_points, 2), device=device)
    coords = init_coordinates.clone()
    h = torch.zeros((1, batch_size, hidden_size), device=device)
    for point_idx in range(n_points):
        for er_idx in range(er_steps):
            optim.zero_grad()
            h = h.detach()  # so h doesn't point to the t-1 computational graph
            new_coords, h = rnn(coords, h, batch_first=True)
            gen_coordinates[:, point_idx, :] = new_coords.squeeze()
            init_idx = point_idx - max(0, point_idx - er_window)
            error = (all_coordinates[:, init_idx:point_idx + 1] - 
                     gen_coordinates[:, init_idx:point_idx + 1]).norm()
            error.backward()
            gen_coordinates = gen_coordinates.detach()
            # print(f"pass {er_idx}")
            for name, param in rnn.named_parameters():
                if name == "A":
                    continue
                else:
                    if param.grad is not None:
                        param.grad.detach()
                        param.grad.zero_()
            optim.step()
            if point_idx % 100 == 0:
                print(f"Point {point_idx} ER step {er_idx} error: {error} ")
        coords = new_coords.detach()
    return gen_coordinates.detach().cpu().numpy()


gen_coordinates = generate_trajectory(init_coords, n_points)
gen_coordinates_er = generate_trajectory_ER(init_coords, n_points, er_steps, er_window, coordinates[:])
all_generated_er.append(gen_coordinates_er)
all_generated.append(gen_coordinates)


In [None]:
# Generate a trajectory autoregressively (PVRNN)
n_points = 600 #coordinates.shape[1]
# coordinate = torch.tensor([-1, 1], dtype=torch.float32, device=device)
init_coords = coordinates[:,0:1]
print(f"init coords shape : {init_coords.shape}")
er_steps = 100
er_window = 40
deterministic = True
# all_generated = []
# all_generated_er = []

rnn.reshape_A_batches(4, device)
optim = torch.optim.Adam(rnn.parameters(), lr=lr)

def generate_trajectory(init_coordinates, n_points):
    """ Standard autoregressive trajectory generation. """
    batch_size = init_coordinates.shape[0]
    gen_coordinates = np.empty((batch_size, n_points, 2))
    coords = init_coordinates.clone()
    with torch.no_grad():
        for point_idx in range(n_points):
            # print(f"coordinate shape before call: {coords.shape}")
            coords, h, _, _ = rnn(coords, batch_first=True, deterministic=deterministic)
            # print(f"coordinate shape after call: {coordinate.shape}")
            gen_coordinates[:, point_idx, :] = coords.detach().squeeze().cpu().numpy()
    return gen_coordinates


def generate_trajectory_ER(init_coordinates, n_points, er_steps, er_window, all_coordinates):
    """ Autoregressive trajectory generated with error regression. """
    batch_size = init_coordinates.shape[0]
    gen_coordinates = torch.zeros((batch_size, n_points, 2), device=device)
    coords = init_coordinates.clone()
    h = torch.zeros((1, batch_size, hidden_size), device=device)
    for point_idx in range(n_points):
        for er_idx in range(er_steps):
            optim.zero_grad()
            h = h.detach()  # so h doesn't point to the t-1 computational graph
            if er_idx == 0:
                use_prior = True
            else:
                use_prior = False
            new_coords, h, _, _ = rnn(coords, h, batch_first=True, use_prior=use_prior, deterministic=deterministic)
            gen_coordinates[:, point_idx, :] = new_coords.squeeze()
            init_idx = point_idx - max(0, point_idx - er_window)
            error = (all_coordinates[:, init_idx:point_idx + 1] - 
                     gen_coordinates[:, init_idx:point_idx + 1]).norm()
            error.backward()
            gen_coordinates = gen_coordinates.detach()
            # print(f"pass {er_idx}")
            for name, param in rnn.named_parameters():
                if name in ["A_mu", "A_logvar"]:
                    print('.', end='')
                    continue
                else:
                    if param.grad is not None:
                        param.grad.detach()
                        param.grad.zero_()
            optim.step()
            if point_idx % 100 == 0:
                print(f"Point {point_idx} ER step {er_idx} error: {error} ")

        # A_mu_norm = rnn.A_mu.grad.norm()
        # if  A_mu_norm < A_norm_thresh: # or A_mu_norm > 20:
        #     trace_factor = -0.5 if A_mu_norm < A_norm_thresh else 1.0
        #     A_mu = rnn.A_mu.clone()
        #     trace_est = 0.0
        #     optim.zero_grad()
        #     for _ in range(num_probes):
        #         z = torch.empty_like(A_mu).bernoulli_(0.5).mul(2).sub(1)
        #         _, Jz = jvp(loss_A2, (A_mu,), (z,) )
        #         trace_est = trace_est + (Jz * z).sum()
        #     trace_est = trace_factor * trace_est / num_probes
        #     print(f"-------------->>  Epoch {epoch}. Trace est: {trace_est}. trace factor: {trace_factor}. A_mu norm: {A_mu_norm}")
        #     trace_est.backward()
        #     optim.step()
            
        coords = new_coords.detach()
    return gen_coordinates.detach().cpu().numpy()


gen_coordinates = generate_trajectory(init_coords, n_points)
gen_coordinates_er = generate_trajectory_ER(init_coords, n_points, er_steps, er_window, coordinates[:])
all_generated_er.append(gen_coordinates_er)
all_generated.append(gen_coordinates)


In [None]:
ex_idx = 0  # index of the example to plot

n_traces = len(all_generated)

if n_traces > 3:
    n_rows = int(np.floor(np.sqrt(n_traces + 1)))
    n_cols = int(np.ceil((n_traces+1) / n_rows))
else:
    n_cols = n_traces + 1 
    n_rows = 1
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(5.1*n_cols,4.6*n_rows))
for ax in axs:
    if type(ax) == np.ndarray:
        for _ax in ax:
            _ax.set_aspect('equal')
    else:
        ax.set_aspect('equal')

original = coordinates.cpu().detach()
if n_rows > 1:
    ax0 = axs[0,0]
else:
    ax0 = axs[0]
ax0.plot(original[ex_idx, :, 0], original[ex_idx, :, 1])
ax0.set_title('original trace')
for trace in range(n_traces):
    gen_coordinates = all_generated[trace]
    if n_rows == 1:
        ax = axs[trace+1]
    else:
        row = int(np.floor((trace+1)/n_cols))
        col = int(np.ceil((trace+1)%n_cols)) 
        ax = axs[row, col]
    # ax.set_title(f'After {n_epochs*(trace+1)} epochs')
    if trace == 0:
        er_st = 10
        er_win = 10
    elif trace == 1:
        er_st = 30
        er_win = 15
    else:
        er_st = 100
        er_win = 40
    ax.set_title(f"ER steps: {er_st}, ER window: {er_win}")
    ax.plot(gen_coordinates[ex_idx, :, 0], gen_coordinates[ex_idx, :, 1])

In [None]:
n_traces = len(all_generated_er)

if n_traces > 3:
    n_rows = int(np.floor(np.sqrt(n_traces + 1)))
    n_cols = int(np.ceil((n_traces+1) / n_rows))
else:
    n_cols = n_traces + 1 
    n_rows = 1
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(5.1*n_cols,4.6*n_rows))
for ax in axs:
    if type(ax) == np.ndarray:
        for _ax in ax:
            _ax.set_aspect('equal')
    else:
        ax.set_aspect('equal')

original = coordinates.cpu().detach()
if n_rows > 1:
    ax0 = axs[0,0]
else:
    ax0 = axs[0]
# ax0.plot(original[ex_idx+2, :, 0], original[ex_idx+2, :, 1])
ax0.plot(original[ex_idx, :, 0], original[ex_idx, :, 1])
ax0.set_title('original trace')
for trace in range(n_traces):
    gen_coordinates = all_generated_er[trace]
    if n_rows == 1:
        ax = axs[trace+1]
    else:
        row = int(np.floor((trace+1)/n_cols))
        col = int(np.ceil((trace+1)%n_cols)) 
        ax = axs[row, col]
    # ax.set_title(f'After {n_epochs*(trace+1)} epochs')
    if trace == 0:
        er_st = 10
        er_win = 10
    elif trace == 1:
        er_st = 30
        er_win = 15
    else:
        er_st = 100
        er_win = 40
    ax.set_title(f"ER steps: {er_st}, ER window: {er_win}")
    ax.plot(gen_coordinates[ex_idx, :, 0], gen_coordinates[ex_idx, :, 1])

In [None]:
ex_idx = 1

n_traces = len(all_generated_er)

if n_traces > 3:
    n_rows = int(np.floor(np.sqrt(n_traces + 1)))
    n_cols = int(np.ceil((n_traces+1) / n_rows))
else:
    n_cols = n_traces + 1 
    n_rows = 1
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(5.1*n_cols,4.6*n_rows))
for ax in axs:
    if type(ax) == np.ndarray:
        for _ax in ax:
            _ax.set_aspect('equal')
    else:
        ax.set_aspect('equal')

original = coordinates.cpu().detach()
if n_rows > 1:
    ax0 = axs[0,0]
else:
    ax0 = axs[0]
# ax0.plot(original[ex_idx+2, :, 0], original[ex_idx+2, :, 1])
ax0.plot(original[ex_idx, :, 0], original[ex_idx, :, 1])
ax0.set_title('original trace')
for trace in range(n_traces):
    gen_coordinates = all_generated_er[trace]
    if n_rows == 1:
        ax = axs[trace+1]
    else:
        row = int(np.floor((trace+1)/n_cols))
        col = int(np.ceil((trace+1)%n_cols)) 
        ax = axs[row, col]
    # ax.set_title(f'After {n_epochs*(trace+1)} epochs')
    if trace == 0:
        er_st = 10
        er_win = 10
    elif trace == 1:
        er_st = 30
        er_win = 15
    else:
        er_st = 100
        er_win = 40
    ax.set_title(f"ER steps: {er_st}, ER window: {er_win}")
    ax.plot(gen_coordinates[ex_idx, :, 0], gen_coordinates[ex_idx, :, 1])

In [None]:
ex_idx = 2

n_traces = len(all_generated_er)

if n_traces > 3:
    n_rows = int(np.floor(np.sqrt(n_traces + 1)))
    n_cols = int(np.ceil((n_traces+1) / n_rows))
else:
    n_cols = n_traces + 1 
    n_rows = 1
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(5.1*n_cols,4.6*n_rows))
for ax in axs:
    if type(ax) == np.ndarray:
        for _ax in ax:
            _ax.set_aspect('equal')
    else:
        ax.set_aspect('equal')

original = coordinates.cpu().detach()
if n_rows > 1:
    ax0 = axs[0,0]
else:
    ax0 = axs[0]
ax0.plot(original[ex_idx, :, 0], original[ex_idx, :, 1])
ax0.set_title('original trace')
for trace in range(n_traces):
    gen_coordinates = all_generated_er[trace]
    if n_rows == 1:
        ax = axs[trace+1]
    else:
        row = int(np.floor((trace+1)/n_cols))
        col = int(np.ceil((trace+1)%n_cols)) 
        ax = axs[row, col]
    # ax.set_title(f'After {n_epochs*(trace+1)} epochs')
    if trace == 0:
        er_st = 10
        er_win = 10
    elif trace == 1:
        er_st = 30
        er_win = 15
    else:
        er_st = 100
        er_win = 40
    ax.set_title(f"ER steps: {er_st}, ER window: {er_win}")
    ax.plot(gen_coordinates[ex_idx, :, 0], gen_coordinates[ex_idx, :, 1])

In [None]:
ex_idx = 3

n_traces = len(all_generated_er)

if n_traces > 3:
    n_rows = int(np.floor(np.sqrt(n_traces + 1)))
    n_cols = int(np.ceil((n_traces+1) / n_rows))
else:
    n_cols = n_traces + 1 
    n_rows = 1
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(5.1*n_cols,4.6*n_rows))
for ax in axs:
    if type(ax) == np.ndarray:
        for _ax in ax:
            _ax.set_aspect('equal')
    else:
        ax.set_aspect('equal')

original = coordinates.cpu().detach()
if n_rows > 1:
    ax0 = axs[0,0]
else:
    ax0 = axs[0]
ax0.plot(original[ex_idx, :, 0], original[ex_idx, :, 1])
ax0.set_title('original trace')
for trace in range(n_traces):
    gen_coordinates = all_generated_er[trace]
    if n_rows == 1:
        ax = axs[trace+1]
    else:
        row = int(np.floor((trace+1)/n_cols))
        col = int(np.ceil((trace+1)%n_cols)) 
        ax = axs[row, col]
    # ax.set_title(f'After {n_epochs*(trace+1)} epochs')
    if trace == 0:
        er_st = 10
        er_win = 10
    elif trace == 1:
        er_st = 30
        er_win = 15
    else:
        er_st = 100
        er_win = 40
    ax.set_title(f"ER steps: {er_st}, ER window: {er_win}")
    ax.plot(gen_coordinates[ex_idx, :, 0], gen_coordinates[ex_idx, :, 1])

In [None]:
# plot a direction field
ex_idx = 0
lim_x = 1.2 # largest magnitude of x
lim_y = 1.2
N_x = 3 # number of points in x dimension
N_y = 3
n_points = 400
bs = rnn.A_mu.shape[0]
#df_fig = plt.figure(figsize = (10,10))
fig, axs = plt.subplots(1, 2, figsize=(14, 7))
axs[0].set_xlim(-1.1*lim_x, 1.1*lim_x)
axs[0].set_ylim(-1.1*lim_y, 1.1*lim_y)
axs[0].axis('equal')
for x in np.linspace(-lim_x, lim_x, N_x):
    for y in np.linspace(-lim_y, lim_y, N_y):
        #print(f"x = {x}, y = {y}")
        axs[0].scatter([x], [y], color='tab:pink')
        axs[1].scatter([x], [y], color='tab:pink')
        coordinate = torch.tensor([x, y], dtype=torch.float32, device=device).unsqueeze(dim=0)
        coordinate = coordinate.tile((bs, 1, 1))
        trajectory1 = generate_trajectory(coordinate, n_points)
        trajectory2 = generate_trajectory_ER(coordinate, n_points, 20, 10, coordinates[:bs])
        axs[0].scatter([trajectory1[ex_idx, 0, 0]], [trajectory1[ex_idx, 0, 1]], color='tab:cyan')
        axs[1].scatter([trajectory2[ex_idx, 0, 0]], [trajectory2[ex_idx, 0,1]], color='tab:cyan')
        axs[0].plot(trajectory1[ex_idx, :, 0], trajectory1[ex_idx, :, 1])
        axs[1].plot(trajectory2[ex_idx, :, 0], trajectory2[ex_idx, :, 1])
plt.show()

In [None]:
ex_idx = 1

# plot a direction field
lim_x = 1.2 # largest magnitude of x
lim_y = 1.2
N_x = 3 # number of points in x dimension
N_y = 3
n_points = 400
bs = rnn.A_mu.shape[0]
#df_fig = plt.figure(figsize = (10,10))
fig, axs = plt.subplots(1, 2, figsize=(14, 7))
axs[0].set_xlim(-1.1*lim_x, 1.1*lim_x)
axs[0].set_ylim(-1.1*lim_y, 1.1*lim_y)
axs[0].axis('equal')
for x in np.linspace(-lim_x, lim_x, N_x):
    for y in np.linspace(-lim_y, lim_y, N_y):
        #print(f"x = {x}, y = {y}")
        axs[0].scatter([x], [y], color='tab:pink')
        axs[1].scatter([x], [y], color='tab:pink')
        coordinate = torch.tensor([x, y], dtype=torch.float32, device=device).unsqueeze(dim=0)
        coordinate = coordinate.tile((bs, 1, 1))
        trajectory1 = generate_trajectory(coordinate, n_points)
        trajectory2 = generate_trajectory_ER(coordinate, n_points, 20, 10, coordinates[:bs])
        axs[0].scatter([trajectory1[ex_idx, 0, 0]], [trajectory1[ex_idx, 0, 1]], color='tab:cyan')
        axs[1].scatter([trajectory2[ex_idx, 0, 0]], [trajectory2[ex_idx, 0,1]], color='tab:cyan')
        axs[0].plot(trajectory1[ex_idx, :, 0], trajectory1[ex_idx, :, 1])
        axs[1].plot(trajectory2[ex_idx, :, 0], trajectory2[ex_idx, :, 1])
plt.show()

## LSTM Comparison

In [29]:
# Same thing, but using LSTM
# Create the network, the optimizer, and the loss function
input_size = 2  # this shouldn't change
hidden_size = 256
lr = 3e-4  # learning rate
all_generated = []

torch.manual_seed(345)

class LSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM_RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            batch_first=True,
                            device=device)
        self.fc = nn.Linear(hidden_size, 2, device=device)

    def forward(self, x, h0=None, c0=None):
        if h0 is None:
            h0 =  torch.zeros((1, hidden_size), device=device)
        if c0 is None:
            c0 =  torch.zeros((1, hidden_size), device=device)
        out, (h, c) = self.lstm(x, (h0, c0))
        return self.fc(out), h, c

# rnn = ElmanRNN(input_size, hidden_size, nonlinearity)
rnn = LSTM_RNN(input_size, hidden_size)

optim = torch.optim.Adam(rnn.parameters(), lr=lr)

loss = nn.MSELoss(reduction='mean')

In [None]:
# training loop
n_epochs = 1000  # number of epochs

for epoch in range(n_epochs):
    optim.zero_grad()
    out, _, _  = rnn(coordinates[:-1])
    error = loss(out, coordinates[1:])
    error.backward()
    optim.step()
    if epoch % 20 == 0:
        print(f"Error at epoch {epoch} = {error}")

In [53]:
# Generate a trajectory
n_points = 1000
coordinate = torch.tensor([-1, 1], dtype=torch.float32, device=device)
#coordinate = coordinates[0]
coordinate = coordinate.unsqueeze(dim=0)

def generate_trajectory(init_coordinate, n_points):
    gen_coordinates = np.empty((n_points, 2))
    coordinate = init_coordinate
    h = torch.zeros((1, hidden_size), device=device)
    c = torch.zeros((1, hidden_size), device=device)
    with torch.no_grad():
        for point_idx in range(n_points):
            coordinate, h, c = rnn(coordinate, h, c)
            gen_coordinates[point_idx, :] = coordinate.detach().cpu().numpy()
    return gen_coordinates

gen_coordinates = generate_trajectory(coordinate, n_points)
all_generated.append(gen_coordinates)

In [None]:
n_traces = len(all_generated)

if n_traces > 3:
    n_rows = int(np.floor(np.sqrt(n_traces + 1)))
    n_cols = int(np.ceil((n_traces+1) / n_rows))
else:
    n_cols = n_traces + 1 
    n_rows = 1
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(5.1*n_cols,4.6*n_rows))
for ax in axs:
    if type(ax) == np.ndarray:
        for _ax in ax:
            _ax.set_aspect('equal')
    else:
        ax.set_aspect('equal')
original = coordinates.cpu().detach()
if n_rows > 1:
    ax0 = axs[0,0]
else:
    ax0 = axs[0]
ax0.plot(original[:,0], original[:,1])
ax0.set_title('original trace')
for trace in range(n_traces):
    gen_coordinates = all_generated[trace]
    if n_rows == 1:
        ax = axs[trace+1]
    else:
        row = int(np.floor((trace+1)/n_cols))
        col = int(np.ceil((trace+1)%n_cols)) 
        ax = axs[row, col]
    ax.set_title(f'After {n_epochs*(trace+1)} epochs')
    ax.plot(gen_coordinates[:,0], gen_coordinates[:,1])

In [None]:
# plot a direction field
lim_x = 2. # largest magnitude of x
lim_y = 2.
N_x = 2 # number of points in x dimension
N_y = 2
n_points = 800
#df_fig = plt.figure(figsize = (10,10))
fig = plt.figure(figsize=(5,5))
plt.xlim(-1.1*lim_x, 1.1*lim_x)
plt.ylim(-1.1*lim_y, 1.1*lim_y)
plt.axis('equal')
for x in np.linspace(-lim_x, lim_x, N_x):
    for y in np.linspace(-lim_y, lim_y, N_y):
        #print(f"x = {x}, y = {y}")
        plt.scatter([x], [y], color='tab:pink')
        coordinate = torch.tensor([x, y], dtype=torch.float32, device=device).unsqueeze(dim=0)
        trajectory = generate_trajectory(coordinate, n_points)
        plt.scatter([trajectory[0, 0]], [trajectory[0,1]], color='tab:cyan')
        plt.plot(trajectory[:, 0], trajectory[:, 1])
plt.show()

In [None]:
# Create the network, the optimizer, and the loss function
input_size = 2  # this shouldn't change
hidden_size = 256
nonlinearity = 'tanh'  # tanh or relu
lr = 3e-4  # learning rate
all_generated = []

torch.manual_seed(345)

class ElmanRNN(nn.Module):
    def __init__(self, input_size, hidden_size, nonlinearity):
        super(ElmanRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.rnn = nn.RNN(input_size=input_size,
                          hidden_size=hidden_size,
                          nonlinearity=nonlinearity,
                          num_layers=1,
                          batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)

    def forward(self, x, h0=None):
        if h0 is None:
            h0 =  torch.zeros((1, hidden_size))
        out, h = self.rnn(x, h0)
        return self.fc(out), h

rnn = ElmanRNN(input_size, hidden_size, nonlinearity)

optim = torch.optim.Adam(rnn.parameters(), lr=lr)

loss = nn.MSELoss(reduction='mean')