# Pip Installs

In [9]:
!pip install --upgrade numpy
!pip install --upgrade tensorflow==1.12.3
!pip install --upgrade dm-sonnet==1.27
!pip install --upgrade scipy
!pip install --upgrade matplotlib
!pip install --upgrade tensorflow-probability==0.8.0
!pip install --upgrade wrapt==1.9.0

Requirement already up-to-date: numpy in /usr/local/lib/python3.6/dist-packages (1.17.4)
Requirement already up-to-date: tensorflow==1.12.3 in /usr/local/lib/python3.6/dist-packages (1.12.3)
Requirement already up-to-date: dm-sonnet==1.27 in /usr/local/lib/python3.6/dist-packages (1.27)
Requirement already up-to-date: scipy in /usr/local/lib/python3.6/dist-packages (1.3.2)
Requirement already up-to-date: matplotlib in /usr/local/lib/python3.6/dist-packages (3.1.1)
Requirement already up-to-date: tensorflow-probability==0.8.0 in /usr/local/lib/python3.6/dist-packages (0.8.0)
Requirement already up-to-date: wrapt==1.9.0 in /usr/local/lib/python3.6/dist-packages (1.9.0)


# dataset_reader.py

In [0]:
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Minimal queue based TFRecord reader for the Grid Cell paper."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os
import tensorflow as tf
nest = tf.contrib.framework.nest

DatasetInfo = collections.namedtuple(
    'DatasetInfo', ['basepath', 'size', 'sequence_length', 'coord_range'])

_DATASETS = dict(
    square_room=DatasetInfo(
        basepath='square_room_100steps_2.2m_1000000',
        size=100,
        sequence_length=100,
        coord_range=((-1.1, 1.1), (-1.1, 1.1))),)


def _get_dataset_files(dateset_info, root):
  """Generates lists of files for a given dataset version."""
  basepath = dateset_info.basepath
  base = os.path.join(root, basepath)
  num_files = dateset_info.size
  template = '{:0%d}-of-{:0%d}.tfrecord' % (4, 4)
  return [
      os.path.join(base, template.format(i, num_files - 1))
      for i in range(num_files)
  ]


class DataReader(object):
  """Minimal queue based TFRecord reader.
  You can use this reader to load the datasets used to train the grid cell
  network in the 'Vector-based Navigation using Grid-like Representations
  in Artificial Agents' paper.
  See README.md for a description of the datasets and an example of how to use
  the reader.
  """

  def __init__(
      self,
      dataset,
      root,
      # Queue params
      num_threads=4,
      capacity=256,
      min_after_dequeue=128,
      seed=None):
    """Instantiates a DataReader object and sets up queues for data reading.
    Args:
      dataset: string, one of ['jaco', 'mazes', 'rooms_ring_camera',
        'rooms_free_camera_no_object_rotations',
        'rooms_free_camera_with_object_rotations', 'shepard_metzler_5_parts',
        'shepard_metzler_7_parts'].
      root: string, path to the root folder of the data.
      num_threads: (optional) integer, number of threads used to feed the reader
        queues, defaults to 4.
      capacity: (optional) integer, capacity of the underlying
        RandomShuffleQueue, defaults to 256.
      min_after_dequeue: (optional) integer, min_after_dequeue of the underlying
        RandomShuffleQueue, defaults to 128.
      seed: (optional) integer, seed for the random number generators used in
        the reader.
    Raises:
      ValueError: if the required version does not exist;
    """

    if dataset not in _DATASETS:
      raise ValueError('Unrecognized dataset {} requested. Available datasets '
                       'are {}'.format(dataset, _DATASETS.keys()))

    self._dataset_info = _DATASETS[dataset]
    self._steps = _DATASETS[dataset].sequence_length

    with tf.device('/cpu'):
      file_names = _get_dataset_files(self._dataset_info, root)
      filename_queue = tf.train.string_input_producer(file_names, seed=seed)
      reader = tf.TFRecordReader()

      read_ops = [
          self._make_read_op(reader, filename_queue) for _ in range(num_threads)
      ]
      dtypes = nest.map_structure(lambda x: x.dtype, read_ops[0])
      shapes = nest.map_structure(lambda x: x.shape[1:], read_ops[0])

      self._queue = tf.RandomShuffleQueue(
          capacity=capacity,
          min_after_dequeue=min_after_dequeue,
          dtypes=dtypes,
          shapes=shapes,
          seed=seed)

      enqueue_ops = [self._queue.enqueue_many(op) for op in read_ops]
      tf.train.add_queue_runner(tf.train.QueueRunner(self._queue, enqueue_ops))

  def read(self, batch_size):
    """Reads batch_size."""
    in_pos, in_hd, ego_vel, target_pos, target_hd = self._queue.dequeue_many(
        batch_size)
    return in_pos, in_hd, ego_vel, target_pos, target_hd

  def get_coord_range(self):
    return self._dataset_info.coord_range

  def _make_read_op(self, reader, filename_queue):
    """Instantiates the ops used to read and parse the data into tensors."""
    _, raw_data = reader.read_up_to(filename_queue, num_records=64)
    feature_map = {
        'init_pos':
            tf.FixedLenFeature(shape=[2], dtype=tf.float32),
        'init_hd':
            tf.FixedLenFeature(shape=[1], dtype=tf.float32),
        'ego_vel':
            tf.FixedLenFeature(
                shape=[self._dataset_info.sequence_length, 3],
                dtype=tf.float32),
        'target_pos':
            tf.FixedLenFeature(
                shape=[self._dataset_info.sequence_length, 2],
                dtype=tf.float32),
        'target_hd':
            tf.FixedLenFeature(
                shape=[self._dataset_info.sequence_length, 1],
                dtype=tf.float32),
    }
    example = tf.parse_example(raw_data, feature_map)
    batch = [
        example['init_pos'], example['init_hd'],
        example['ego_vel'][:, :self._steps, :],
        example['target_pos'][:, :self._steps, :],
        example['target_hd'][:, :self._steps, :]
    ]
    return batch

# ensembles.py

In [0]:
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Ensembles of place and head direction cells.
These classes provide the targets for the training of grid-cell networks.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf


def one_hot_max(x, axis=-1):
  """Compute one-hot vectors setting to one the index with the maximum value."""
  return tf.one_hot(tf.argmax(x, axis=axis),
                    depth=x.get_shape()[-1],
                    dtype=x.dtype)


def softmax(x, axis=-1):
  """Compute softmax values for each sets of scores in x."""
  return tf.nn.softmax(x, dim=axis)


def softmax_sample(x):
  """Sample the categorical distribution from logits and sample it."""
  dist = tf.contrib.distributions.OneHotCategorical(logits=x, dtype=tf.float32)
  return dist.sample()


class CellEnsemble(object):
  """Abstract parent class for place and head direction cell ensembles."""

  def __init__(self, n_cells, soft_targets, soft_init):
    self.n_cells = n_cells
    if soft_targets not in ["softmax", "voronoi", "sample", "normalized"]:
      raise ValueError
    else:
      self.soft_targets = soft_targets
    # Provide initialization of LSTM in the same way as targets if not specified
    # i.e one-hot if targets are Voronoi
    if soft_init is None:
      self.soft_init = soft_targets
    else:
      if soft_init not in [
          "softmax", "voronoi", "sample", "normalized", "zeros"
      ]:
        raise ValueError
      else:
        self.soft_init = soft_init

  def get_targets(self, x):
    """Type of target."""

    if self.soft_targets == "normalized":
      targets = tf.exp(self.unnor_logpdf(x))
    elif self.soft_targets == "softmax":
      lp = self.log_posterior(x)
      targets = softmax(lp)
    elif self.soft_targets == "sample":
      lp = self.log_posterior(x)
      targets = softmax_sample(lp)
    elif self.soft_targets == "voronoi":
      lp = self.log_posterior(x)
      targets = one_hot_max(lp)
    return targets

  def get_init(self, x):
    """Type of initialisation."""

    if self.soft_init == "normalized":
      init = tf.exp(self.unnor_logpdf(x))
    elif self.soft_init == "softmax":
      lp = self.log_posterior(x)
      init = softmax(lp)
    elif self.soft_init == "sample":
      lp = self.log_posterior(x)
      init = softmax_sample(lp)
    elif self.soft_init == "voronoi":
      lp = self.log_posterior(x)
      init = one_hot_max(lp)
    elif self.soft_init == "zeros":
      init = tf.zeros_like(self.unnor_logpdf(x))
    return init

  def loss(self, predictions, targets):
    """Loss."""

    if self.soft_targets == "normalized":
      smoothing = 1e-2
      loss = tf.nn.sigmoid_cross_entropy_with_logits(
          labels=(1. - smoothing) * targets + smoothing * 0.5,
          logits=predictions,
          name="ensemble_loss")
      loss = tf.reduce_mean(loss, axis=-1)
    else:
      loss = tf.nn.softmax_cross_entropy_with_logits(
          labels=targets,
          logits=predictions,
          name="ensemble_loss")
    return loss

  def log_posterior(self, x):
    logp = self.unnor_logpdf(x)
    log_posteriors = logp - tf.reduce_logsumexp(logp, axis=2, keep_dims=True)
    return log_posteriors


class PlaceCellEnsemble(CellEnsemble):
  """Calculates the dist over place cells given an absolute position."""

  def __init__(self, n_cells, stdev=0.35, pos_min=-5, pos_max=5, seed=None,
               soft_targets=None, soft_init=None):
    super(PlaceCellEnsemble, self).__init__(n_cells, soft_targets, soft_init)
    # Create a random MoG with fixed cov over the position (Nx2)
    rs = np.random.RandomState(seed)
    self.means = rs.uniform(pos_min, pos_max, size=(self.n_cells, 2))
    self.variances = np.ones_like(self.means) * stdev**2

  def unnor_logpdf(self, trajs):
    # Output the probability of each component at each point (BxTxN)
    diff = trajs[:, :, tf.newaxis, :] - self.means[np.newaxis, np.newaxis, ...]
    unnor_logp = -0.5 * tf.reduce_sum((diff**2)/ self.variances, axis=-1)
    return unnor_logp


class HeadDirectionCellEnsemble(CellEnsemble):
  """Calculates the dist over HD cells given an absolute angle."""

  def __init__(self, n_cells, concentration=20, seed=None,
               soft_targets=None, soft_init=None):
    super(HeadDirectionCellEnsemble, self).__init__(n_cells,
                                                    soft_targets,
                                                    soft_init)
    # Create a random Von Mises with fixed cov over the position
    rs = np.random.RandomState(seed)
    self.means = rs.uniform(-np.pi, np.pi, (n_cells))
    self.kappa = np.ones_like(self.means) * concentration

  def unnor_logpdf(self, x):
    return self.kappa * tf.cos(x - self.means[np.newaxis, np.newaxis, :])

# model.py

In [14]:
!pip install --upgrade tensorflow-probability==0.5.0

Collecting tensorflow-probability==0.5.0
[?25l  Downloading https://files.pythonhosted.org/packages/a1/ca/6f213618b5f7d0bf6139e6ec928d412a5ca14e4776adfd41a59c74a34021/tensorflow_probability-0.5.0-py2.py3-none-any.whl (680kB)
[K     |▌                               | 10kB 16.3MB/s eta 0:00:01[K     |█                               | 20kB 1.8MB/s eta 0:00:01[K     |█▌                              | 30kB 2.6MB/s eta 0:00:01[K     |██                              | 40kB 3.4MB/s eta 0:00:01[K     |██▍                             | 51kB 2.1MB/s eta 0:00:01[K     |███                             | 61kB 2.5MB/s eta 0:00:01[K     |███▍                            | 71kB 2.9MB/s eta 0:00:01[K     |███▉                            | 81kB 3.3MB/s eta 0:00:01[K     |████▍                           | 92kB 3.6MB/s eta 0:00:01[K     |████▉                           | 102kB 2.8MB/s eta 0:00:01[K     |█████▎                          | 112kB 2.8MB/s eta 0:00:01[K     |█████▉       

In [15]:
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Model for grid cells supervised training.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy
import sonnet as snt
import tensorflow as tf


def displaced_linear_initializer(input_size, displace, dtype=tf.float32):
  stddev = 1. / numpy.sqrt(input_size)
  return tf.truncated_normal_initializer(
      mean=displace*stddev, stddev=stddev, dtype=dtype)


class GridCellsRNNCell(snt.RNNCore):
  """LSTM core implementation for the grid cell network."""

  def __init__(self,
               target_ensembles,
               nh_lstm,
               nh_bottleneck,
               nh_embed=None,
               dropoutrates_bottleneck=None,
               bottleneck_weight_decay=0.0,
               bottleneck_has_bias=False,
               init_weight_disp=0.0,
               name="grid_cells_core"):
    """Constructor of the RNN cell.
    Args:
      target_ensembles: Targets, place cells and head direction cells.
      nh_lstm: Size of LSTM cell.
      nh_bottleneck: Size of the linear layer between LSTM output and output.
      nh_embed: Number of hiddens between input and LSTM input.
      dropoutrates_bottleneck: Iterable of keep rates (0,1]. The linear layer is
        partitioned into as many groups as the len of this parameter.
      bottleneck_weight_decay: Weight decay used in the bottleneck layer.
      bottleneck_has_bias: If the bottleneck has a bias.
      init_weight_disp: Displacement in the weights initialisation.
      name: the name of the module.
    """
    super(GridCellsRNNCell, self).__init__(name=name)
    self._target_ensembles = target_ensembles
    self._nh_embed = nh_embed
    self._nh_lstm = nh_lstm
    self._nh_bottleneck = nh_bottleneck
    self._dropoutrates_bottleneck = dropoutrates_bottleneck
    self._bottleneck_weight_decay = bottleneck_weight_decay
    self._bottleneck_has_bias = bottleneck_has_bias
    self._init_weight_disp = init_weight_disp
    self.training = False
    with self._enter_variable_scope():
      self._lstm = snt.LSTM(self._nh_lstm)

  def _build(self, inputs, prev_state):
    """Build the module.
    Args:
      inputs: Egocentric velocity (BxN)
      prev_state: Previous state of the recurrent network
    Returns:
      ((predictions, bottleneck, lstm_outputs), next_state)
      The predictions
    """
    conc_inputs = tf.concat(inputs, axis=1, name="conc_inputs")
    # Embedding layer
    lstm_inputs = conc_inputs
    # LSTM
    lstm_output, next_state = self._lstm(lstm_inputs, prev_state)
    # Bottleneck
    bottleneck = snt.Linear(self._nh_bottleneck,
                            use_bias=self._bottleneck_has_bias,
                            regularizers={
                                "w": tf.contrib.layers.l2_regularizer(
                                    self._bottleneck_weight_decay)},
                            name="bottleneck")(lstm_output)
    if self.training and self._dropoutrates_bottleneck is not None:
      tf.logging.info("Adding dropout layers")
      n_scales = len(self._dropoutrates_bottleneck)
      scale_pops = tf.split(bottleneck, n_scales, axis=1)
      dropped_pops = [tf.nn.dropout(pop, rate, name="dropout")
                      for rate, pop in zip(self._dropoutrates_bottleneck,
                                           scale_pops)]
      bottleneck = tf.concat(dropped_pops, axis=1)
    # Outputs
    ens_outputs = [snt.Linear(
        ens.n_cells,
        regularizers={
            "w": tf.contrib.layers.l2_regularizer(
                self._bottleneck_weight_decay)},
        initializers={
            "w": displaced_linear_initializer(self._nh_bottleneck,
                                              self._init_weight_disp,
                                              dtype=tf.float32)},
        name="pc_logits")(bottleneck)
                   for ens in self._target_ensembles]
    return (ens_outputs, bottleneck, lstm_output), tuple(list(next_state))

  @property
  def state_size(self):
    """Returns a description of the state size, without batch dimension."""
    return self._lstm.state_size

  @property
  def output_size(self):
    """Returns a description of the output size, without batch dimension."""
    return tuple([ens.n_cells for ens in self._target_ensembles] +
                 [self._nh_bottleneck, self._nh_lstm])


class GridCellsRNN(snt.AbstractModule):
  """RNN computes place and head-direction cell predictions from velocities."""

  def __init__(self, rnn_cell, nh_lstm, name="grid_cell_supervised"):
    super(GridCellsRNN, self).__init__(name=name)
    self._core = rnn_cell
    self._nh_lstm = nh_lstm

  def _build(self, init_conds, vels, training=False):
    """Outputs place, and head direction cell predictions from velocity inputs.
    Args:
      init_conds: Initial conditions given by ensemble activatons, list [BxN_i]
      vels:  Translational and angular velocities [BxTxV]
      training: Activates and deactivates dropout
    Returns:
      [logits_i]:
        logits_i: Logits predicting i-th ensemble activations (BxTxN_i)
    """
    # Calculate initialization for LSTM. Concatenate pc and hdc activations
    concat_init = tf.concat(init_conds, axis=1)

    init_lstm_state = snt.Linear(self._nh_lstm, name="state_init")(concat_init)
    init_lstm_cell = snt.Linear(self._nh_lstm, name="cell_init")(concat_init)
    self._core.training = training

    # Run LSTM
    output_seq, final_state = tf.nn.dynamic_rnn(cell=self._core,
                                                inputs=(vels,),
                                                time_major=False,
                                                initial_state=(init_lstm_state,
                                                               init_lstm_cell))
    ens_targets = output_seq[:-2]
    bottleneck = output_seq[-2]
    lstm_output = output_seq[-1]
    # Return
    return (ens_targets, bottleneck, lstm_output), final_state

  def get_all_variables(self):
    return (super(GridCellsRNN, self).get_variables()
            + self._core.get_variables())

  version_spec = semantic_version.Spec('>=' + min_version)
  return _inspect.getargspec(target)


# scores.py

In [16]:
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Grid score calculations.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal


def circle_mask(size, radius, in_val=1.0, out_val=0.0):
  """Calculating the grid scores with different radius."""
  sz = [math.floor(size[0] / 2), math.floor(size[1] / 2)]
  x = np.linspace(-sz[0], sz[1], size[1])
  x = np.expand_dims(x, 0)
  x = x.repeat(size[0], 0)
  y = np.linspace(-sz[0], sz[1], size[1])
  y = np.expand_dims(y, 1)
  y = y.repeat(size[1], 1)
  z = np.sqrt(x**2 + y**2)
  z = np.less_equal(z, radius)
  vfunc = np.vectorize(lambda b: b and in_val or out_val)
  return vfunc(z)


class GridScorer(object):
  """Class for scoring ratemaps given trajectories."""

  def __init__(self, nbins, coords_range, mask_parameters, min_max=False):
    """Scoring ratemaps given trajectories.
    Args:
      nbins: Number of bins per dimension in the ratemap.
      coords_range: Environment coordinates range.
      mask_parameters: parameters for the masks that analyze the angular
        autocorrelation of the 2D autocorrelation.
      min_max: Correction.
    """
    self._nbins = nbins
    self._min_max = min_max
    self._coords_range = coords_range
    self._corr_angles = [30, 45, 60, 90, 120, 135, 150]
    # Create all masks
    self._masks = [(self._get_ring_mask(mask_min, mask_max), (mask_min,
                                                              mask_max))
                   for mask_min, mask_max in mask_parameters]
    # Mask for hiding the parts of the SAC that are never used
    self._plotting_sac_mask = circle_mask(
        [self._nbins * 2 - 1, self._nbins * 2 - 1],
        self._nbins,
        in_val=1.0,
        out_val=np.nan)

  def calculate_ratemap(self, xs, ys, activations, statistic='mean'):
    return scipy.stats.binned_statistic_2d(
        xs,
        ys,
        activations,
        bins=self._nbins,
        statistic=statistic,
        range=self._coords_range)[0]

  def _get_ring_mask(self, mask_min, mask_max):
    n_points = [self._nbins * 2 - 1, self._nbins * 2 - 1]
    return (circle_mask(n_points, mask_max * self._nbins) *
            (1 - circle_mask(n_points, mask_min * self._nbins)))

  def grid_score_60(self, corr):
    if self._min_max:
      return np.minimum(corr[60], corr[120]) - np.maximum(
          corr[30], np.maximum(corr[90], corr[150]))
    else:
      return (corr[60] + corr[120]) / 2 - (corr[30] + corr[90] + corr[150]) / 3

  def grid_score_90(self, corr):
    return corr[90] - (corr[45] + corr[135]) / 2

  def calculate_sac(self, seq1):
    """Calculating spatial autocorrelogram."""
    seq2 = seq1

    def filter2(b, x):
      stencil = np.rot90(b, 2)
      return scipy.signal.convolve2d(x, stencil, mode='full')

    seq1 = np.nan_to_num(seq1)
    seq2 = np.nan_to_num(seq2)

    ones_seq1 = np.ones(seq1.shape)
    ones_seq1[np.isnan(seq1)] = 0
    ones_seq2 = np.ones(seq2.shape)
    ones_seq2[np.isnan(seq2)] = 0

    seq1[np.isnan(seq1)] = 0
    seq2[np.isnan(seq2)] = 0

    seq1_sq = np.square(seq1)
    seq2_sq = np.square(seq2)

    seq1_x_seq2 = filter2(seq1, seq2)
    sum_seq1 = filter2(seq1, ones_seq2)
    sum_seq2 = filter2(ones_seq1, seq2)
    sum_seq1_sq = filter2(seq1_sq, ones_seq2)
    sum_seq2_sq = filter2(ones_seq1, seq2_sq)
    n_bins = filter2(ones_seq1, ones_seq2)
    n_bins_sq = np.square(n_bins)

    std_seq1 = np.power(
        np.subtract(
            np.divide(sum_seq1_sq, n_bins),
            (np.divide(np.square(sum_seq1), n_bins_sq))), 0.5)
    std_seq2 = np.power(
        np.subtract(
            np.divide(sum_seq2_sq, n_bins),
            (np.divide(np.square(sum_seq2), n_bins_sq))), 0.5)
    covar = np.subtract(
        np.divide(seq1_x_seq2, n_bins),
        np.divide(np.multiply(sum_seq1, sum_seq2), n_bins_sq))
    x_coef = np.divide(covar, np.multiply(std_seq1, std_seq2))
    x_coef = np.real(x_coef)
    x_coef = np.nan_to_num(x_coef)
    return x_coef

  def rotated_sacs(self, sac, angles):
    return [
        scipy.ndimage.interpolation.rotate(sac, angle, reshape=False)
        for angle in angles
    ]

  def get_grid_scores_for_mask(self, sac, rotated_sacs, mask):
    """Calculate Pearson correlations of area inside mask at corr_angles."""
    masked_sac = sac * mask
    ring_area = np.sum(mask)
    # Calculate dc on the ring area
    masked_sac_mean = np.sum(masked_sac) / ring_area
    # Center the sac values inside the ring
    masked_sac_centered = (masked_sac - masked_sac_mean) * mask
    variance = np.sum(masked_sac_centered**2) / ring_area + 1e-5
    corrs = dict()
    for angle, rotated_sac in zip(self._corr_angles, rotated_sacs):
      masked_rotated_sac = (rotated_sac - masked_sac_mean) * mask
      cross_prod = np.sum(masked_sac_centered * masked_rotated_sac) / ring_area
      corrs[angle] = cross_prod / variance
    return self.grid_score_60(corrs), self.grid_score_90(corrs), variance

  def get_scores(self, rate_map):
    """Get summary of scrores for grid cells."""
    sac = self.calculate_sac(rate_map)
    rotated_sacs = self.rotated_sacs(sac, self._corr_angles)

    scores = [
        self.get_grid_scores_for_mask(sac, rotated_sacs, mask)
        for mask, mask_params in self._masks  # pylint: disable=unused-variable
    ]
    scores_60, scores_90, variances = map(np.asarray, zip(*scores))  # pylint: disable=unused-variable
    max_60_ind = np.argmax(scores_60)
    max_90_ind = np.argmax(scores_90)

    return (scores_60[max_60_ind], scores_90[max_90_ind],
            self._masks[max_60_ind][1], self._masks[max_90_ind][1], sac)

  def plot_ratemap(self, ratemap, ax=None, title=None, *args, **kwargs):  # pylint: disable=keyword-arg-before-vararg
    """Plot ratemaps."""
    if ax is None:
      ax = plt.gca()
    # Plot the ratemap
    ax.imshow(ratemap, interpolation='none', *args, **kwargs)
    # ax.pcolormesh(ratemap, *args, **kwargs)
    ax.axis('off')
    if title is not None:
      ax.set_title(title)

  def plot_sac(self,
               sac,
               mask_params=None,
               ax=None,
               title=None,
               *args,
               **kwargs):  # pylint: disable=keyword-arg-before-vararg
    """Plot spatial autocorrelogram."""
    if ax is None:
      ax = plt.gca()
    # Plot the sac
    useful_sac = sac * self._plotting_sac_mask
    ax.imshow(useful_sac, interpolation='none', *args, **kwargs)
    # ax.pcolormesh(useful_sac, *args, **kwargs)
    # Plot a ring for the adequate mask
    if mask_params is not None:
      center = self._nbins - 1
      ax.add_artist(
          plt.Circle(
              (center, center),
              mask_params[0] * self._nbins,
              # lw=bump_size,
              fill=False,
              edgecolor='k'))
      ax.add_artist(
          plt.Circle(
              (center, center),
              mask_params[1] * self._nbins,
              # lw=bump_size,
              fill=False,
              edgecolor='k'))
    ax.axis('off')
    if title is not None:
      ax.set_title(title)

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


# utils.py

In [0]:
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Helper functions for creating the training graph and plotting.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# import ensembles  # pylint: disable=g-bad-import-order


np.seterr(invalid="ignore")


def get_place_cell_ensembles(
    env_size, neurons_seed, targets_type, lstm_init_type, n_pc, pc_scale):
  """Create the ensembles for the Place cells."""
  place_cell_ensembles = [
      PlaceCellEnsemble(
          n,
          stdev=s,
          pos_min=-env_size / 2.0,
          pos_max=env_size / 2.0,
          seed=neurons_seed,
          soft_targets=targets_type,
          soft_init=lstm_init_type)
      for n, s in zip(n_pc, pc_scale)
  ]
  return place_cell_ensembles


def get_head_direction_ensembles(
    neurons_seed, targets_type, lstm_init_type, n_hdc, hdc_concentration):
  """Create the ensembles for the Head direction cells."""
  head_direction_ensembles = [
      HeadDirectionCellEnsemble(
          n,
          concentration=con,
          seed=neurons_seed,
          soft_targets=targets_type,
          soft_init=lstm_init_type)
      for n, con in zip(n_hdc, hdc_concentration)
  ]
  return head_direction_ensembles


def encode_initial_conditions(init_pos, init_hd, place_cell_ensembles,
                              head_direction_ensembles):
  initial_conds = []
  for ens in place_cell_ensembles:
    initial_conds.append(
        tf.squeeze(ens.get_init(init_pos[:, tf.newaxis, :]), axis=1))
  for ens in head_direction_ensembles:
    initial_conds.append(
        tf.squeeze(ens.get_init(init_hd[:, tf.newaxis, :]), axis=1))
  return initial_conds


def encode_targets(target_pos, target_hd, place_cell_ensembles,
                   head_direction_ensembles):
  ensembles_targets = []
  for ens in place_cell_ensembles:
    ensembles_targets.append(ens.get_targets(target_pos))
  for ens in head_direction_ensembles:
    ensembles_targets.append(ens.get_targets(target_hd))
  return ensembles_targets


def clip_all_gradients(g, var, limit):
  # print(var.name)
  return (tf.clip_by_value(g, -limit, limit), var)


def clip_bottleneck_gradient(g, var, limit):
  if ("bottleneck" in var.name or "pc_logits" in var.name):
    return (tf.clip_by_value(g, -limit, limit), var)
  else:
    return (g, var)


def no_clipping(g, var):
  return (g, var)


def concat_dict(acc, new_data):
  """Dictionary concatenation function."""

  def to_array(kk):
    if isinstance(kk, np.ndarray):
      return kk
    else:
      return np.asarray([kk])

  for k, v in new_data.iteritems():
    if isinstance(v, dict):
      if k in acc:
        acc[k] = concat_dict(acc[k], v)
      else:
        acc[k] = concat_dict(dict(), v)
    else:
      v = to_array(v)
      if k in acc:
        acc[k] = np.concatenate([acc[k], v])
      else:
        acc[k] = np.copy(v)
  return acc


def get_scores_and_plot(scorer,
                        data_abs_xy,
                        activations,
                        directory,
                        filename,
                        plot_graphs=True,  # pylint: disable=unused-argument
                        nbins=20,  # pylint: disable=unused-argument
                        cm="jet",
                        sort_by_score_60=True):
  """Plotting function."""

  # Concatenate all trajectories
  xy = data_abs_xy.reshape(-1, data_abs_xy.shape[-1])
  act = activations.reshape(-1, activations.shape[-1])
  n_units = act.shape[1]
  # Get the rate-map for each unit
  s = [
      scorer.calculate_ratemap(xy[:, 0], xy[:, 1], act[:, i])
      for i in xrange(n_units)
  ]
  # Get the scores
  score_60, score_90, max_60_mask, max_90_mask, sac = zip(
      *[scorer.get_scores(rate_map) for rate_map in s])
  # Separations
  # separations = map(np.mean, max_60_mask)
  # Sort by score if desired
  if sort_by_score_60:
    ordering = np.argsort(-np.array(score_60))
  else:
    ordering = range(n_units)
  # Plot
  cols = 16
  rows = int(np.ceil(n_units / cols))
  fig = plt.figure(figsize=(24, rows * 4))
  for i in xrange(n_units):
    rf = plt.subplot(rows * 2, cols, i + 1)
    acr = plt.subplot(rows * 2, cols, n_units + i + 1)
    if i < n_units:
      index = ordering[i]
      title = "%d (%.2f)" % (index, score_60[index])
      # Plot the activation maps
      scorer.plot_ratemap(s[index], ax=rf, title=title, cmap=cm)
      # Plot the autocorrelation of the activation maps
      scorer.plot_sac(
          sac[index],
          mask_params=max_60_mask[index],
          ax=acr,
          title=title,
          cmap=cm)
  # Save
  if not os.path.exists(directory):
    os.makedirs(directory)
  with PdfPages(os.path.join(directory, filename), "w") as f:
    plt.savefig(f, format="pdf")
  plt.close(fig)
  return (np.asarray(score_60), np.asarray(score_90),
          np.asarray(map(np.mean, max_60_mask)),
          np.asarray(map(np.mean, max_90_mask)))

# train.py

In [24]:
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Supervised training for the Grid cell network."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib
import numpy as np
import tensorflow as tf
import tkinter  # pylint: disable=unused-import

matplotlib.use('Agg')

# import dataset_reader  # pylint: disable=g-bad-import-order, g-import-not-at-top
# import model  # pylint: disable=g-bad-import-order
# import scores  # pylint: disable=g-bad-import-order
# import utils  # pylint: disable=g-bad-import-order


# Task config
tf.flags.DEFINE_string('task_dataset_info', 'square_room',
                       'Name of the room in which the experiment is performed.')
tf.flags.DEFINE_string('task_root',
                       None,
                       'Dataset path.')
tf.flags.DEFINE_float('task_env_size', 2.2,
                      'Environment size (meters).')
tf.flags.DEFINE_list('task_n_pc', [256],
                     'Number of target place cells.')
tf.flags.DEFINE_list('task_pc_scale', [0.01],
                     'Place cell standard deviation parameter (meters).')
tf.flags.DEFINE_list('task_n_hdc', [12],
                     'Number of target head direction cells.')
tf.flags.DEFINE_list('task_hdc_concentration', [20.],
                     'Head direction concentration parameter.')
tf.flags.DEFINE_integer('task_neurons_seed', 8341,
                        'Seeds.')
tf.flags.DEFINE_string('task_targets_type', 'softmax',
                       'Type of target, soft or hard.')
tf.flags.DEFINE_string('task_lstm_init_type', 'softmax',
                       'Type of LSTM initialisation, soft or hard.')
tf.flags.DEFINE_bool('task_velocity_inputs', True,
                     'Input velocity.')
tf.flags.DEFINE_list('task_velocity_noise', [0.0, 0.0, 0.0],
                     'Add noise to velocity.')

# Model config
tf.flags.DEFINE_integer('model_nh_lstm', 128, 'Number of hidden units in LSTM.')
tf.flags.DEFINE_integer('model_nh_bottleneck', 256,
                        'Number of hidden units in linear bottleneck.')
tf.flags.DEFINE_list('model_dropout_rates', [0.5],
                     'List of floats with dropout rates.')
tf.flags.DEFINE_float('model_weight_decay', 1e-5,
                      'Weight decay regularisation')
tf.flags.DEFINE_bool('model_bottleneck_has_bias', False,
                     'Whether to include a bias in linear bottleneck')
tf.flags.DEFINE_float('model_init_weight_disp', 0.0,
                      'Initial weight displacement.')

# Training config
tf.flags.DEFINE_integer('training_epochs', 1000, 'Number of training epochs.')
tf.flags.DEFINE_integer('training_steps_per_epoch', 1000,
                        'Number of optimization steps per epoch.')
tf.flags.DEFINE_integer('training_minibatch_size', 10,
                        'Size of the training minibatch.')
tf.flags.DEFINE_integer('training_evaluation_minibatch_size', 4000,
                        'Size of the minibatch during evaluation.')
tf.flags.DEFINE_string('training_clipping_function', 'clip_all_gradients',
                       'Function for gradient clipping.')
tf.flags.DEFINE_float('training_clipping', 1e-5,
                      'The absolute value to clip by.')

tf.flags.DEFINE_string('training_optimizer_class', 'tf.train.RMSPropOptimizer',
                       'The optimizer used for training.')
tf.flags.DEFINE_string('training_optimizer_options',
                       '{"learning_rate": 1e-5, "momentum": 0.9}',
                       'Defines a dict with opts passed to the optimizer.')

# Store
tf.flags.DEFINE_string('saver_results_directory',
                       None,
                       'Path to directory for saving results.')
tf.flags.DEFINE_integer('saver_eval_time', 2,
                        'Frequency at which results are saved.')

# Options
tf.flags.DEFINE_string('task_root',
                       './grid-cells-datasets',
                       'Root directory')
tf.flags.DEFINE_string('saver_results_directory',
                       './results',
                        'Results directory.')

# Require flags
tf.flags.mark_flag_as_required('task_root')
tf.flags.mark_flag_as_required('saver_results_directory')
FLAGS = tf.flags.FLAGS


def train():
  """Training loop."""

  tf.reset_default_graph()

  # Create the motion models for training and evaluation
  data_reader = DataReader(
      FLAGS.task_dataset_info, root=FLAGS.task_root, num_threads=4)
  train_traj = read(batch_size=FLAGS.training_minibatch_size)

  # Create the ensembles that provide targets during training
  place_cell_ensembles = get_place_cell_ensembles(
      env_size=FLAGS.task_env_size,
      neurons_seed=FLAGS.task_neurons_seed,
      targets_type=FLAGS.task_targets_type,
      lstm_init_type=FLAGS.task_lstm_init_type,
      n_pc=FLAGS.task_n_pc,
      pc_scale=FLAGS.task_pc_scale)

  head_direction_ensembles = get_head_direction_ensembles(
      neurons_seed=FLAGS.task_neurons_seed,
      targets_type=FLAGS.task_targets_type,
      lstm_init_type=FLAGS.task_lstm_init_type,
      n_hdc=FLAGS.task_n_hdc,
      hdc_concentration=FLAGS.task_hdc_concentration)
  target_ensembles = place_cell_ensembles + head_direction_ensembles

  # Model creation
  rnn_core = GridCellsRNNCell(
      target_ensembles=target_ensembles,
      nh_lstm=FLAGS.model_nh_lstm,
      nh_bottleneck=FLAGS.model_nh_bottleneck,
      dropoutrates_bottleneck=np.array(FLAGS.model_dropout_rates),
      bottleneck_weight_decay=FLAGS.model_weight_decay,
      bottleneck_has_bias=FLAGS.model_bottleneck_has_bias,
      init_weight_disp=FLAGS.model_init_weight_disp)
  rnn = GridCellsRNN(rnn_core, FLAGS.model_nh_lstm)

  # Get a trajectory batch
  input_tensors = []
  init_pos, init_hd, ego_vel, target_pos, target_hd = train_traj
  if FLAGS.task_velocity_inputs:
    # Add the required amount of noise to the velocities
    vel_noise = tf.distributions.Normal(0.0, 1.0).sample(
        sample_shape=ego_vel.get_shape()) * FLAGS.task_velocity_noise
    input_tensors = [ego_vel + vel_noise] + input_tensors
  # Concatenate all inputs
  inputs = tf.concat(input_tensors, axis=2)

  # Replace euclidean positions and angles by encoding of place and hd ensembles
  # Note that the initial_conds will be zeros if the ensembles were configured
  # to provide that type of initialization
  initial_conds = encode_initial_conditions(
      init_pos, init_hd, place_cell_ensembles, head_direction_ensembles)

  # Encode targets as well
  ensembles_targets = encode_targets(
      target_pos, target_hd, place_cell_ensembles, head_direction_ensembles)

  # Estimate future encoding of place and hd ensembles inputing egocentric vels
  outputs, _ = rnn(initial_conds, inputs, training=True)
  ensembles_logits, bottleneck, lstm_output = outputs

  # Training loss
  pc_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
      labels=ensembles_targets[0], logits=ensembles_logits[0], name='pc_loss')
  hd_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
      labels=ensembles_targets[1], logits=ensembles_logits[1], name='hd_loss')
  total_loss = pc_loss + hd_loss
  train_loss = tf.reduce_mean(total_loss, name='train_loss')

  # Optimisation ops
  optimizer_class = eval(FLAGS.training_optimizer_class)  # pylint: disable=eval-used
  optimizer = optimizer_class(**eval(FLAGS.training_optimizer_options))  # pylint: disable=eval-used
  grad = optimizer.compute_gradients(train_loss)
  clip_gradient = eval(FLAGS.training_clipping_function)  # pylint: disable=eval-used
  clipped_grad = [
      clip_gradient(g, var, FLAGS.training_clipping) for g, var in grad
  ]
  train_op = optimizer.apply_gradients(clipped_grad)

  # Store the grid scores
  grid_scores = dict()
  grid_scores['btln_60'] = np.zeros((FLAGS.model_nh_bottleneck,))
  grid_scores['btln_90'] = np.zeros((FLAGS.model_nh_bottleneck,))
  grid_scores['btln_60_separation'] = np.zeros((FLAGS.model_nh_bottleneck,))
  grid_scores['btln_90_separation'] = np.zeros((FLAGS.model_nh_bottleneck,))
  grid_scores['lstm_60'] = np.zeros((FLAGS.model_nh_lstm,))
  grid_scores['lstm_90'] = np.zeros((FLAGS.model_nh_lstm,))

  # Create scorer objects
  starts = [0.2] * 10
  ends = np.linspace(0.4, 1.0, num=10)
  masks_parameters = zip(starts, ends.tolist())
  latest_epoch_scorer = GridScorer(20, data_reader.get_coord_range(),
                                          masks_parameters)

  with tf.train.SingularMonitoredSession() as sess:
    for epoch in range(FLAGS.training_epochs):
      loss_acc = list()
      for _ in range(FLAGS.training_steps_per_epoch):
        res = sess.run({'train_op': train_op, 'total_loss': train_loss})
        loss_acc.append(res['total_loss'])

      tf.logging.info('Epoch %i, mean loss %.5f, std loss %.5f', epoch,
                      np.mean(loss_acc), np.std(loss_acc))
      if epoch % FLAGS.saver_eval_time == 0:
        res = dict()
        for _ in xrange(FLAGS.training_evaluation_minibatch_size //
                        FLAGS.training_minibatch_size):
          mb_res = sess.run({
              'bottleneck': bottleneck,
              'lstm': lstm_output,
              'pos_xy': target_pos
          })
          res = concat_dict(res, mb_res)

        # Store at the end of validation
        filename = 'rates_and_sac_latest_hd.pdf'
        grid_scores['btln_60'], grid_scores['btln_90'], grid_scores[
            'btln_60_separation'], grid_scores[
                'btln_90_separation'] = get_scores_and_plot(
                    latest_epoch_scorer, res['pos_xy'], res['bottleneck'],
                    FLAGS.saver_results_directory, filename)


def main(unused_argv):
  tf.logging.set_verbosity(3)  # Print INFO log messages.
  train()

if __name__ == '__main__':
  tf.app.run()

DuplicateFlagError: ignored

In [25]:
!git clone https://github.com/deepmind/grid-cells.git

Cloning into 'grid-cells'...
remote: Enumerating objects: 13, done.[K
remote: Total 13 (delta 0), reused 0 (delta 0), pack-reused 13[K
Unpacking objects: 100% (13/13), done.


In [27]:
!ls grid-cells

CONTRIBUTING.md    ensembles.py  model.py   scores.py  utils.py
dataset_reader.py  LICENSE	 README.md  train.py
