In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
%matplotlib inline

from lib.FramesDataset import FramesDataset
from lib import network

WARMUP = 20
T_STEPS = 5
FRAME_SIZE = 20

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using", DEVICE)

In [None]:
train_dataset = FramesDataset('./datasets/processed_dataset_small.pkl', 'train', WARMUP)
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128)

print("Training dataset length:", len(train_dataset))

model = network.RNN.load(
    hidden_units = 400,
    frame_size = FRAME_SIZE,
    t_steps = T_STEPS,
    path = './models/model-2000epochs-20210206-080756.pt'
)

In [None]:
hidden_units = 400
hidden_unit_rf = 4

stimuli = {}
for i in range(hidden_units):
    stimuli[i] = []

for t in range(100):
    if t % 1000 == 0:
        print('Trial', t)
    
    noise = np.random.normal(loc = 0, scale = 1, size=(FRAME_SIZE**2))
    noise = torch.Tensor(noise).unsqueeze(0).to(DEVICE)
    
    input = noise
    hidden_state = torch.zeros((1, hidden_units)).to(DEVICE)
    
    for t in range(T_STEPS):
        hidden_state = model.rnn(input, hidden_state)
        input = model.fc(hidden_state)

    units = hidden_state[0, :]
    for i, unit in enumerate(units):
        if unit > 0:
            stimuli[i].append(noise.squeeze(0).cpu().detach().numpy())

averaged_stimuli = {}
for i in range(hidden_units):
    if len(stimuli[i]):
        averaged_stimuli[i] = np.stack(stimuli[i])
        averaged_stimuli[i] = np.reshape(np.mean(stimuli[i], 0), (FRAME_SIZE, FRAME_SIZE))
        plt.imshow(averaged_stimuli[i], cmap='gray')

In [None]:
network.plot_input_weights(model, hidden_units = 10, frame_size = FRAME_SIZE)

x, y = train_dataset[0]
x = x.to(DEVICE)
y = y.to(DEVICE)
pred = torch.squeeze(model(torch.unsqueeze(x, 0)), 0)

fig, axes = plt.subplots(1, T_STEPS+1, dpi=150)
for i in range(T_STEPS+1):
    im_y = y[i].view(20, 20).detach().cpu().numpy()
    axes[i].imshow(im_y, cmap='gray')
    axes[i].axis('off')
plt.show()

fig, axes = plt.subplots(1, T_STEPS+1, dpi=150)
for i in range(T_STEPS+1):
    im_pred = pred[i].view(20, 20).detach().cpu().numpy()
    axes[i].imshow(im_pred, cmap='gray')
    axes[i].axis('off')
plt.show()