In [1]:
import numpy as np
import torch

from tqdm import tqdm
from matplotlib import pyplot as plt

import seaborn as sns

%load_ext autoreload
%autoreload 2

In [2]:
# If GPUs available, select which to train on
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
from utils import generate_run_ID, load_trained_weights
from place_cells import PlaceCells
from trajectory_generator import TrajectoryGenerator
from model import RNN
from trainer import Trainer
from visualize import compute_ratemaps
from scores import GridScorer, BandScorer

In [4]:
# Training options and hyperparameters
class Options:
    pass
options = Options()

options.save_dir = './models'
options.n_steps = 1000      # number of training steps
options.batch_size = 200      # number of trajectories per batch
options.sequence_length = 20  # number of steps in trajectory
options.learning_rate = 1e-4  # gradient descent learning rate
options.Np = 512              # number of place cells
options.Ng = 4096           # number of grid cells
options.place_cell_rf = 0.12  # width of place cell center tuning curve (m)
options.surround_scale = 2    # if DoG, ratio of sigma2^2 to sigma1^2
options.RNN_type = 'RNN'      # RNN, RNN_2RNN or RNN_reconstruction
options.activation = 'relu'   # recurrent nonlinearity
options.weight_decay = 1e-4   # strength of weight decay on recurrent weights
options.DoG = True            # use difference of gaussians tuning curves 
options.periodic = False      # trajectories with periodic boundary conditions
options.box_width = 2.2       # width of training environment (m)
options.box_height = 2.2      # height of training environment (m)
options.seed = None           # random seed

options.run_ID = generate_run_ID(options)
options.device = torch.device('cpu')

In [5]:
# If you've trained with these params before, will restore trained model
place_cells = PlaceCells(options)
model = RNN(options, place_cells)
trajectory_generator = TrajectoryGenerator(options, place_cells)
trainer = Trainer(options, model, trajectory_generator)

Initializing new model from scratch.
Saving to: ./models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001


In [6]:
# Load example weights stored on github
weight_dir = options.save_dir + '/example_trained_weights.npy'
load_trained_weights(model, trainer, weight_dir)

Initialized trained weights.
Epoch: 0/1. Step 0/1. Loss: 6.37. Err: 92.58cm
Loaded trained weights.
Epoch: 0/1. Step 0/1. Loss: 6.16. Err: 4.66cm


In [7]:
save_dir = options.save_dir + '/' + options.run_ID + '/data/'
os.makedirs(save_dir, exist_ok=True)
os.chdir(save_dir)

# Activations

In [None]:
# Compute a set of high-res maps
res = 50
n_avg = 100
Ng = options.Ng
idxs = np.arange(Ng)

activations, rate_map, g, pos, activations_theta = compute_ratemaps(model,
                                                 trajectory_generator,
                                                 options,
                                                 res=res,
                                                 n_avg=n_avg,
                                                 Ng=Ng, idxs=idxs)

# Compute a set of lo-res maps to use for evalutaing grid score
lo_res = 20 # low resolution
_, rate_map_lores, _, _, _ = compute_ratemaps(model,
                                         trajectory_generator,
                                         options,
                                         res=lo_res,
                                         n_avg=n_avg,
                                         Ng=Ng, idxs=idxs)

In [None]:
# Save data
np.save('rate_map.npy', rate_map)
np.save('activations.npy', activations)
np.save('activations_theta.npy', activations_theta)
np.save('rate_map_lores.npy', rate_map_lores)

# Band Cell

In [8]:
res = 50
band_scorer = BandScorer(res, options.box_width, options.box_height)

## Band score, spacing, orientation

In [None]:
rates = np.load('rate_map.npy')

Ng = rates.shape[0]
score = np.zeros(Ng,)
k = np.zeros(Ng,)
theta = np.zeros(Ng)
ratemap = np.zeros([Ng, res, res])

for i in tqdm(range(Ng)):
    rate = (rates[i] - np.mean(rates[i])).reshape(res, res)
    score[i], _, _, _, k[i], _, _ = band_scorer.comput_band_score(rate)
    theta[i] = band_scorer.comput_orientation(rate)

L = 2*np.pi/k  # spacing
np.save('band_score.npy', score)
np.save('L.npy', L)  
np.save('orientation.npy', theta)

  0%|          | 0/4096 [00:00<?, ?it/s]

100%|██████████| 4096/4096 [00:46<00:00, 88.40it/s] 


## phase

In [None]:
# Identify high band score cells
score_thres = 5.8
high_band_idxs = np.where(score > score_thres)[0]

In [None]:
# compute phase
phase = np.zeros(Ng)  # Use plural type

# "loc" here is different from what we use in "band score", it's the real physical location
X, Y = np.meshgrid(np.linspace(-options.box_width/2, options.box_width/2, res), np.linspace(-options.box_height/2, options.box_height/2, res))
x_flat, y_flat = X.flatten(), Y.flatten()
loc = np.stack([x_flat, y_flat])

N_band = high_band_idxs.shape[0]
for i in range(N_band):
    idx = high_band_idxs[i]
    j = 1j  # imaginary component
    k_vec = np.array([k[idx]*np.cos(theta[idx]), k[idx]*np.sin(theta[idx])]).reshape(1, 2)
    loc_phase = np.mod(np.dot(k_vec, loc), L[idx]) / L[idx]*np.pi*2 - np.pi # 1*Ng
    phase[idx] = np.angle(np.sum(np.exp(loc_phase*1j)*rates[idx])/np.sum(rates[idx])) # 1

np.save('phase.npy', phase)

## direction

In [11]:
activations_theta = np.load('activations_theta.npy')

### prefered direction

In [None]:
prefer_dir = np.zeros(Ng) 
theta = np.linspace(-np.pi, np.pi, activations_theta.shape[1], endpoint=False)
for i in range(Ng):
    prefer_dir[i] = np.angle(np.sum(np.exp(theta*1j)*activations_theta[i,:]))

np.save('prefer_dir.npy', prefer_dir)

### direction score

In [None]:
dir_scores, A_params, mu_params, sigma_params = band_scorer.direction_score(activations_theta)

np.save('direction_scores.npy', dir_scores)
np.save('A_params.npy', A_params)
np.save('sigma_params.npy', sigma_params)
np.save('mu_params.npy', mu_params)

# Grid cell


## Grid score

In [None]:
rate_map_lores = np.load('rate_map_lores.npy')

res = 20
starts = [0.2] * 10
ends = np.linspace(0.4, 1.0, num=10)
box_width=options.box_width
box_height=options.box_height
coord_range=((-box_width/2, box_width/2), (-box_height/2, box_height/2))
masks_parameters = zip(starts, ends.tolist())
grid_scorer = GridScorer(res, coord_range, masks_parameters)

In [None]:
score_60, score_90, max_60_mask, max_90_mask, sac, max_60_ind = zip(
      *[grid_scorer.get_scores(rm.reshape(res, res)) for rm in tqdm(rate_map_lores)])

np.save('grid_score.npy', score_60)

  x_coef = np.divide(covar, np.multiply(std_seq1, std_seq2))
100%|██████████| 4096/4096 [00:34<00:00, 119.61it/s]
