In [None]:
import numpy as np
import torch 
from torch import nn
import random
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import scipy.optimize as optimize
from sklearn.metrics import r2_score
from matplotlib import cm
import pandas

In [None]:
#%% rnn setting
# 1. receptive fields (2d Gaussian, center and sigma)
# 2. structured connection (weight depending on center distance)
# 3. self-adjusted possibility (as gain on connection weight)

        
def receptive_field_function(rf_type, dist_to_center, **kwargs):
    match rf_type: 
        case 'gaussian2d':
            sigma = kwargs['sigma']
            z = (1 / 
                 (2*np.pi*(sigma**2)) * 
                 np.exp(-(dist_to_center**2) / (2*sigma**2)))
            return z

        case 'step2d':
            if 'radius' in kwargs.keys():
                z = (dist_to_center**2<=kwargs['radius']).astype(int)
           
            elif ('radius1' in kwargs.keys()) & ('radius2' in kwargs.keys()):
                radius_on = kwargs['radius1']
                radius_peak = kwargs['radius2']
                assert radius_on > radius_peak, "radius1 should be > radius2"
                
                z = np.zeros_like(dist_to_center)
                
                z[dist_to_center**2 > radius_on**2] = 0
                z[dist_to_center**2 <= radius_peak**2] = 1
                flag_index = ((dist_to_center**2 > radius_peak**2) 
                  & (dist_to_center**2 <= radius_on**2))
                z[flag_index] = ((dist_to_center[flag_index]-radius_peak)
                                 /(radius_on-radius_peak))
            return z

        case 'pulse2d':
            radius = kwargs['radius']
            sigma = kwargs['sigma']
            z = (1 / 
                 (2*np.pi*(sigma**2)) * 
                 np.exp(-((dist_to_center-radius)**2) / (2*sigma**2)))
            return z


class pRF_RNN_Cell(nn.Module):
    
    def __init__(self, argument, centers_distr_type):
        
        super(pRF_RNN_Cell, self).__init__()
        
        self.N = argument.hidden_size
        self.tau = argument.tau
        self.a = argument.a  # range of excitatory connection
        self.b = argument.b  # amplitude of input-driven response
        self.rf_sigma = argument.rf_sigma  # sigma for rf gaussian
        self.noise_sigma = argument.noise_sigma  # sigma for noise gaussian
        
        self.phi = argument.phi
        self.kappa = argument.kappa
        
        if centers_distr_type == 'ring':
            self.centers = R_circle*np.array(
                [np.cos(np.linspace(0, np.pi*2-np.pi*2/self.N, self.N)),
                 np.sin(np.linspace(0, np.pi*2-np.pi*2/self.N, self.N))]).T
            
        elif centers_distr_type == 'uniform':
            cx, cy = eval(
                'np.mgrid[-R_circle:R_circle:%dj, -R_circle:R_circle:%dj]' % 
                (int(np.sqrt(self.N)), int(np.sqrt(self.N))))
            self.centers = np.array([cx.flatten(), cy.flatten()]).T
        
        self.w_h = nn.Linear(self.N, self.N)
        self.w_h.weight.requires_grad = False
        
        # connection weight varies with distance between centers
        # referring to CANN
        for i in range(self.N):
            for j in range(self.N):
                self.w_h.weight[i, j] = (
                    np.exp(-0.5 * 
                           np.sum((self.centers[i, :] - self.centers[j, :])**2)/self.a**2)
                    /(np.sqrt(2*np.pi)*self.a))

    
    def set_receptive_field(self, rf_type='gaussian2d', **kwargs):
         
        self.rf_type = rf_type
        self.rf_kwargs = kwargs


    def receptive_field_map(self):
        rf_list = []
        for i in range(self.N):
            x0, y0 = self.centers[i, 0], self.centers[i, 1]
            x, y = np.mgrid[-5:5:50j, -5:5:50j]
            
            z = receptive_field_function(self.rf_type, 
                                         np.sqrt(x**2+y**2), 
                                         **self.rf_kwargs)
            x, y = x+x0, y+y0
            rf_list.append([x, y, z])
        return rf_list
    
    
    def stimulus_activated_response(self, target_pos):
        xt, yt = target_pos[:, 0], target_pos[:, 1]
        
        ir = np.zeros((target_pos.shape[0], self.N))
        for i in range(self.N):
            xi, yi = self.centers[i, 0], self.centers[i, 1]
            r = np.sqrt((xt-xi)**2 + (yt-yi)**2)
            zi = receptive_field_function(self.rf_type, r, **self.rf_kwargs)
            ir[:, i] = zi*self.b
        
        return torch.from_numpy(np.array(ir, dtype=np.float32))
    
   
    def forward(self, target_pos, h_last, h_delta, ptype):
        
        relu = nn.ReLU()
        r_last = relu(torch.tanh(h_last))

        # basic        
        if isinstance(ptype, str) and (ptype == 'basic'):
            h_now = ((1-1/self.tau)*h_last + 
                      (self.w_h(r_last) +  
                        self.stimulus_activated_response(target_pos))/self.tau +
                      random.gauss(0, self.noise_sigma))
            r_now = relu(torch.tanh(h_now))       
            return h_now, r_now
        
        elif isinstance(ptype, (tuple, list)):
            p1, p2 = 0, 0
            if 'p1' in ptype[0]:
                p1 = 1 / (1 + torch.exp(-self.phi*h_delta+self.kappa))
                
            if 'p2' in ptype[0]:
                p2 = self.phi*r_last+self.kappa
            
            assert (isinstance(p1, torch.Tensor) 
                    or isinstance(p2, torch.Tensor))
            
            p = eval(ptype[0])
            
            # print(h_delta)
        
        # only adjusting input
            if ptype[1] == 'i':
                h_now = (
                    (1-1/self.tau)*h_last + 
                    (self.w_h(r_last) +  
                     p*self.stimulus_activated_response(target_pos))/self.tau +
                    + random.gauss(0, self.noise_sigma))

        # only adjusting interaction
            # p.repeat(self.N, 1) -> adjust output as same elements each column
            # p.repeat(self.N, 1).T -> adjust input as same elements each row
        
            if ptype[1] == 'h':
                h_now = (
                    (1-1/self.tau)*h_last + 
                    (torch.matmul(r_last, 
                                  p.repeat(self.N, 1).T*self.w_h.weight) +  
                      self.stimulus_activated_response(target_pos))/self.tau +
                    + random.gauss(0, self.noise_sigma))
                       
        # both 
            if ptype[1] == 'ih':
                h_now = (
                    (1-1/self.tau)*h_last + 
                    (torch.matmul(r_last, p.repeat(self.N, 1)*self.w_h.weight) +  
                     p*self.stimulus_activated_response(target_pos))/self.tau +
                    + random.gauss(0, self.noise_sigma))
                
            r_now = relu(torch.tanh(h_now))            
            return h_now, r_now


class pRF_RNN(nn.Module):
    def __init__(self, argument, centers_distr_type='ring'):
        
        super(pRF_RNN, self).__init__()
        
        self.N = argument.hidden_size
        self.k = argument.k
        self.cell = pRF_RNN_Cell(argument, centers_distr_type)

    def forward(self, targ_pos_series, initial_state, ptype):

        time_steps = targ_pos_series.shape[1]
        
        h = initial_state
        
        h_series = []
        r_series = [] 
        
        k = self.k
        for t in range(time_steps):
            if t < k:
                h, r = self.cell(targ_pos_series[:, t, :], h, 
                                 torch.zeros(self.N), ptype)
            else:
                h, r = self.cell(targ_pos_series[:, t, :], h, 
                                 (h-h_series[-k])/k, ptype)
            
            h_series.append(h)
            r_series.append(r)
            
        return (torch.stack(h_series, dim=1), torch.stack(r_series, dim=1))


In [None]:
#%% helper functions
def cal_pv(r, centers):
    pd = np.arctan2(centers[:, 1], centers[:, 0])
    cossum = np.sum(r*np.cos(pd))
    sinsum = np.sum(r*np.sin(pd))
    
    pv_len = np.sqrt(cossum**2+sinsum**2)
    pv_ang = np.arctan2(sinsum, cossum)
    
    return pv_ang, pv_len


def dist(p, z_range):
    ## refer to Wong et al. 
    ##  https://github.com/fccaa/cann_base/blob/master/cann_base.py
     tmp = np.remainder(p, z_range)
     if isinstance(tmp, (int, float)):
         if tmp > (0.5*z_range):
             return tmp - z_range # range:[-pi, pi]
         return tmp
     else:
         for tmp_1 in np.nditer(
                 tmp, op_flags=['readwrite']):
             if tmp_1 > (0.5*z_range):
                 tmp_1[...] = tmp_1 - z_range
         return tmp


def generate_targ_pos(t_max, sp_deg_per_sec, z0_deg):
    sp = sp_deg_per_sec/180*np.pi/1000  # angular speed: rad/ms
    z0 = z0_deg/180*np.pi
    theta_t = z0 + np.arange(t_max)*sp
    targ_pos_series = R_circle*np.array([np.cos(theta_t), np.sin(theta_t)]).T
    targ_pos_series = targ_pos_series[np.newaxis, :, :]
    
    return theta_t, targ_pos_series


def cal_diff(f, h, order, accuracy):
    if accuracy == 2:
        if order == 1:
            diff1 = [0]
            for i in np.arange(1, len(f)-1):
                diff1_i = (f[i+1]-f[i-1])/(2*h)
                diff1.append(diff1_i)
            diff1.append(0)
            return np.array(diff1)
        elif order == 2:
            diff2 = [0]
            for i in np.arange(1, len(f)-1):
                diff2_i = (f[i+1]-2*f[i]+f[i-1])/h**2
                diff2.append(diff2_i)
            diff2.append(0)
            return np.array(diff2)
        elif order == 3:
            diff3 = [0, 0]
            for i in np.arange(2, len(f)-2):
                diff3_i = (f[i+2]-f[i+1]+2*f[i-1]-f[i-2])/(2*h**3)
                diff3.append(diff3_i)
            diff3 = diff3 + [0, 0]
            return np.array(diff3)
    elif accuracy == 4:
        if order == 1:
            diff1 = [0, 0]
            for i in np.arange(2, len(f)-2):
                diff1_i = (-f[i+2]+8*f[i+1]-8*f[i-1]+f[i-2])/(12*h)
                diff1.append(diff1_i)
            diff1 = diff1 + [0, 0]
            return np.array(diff1)
        elif order == 2:
            diff2 = [0, 0]
            for i in np.arange(2, len(f)-2):
                diff2_i = (-f[i+2]+16*f[i+1]-30*f[0]+16*f[i-1]-f[i-2])/(12*h**2)
                diff2.append(diff2_i)
            diff2 = diff2 + [0, 0]
            return np.array(diff2)
        elif order == 3:
            diff3 = [0, 0, 0]
            for i in np.arange(3, len(f)-3):
                diff3_i = (-f[i+3]+8*f[i+2]-13*f[i+1]+13*f[i-1]-8*f[i-2]+f[i-3])/(8*h**3)
                diff3.append(diff3_i)
            diff3 = diff3 + [0, 0, 0]
            return np.array(diff3)

In [None]:
#%% plot functions
def visualize(now_rnn, r_np, targ_pos_series, theta_t):
    
    t_max = targ_pos_series.shape[1]
    # Plot population activity as heatmap
    plt.figure(dpi=300)
    sns.heatmap(r_np[0, :, :].T)
    plt.xlabel('Time (ms)')
    plt.ylabel('Node firing rate')
    plt.show()
    
    # Plot population activity with receptive field centers 
    plt.figure(dpi=300)
    for i, t in enumerate(np.arange(0, t_max, 500)):
        plt.scatter(
            np.arctan2(now_rnn.cell.centers[:, 1], now_rnn.cell.centers[:, 0]), 
            r_np[0, t, :], label=t)
        plt.vlines(x=np.arctan2(targ_pos_series[0, t, 1],
                                targ_pos_series[0, t, 0]),
                    ymin=0, ymax=1)
        
        pva, pvl = cal_pv(r_np[0, t, :], now_rnn.cell.centers)
        plt.vlines(x=pva, ymin=0, ymax=1)
        
    plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    plt.show()
    
    # Plot population activity in receptive field space with target, PV, 
    plt.figure(figsize=(12, 5), dpi=300)
    for i, t in enumerate(np.arange(0, t_max, 200)):
        plt.subplot(2, 5, i+1)
        plt.scatter(now_rnn.cell.centers[:, 0], now_rnn.cell.centers[:, 1], 
                    c=r_np[0, t, :])
        pva, pvl = cal_pv(r_np[0, t, :], now_rnn.cell.centers)
        plt.plot([0, 0.8*R_circle*np.cos(pva)], [0, 0.8*R_circle*np.sin(pva)])
        plt.scatter(targ_pos_series[0, t, 0]*0.85, targ_pos_series[0, t, 1]*0.85, 
                    c='k')
        plt.title('t=%d' % t)
        # if i==9:
        #     plt.colorbar()
    plt.tight_layout()
    plt.show()
    
    # Plot angle difference between target and PV, with time
    pvs = np.array([cal_pv(r_np[0, t, :], now_rnn.cell.centers) 
                    for t in range(t_max)])
    plt.figure(dpi=300)
    plt.plot(dist(pvs[:, 0] - theta_t, np.pi*2))
    plt.hlines(y=0, xmin=0, xmax=t_max, ls='--', color='k')
    plt.hlines(y=0, xmin=0, xmax=t_max, ls='--', color='k')
    plt.hlines(y=0, xmin=0, xmax=t_max, ls='--', color='k')
    plt.ylim(-np.pi, np.pi)
    plt.xlabel('time')
    plt.ylabel('PV(t) - target_angle(t)')
    plt.show()
    
    plt.figure(dpi=300)        
    ax = plt.subplot(111, projection='polar')
    ax.plot(np.ones_like(pvs[:, 0])*theta_t, np.arange(t_max), ls='--', c='k', label='target')
    ax.plot(pvs[:, 0], np.arange(t_max), label='PV')
    plt.legend(bbox_to_anchor=(2.04, 1), loc="upper left")
    plt.show()