In [1]:
# imports
import math, random

# import gym
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F
from collections import deque

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
USE_CUDA = torch.cuda.is_available()
Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)

In [None]:
class MonkRNN(nn.Module):
    def __init__(self, inp_size, n_neurons, out_size, rnn_nonlinearity = 'relu',curl_field=False):
        super(MonkRNN, self).__init__()
        
        
        
        # simplify
        
        # input -> ReLU RNN -> Linear -> ReLU -> Linear -> Curl field and friction  -> output (velocity x, velocity y)
        
        # Loss - squared error with target velocity + l2 regularization on all parameters

        self.inp_size = inp_size
        self.out_size = out_size
        self.n_neurons = n_neurons
        self.init_hidden_state = nn.Parameter(torch.zeros(1, n_neurons))

        # Inputs are:
        
        # Recurrent layer parameters.
        self.rnn = nn.RNNCell(input_size=n_neurons,hidden_size=n_neurons,nonlinearity=rnn_nonlinearity)
        self.rnn.weight_hh.data.normal_(0, 2 / np.sqrt(n_neurons))

        # Output layer parameters
        self.out_layer = nn.Sequential(nn.Linear(n_neurons, n_neurons), nn.ReLU(), nn.Linear(n_neurons, out_size))
        self.out_layer[0].weight.data.normal_(0, 1.5 / np.sqrt(n_neurons))
        self.out_layer[2].weight.data.normal_(0, 1.5 / np.sqrt(n_neurons))
        
        

    def forward(self, inp, curl, h_old=None):
        """
        Parameters
        ----------
        inp : torch.tensor
            Hand and target positions. Has shape (7,).
            (go, hand_x, hand_y, curr_tgx, curr_tgy, next_tgx, next_tgy)
        hidden : torch.tensor
            Initial firing rates. Has shape (n_neurons,)
        task_info : torch.tensor
            tensor holding (go, curr_tgx, curr_tgy, next_tgx, next_tgy)

        Returns
        -------
        acc : torch.tensor
            has shape (2,) corresponding to x and y acceleration.
        hiddens : torch.tensor
            has shape (n_neurons,) corresponding network activity.
        """

        if h_old is None:
            h_old = self.init_hidden_state

        # Update RNN one time step.
        h_new = self.rnn(inp, h_old)

        # Collect RNN output (acceleration of hand).
        dv = self.out_layer(h_new)
        
        dv = dv + curl
        
        
        # print("hidden {}, Acc {}".format(h_new, acc))
        return dv, h_new
    
   

In [None]:
class GenerateInputTargetTimeseries():
    def __init__(self, screen_size):
        super(TeacherHandDynamics, self).__init__()
       
    def initialize(self, hx, hy):
        self.hx = torch.tensor([hx])
        self.hy = torch.tensor([hy])
        self.vx = torch.zeros(1)
        self.vy = torch.zeros(1)
        

    def forward(self, acc, targ_loc):
        """
        Parameters
        ----------
        acc : torch.tensor
            has shape (2,) correspond to (ax, ay)
        targ_loc : torch.tensor
            location of the current target
        
        Returns
        -------
        hand : torch.tensor
            has shape (2,) corresponding to (hx, hy).
        """
        acc = acc.squeeze()
        
        hx = self.hx.item()
        hy = self.hy.item()
        tgx = targ_loc[0].item()
        tgy = targ_loc[1].item()

        # Teacher direction
        errx = (tgx - hx)
        erry = (tgy - hy)
        self.int_errx += 0.9 * errx
        self.int_erry += 0.9 * erry
        
        teach_x = 0.01 * self.int_errx + PROP_GAIN * errx
        teach_y = 0.01 * self.int_erry + PROP_GAIN * erry

        # Add teacher contribution
        ax = teach_x * self.teacher_scale + acc[0] * (1 - self.teacher_scale)
        ay = teach_y * self.teacher_scale + acc[1] * (1 - self.teacher_scale)

        self.vx = self.vx / self.friction + ax
        self.vy = self.vy / self.friction + ay
        self.hx = torch.clamp(self.hx + self.vx, .1, self.screen_size - .1)
        self.hy = torch.clamp(self.hy + self.vy, .1, self.screen_size - .1)
        print(teach_x,teach_y,acc)
        teacher_loss = torch.clamp((teach_x - acc[0]) ** 2 + (teach_y - acc[1]) ** 2,-10,10)

        return (
            torch.cat((self.hx, self.hy)),
            teacher_loss,
            torch.tensor([teach_x, teach_y]),
        )