# Environment JSON Schema Documentation

## Overview

The environment JSON defines a **graph-based spatial world** where an agent can navigate between locations. Each location has sensory observations and possible actions that lead to other locations with specified probabilities.

## Top-Level Structure

```json
{
 "n_locations": 25,           // Total number of locations in the environment
 "n_observations": 45,        // Total number of unique observations available
 "n_actions": 5,              // Total number of possible action types
 "adjacency": [[...], ...],   // 2D adjacency matrix (n_locations × n_locations)
 "locations": [...]           // Array of location objects
}
```

The adjacency matrix shows direct connections between locations in the grid.
For the first row `[1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]`:

- Position 0: 1 → Location 0 connects to itself (can stay put)
- Position 1: 1 → Location 0 connects to Location 1 (can move east)
- Position 5: 1 → Location 0 connects to Location 5 (can move south)
- All others: 0 → No direct connection

So Location 0 (top-left corner at coordinates 0.1,0.1) can only reach:

Itself (stay)
Location 1 (move right)
Location 5 (move down)

This makes sense for a corner position - you can't go north or west from the top-left corner of the grid.
The adjacency matrix is essentially a compact representation of which locations are neighbors in the spatial grid.

---
# Location Object Schema

```json
{
  "id": 3,                    // Unique location identifier (0 to n_locations-1)
  "observation": 33,          // Observation ID seen at this location (0 to n_observations-1)
  "x": 0.7,                   // X coordinate in normalized space
  "y": 0.1,                   // Y coordinate in normalized space
  "in_locations": [2,3,4,8],  // List of location IDs that can reach this location
  "in_degree": 4,             // Number of incoming connections (length of in_locations)
  "out_locations": [2,3,4,8], // List of location IDs reachable from this location
  "out_degree": 4,            // Number of outgoing connections (length of out_locations)
  "shiny": null,              // Whether location contains shiny objects (null/true/false)
  "actions": [...]            // Array of possible actions from this location
}
```

---
# Action Object Schema

```json
{
  "id": 0,                     // Action identifier (0 to n_actions-1)
  "transition": [0,0,0,1,0...], // Probability distribution over next locations
  "probability": 0.25          // Probability of selecting this action
}
```

## Action Types by ID

- 0: Stay in place / No movement
- 1: Move North (typically probability 0 if impossible)
- 2: Move East
- 3: Move South
- 4: Move West


## Key Concepts
### Transition Vector

The transition array is a probability distribution over all locations:

Length = n_locations (25 in the example)
Each index corresponds to a location ID
Value at index i = probability of transitioning to location i
Deterministic transitions: exactly one 1.0, rest are 0.0
Stochastic transitions: multiple non-zero values that sum to 1.0

### Observation Encoding
Observations are encoded as one-hot vectors:

Vector length = n_observations (45 in the example)
If observation: 33, then position 33 = 1.0, all others = 0.0

---

# Example Walkthrough

Location 3 Analysis

```json
{
  "id": 3,
  "observation": 33,
  "x": 0.7, "y": 0.1,
  "in_locations": [2,3,4,8],
  "out_locations": [2,3,4,8],
  "actions": [
    {"id": 0, "transition": [0,0,0,1,0,...], "probability": 0.25},  // Stay at location 3
    {"id": 1, "transition": [0,0,0,0,0,...], "probability": 0.0},   // North (impossible)
    {"id": 2, "transition": [0,0,0,0,1,...], "probability": 0.25},  // East to location 4
    {"id": 3, "transition": [0,0,0,0,0,0,0,0,1,...], "probability": 0.25}, // South to location 8
    {"id": 4, "transition": [0,0,1,0,0,...], "probability": 0.25}   // West to location 2
  ]
}
```

Interpretation:

- Agent at coordinates (0.7, 0.1) - fourth column, top row of 5×5 grid
- Sees landmark/object #33
- Can move to 4 different locations (including staying put)
- Cannot move North (action 1 has probability 0 - already at top edge)
- All other moves equally likely (25% each)

Grid Layout Context

In a 5×5 grid with coordinates starting at (0.1, 0.1) and incrementing by 0.2:

```
(0.1,0.1) - (0.3,0.1) - (0.5,0.1) - (0.7,0.1) - (0.9,0.1)
    |           |           |           |           |
(0.1,0.3) - (0.3,0.3) - (0.5,0.3) - (0.7,0.3) - (0.9,0.3)
    |           |           |           |           |
(0.1,0.5) - (0.3,0.5) - (0.5,0.5) - (0.7,0.5) - (0.9,0.5)
    |           |           |           |           |
(0.1,0.7) - (0.3,0.7) - (0.5,0.7) - (0.7,0.7) - (0.9,0.7)
    |           |           |           |           |
(0.1,0.9) - (0.3,0.9) - (0.5,0.9) - (0.7,0.9) - (0.9,0.9)
```

Location 3 at (0.7, 0.1) is in the top row, fourth position (0-indexed: positions 0,1,2,3,4).

The TEM model uses this data to:

1. Generate walks: Sample sequences of (location, observation, action) tuples
1. Learn structure: Infer spatial relationships from observation sequences
1. Predict observations: Given a sequence, predict what comes next

The model never sees the explicit coordinates, adjacency matrix, or transition probabilities - it must learn the environment's structure purely from the sequence of observations and actions.

In [1]:
import numpy as np
import torch

In [2]:
import os
import datetime
import logging
import glob
import shutil

In [3]:
def make_directories():
    '''
    Creates directories for storing data during a model training run
    '''    
    # Get current date for saving folder
    date = datetime.datetime.today().strftime('%Y-%m-%d')
    # Initialise the run and dir_check to create a new run folder within the current date
    run = 0
    dir_check = True
    # Initialise all pahts
    train_path, model_path, save_path, script_path, run_path = None, None, None, None, None
    # Find the current run: the first run that doesn't exist yet
    while dir_check:
        # Construct new paths
        run_path    = f'./Summaries/{date}/run{run}/'
        train_path  = os.path.join(run_path, 'train')
        model_path  = os.path.join(run_path, 'model')
        save_path   = os.path.join(run_path, 'save')
        script_path = os.path.join(run_path, 'script')
        envs_path   = os.path.join(script_path, 'envs')
        run += 1

        # And once a path doesn't exist yet: create new folders
        if not os.path.exists(train_path) and not os.path.exists(model_path) and not os.path.exists(save_path):
            os.makedirs(train_path)
            os.makedirs(model_path)
            os.makedirs(save_path)
            os.makedirs(script_path)
            os.makedirs(envs_path)
            dir_check = False
    # Return folders to new path
    return run_path, train_path, model_path, save_path, script_path, envs_path

In [4]:
# Start training from step 0
i_start = 0

# Create directories for storing all information about the current run
run_path, train_path, model_path, save_path, script_path, envs_path = make_directories()

In [5]:
# Save all python files in current directory to script directory
files = glob.iglob(os.path.join('.', '*.py'))
for file in files:
    if os.path.isfile(file):
        dst = os.path.join(script_path, file)
        print(f'copying {file=} to {dst=}')
        shutil.copy2(file, dst) 

copying file='./hello.py' to dst='./Summaries/2025-05-27/run0/script/./hello.py'


# Data

In [6]:
# Create list of environments that we will sample from during training to provide TEM with trajectory input
envs = ['./envs/5x5.json']
# Save all environment files that are being used in training in the script directory
for file in set(envs):
    dst = os.path.join(envs_path, os.path.basename(file))
    print(f'copying {file=} to {dst=}')
    shutil.copy2(file, dst)

copying file='./envs/5x5.json' to dst='./Summaries/2025-05-27/run0/script/envs/5x5.json'


In [7]:
import json

In [8]:
with open('./envs/5x5.json', 'r') as f:
    _data = json.load(f)

In [9]:
_data.keys()

dict_keys(['n_locations', 'n_observations', 'n_actions', 'adjacency', 'locations'])

In [10]:
_data['n_locations'], _data['n_observations'], _data['n_actions']

(25, 45, 5)

In [11]:
type(_data['adjacency']), len(_data['adjacency']), _data['adjacency'][0]

(list,
 25,
 [1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [12]:
type(_data['locations']), len(_data['locations'])

(list, 25)

In [13]:
for k, v in _data['locations'][0].items():
    if k != 'actions':
        print(f'{k:<15} -> {v}')
    else:
        print(f'{k} ->')
        for action in v:
            print(f'  {action}')

len(_data['locations'][0]['actions'][0]['transition'])

id              -> 0
observation     -> 31
x               -> 0.1
y               -> 0.1
in_locations    -> [0, 1, 5]
in_degree       -> 3
out_locations   -> [0, 1, 5]
out_degree      -> 3
actions ->
  {'id': 0, 'transition': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 1, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}
  {'id': 2, 'transition': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 3, 'transition': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 4, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}


25

In [14]:
np.where( np.array(_data['locations'][0]['actions'][0]['transition']) > 0 )

(array([0]),)

In [15]:
np.where( np.array(_data['locations'][0]['actions'][1]['transition']) > 0 )

(array([], dtype=int64),)

In [16]:
envs = ['./envs/5x5.json']

np.random.choice(envs, 3)

array(['./envs/5x5.json', './envs/5x5.json', './envs/5x5.json'],
      dtype='<U15')

In [17]:
np.random.choice([0, 1], 3)

array([0, 0, 0])

In [18]:
# This is done in
# https://github.com/jbakermans/torch_tem/blob/bf103fb32b5fdc7541ebbd95ba77a2d35d049d7c/world.py#L56
# TEM needs to know that this is a non-shiny environment (e.g. for providing actions to
# generative model), so set shiny to None for each location.
for location in _data['locations']:
    location['shiny'] = None

In [19]:
for k, v in _data['locations'][0].items():
    if k != 'actions':
        print(f'{k:<15} -> {v}')
    else:
        print(f'{k} ->')
        for action in v:
            print(f'  {action}')

len(_data['locations'][0]['actions'][0]['transition'])

id              -> 0
observation     -> 31
x               -> 0.1
y               -> 0.1
in_locations    -> [0, 1, 5]
in_degree       -> 3
out_locations   -> [0, 1, 5]
out_degree      -> 3
actions ->
  {'id': 0, 'transition': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 1, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}
  {'id': 2, 'transition': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 3, 'transition': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 4, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}
shiny           -> None


25

In [20]:
environments = [_data]

# Initialise whether a state has been visited for each world
visited = [ [ False for _ in range(env['n_locations']) ] for env in environments]
len(visited), len(visited[0])

(1, 25)

In [65]:
_env = environments[0]

_new_location = np.random.randint(25)
print(f'step 1: {_new_location}')

_env['locations'][_new_location]
print(f'---> step 2: \n\t{_env['locations'][_new_location]=}')
print(f'step 2: \n\t{_env['locations'][_new_location].keys()}')
print()

_new_location = _env['locations'][_new_location]
_new_observation = np.eye(_env['n_observations'])[_new_location['observation']]
print(f'step 3: {_new_location["observation"]=}')
print(f'---> step 3: {_new_observation=}')
_new_observation = torch.tensor(_new_observation, dtype=torch.float).view((_new_observation.shape[0]))
print(f'step 3: {_new_observation=}')
print()

_policy = np.array([action['probability'] for action in _new_location['actions']])
for action in _new_location['actions']:
    print(f'step 4: {action=}')
print(f'step 4: {_policy=}')

step 1: 3
---> step 2: 
	_env['locations'][_new_location]={'id': 3, 'observation': 33, 'x': 0.7, 'y': 0.1, 'in_locations': [2, 3, 4, 8], 'in_degree': 4, 'out_locations': [2, 3, 4, 8], 'out_degree': 4, 'actions': [{'id': 0, 'transition': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.25}, {'id': 1, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}, {'id': 2, 'transition': [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.25}, {'id': 3, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.25}, {'id': 4, 'transition': [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.25}], 'shiny': None}
step 2: 
	dict_keys(['id', 'observation', 'x', 'y', 'in_locations', 'in_degree', 'out_locations', 'out_degree', 'actions', 'shiny'])

step 3: _new_location["ob

In [40]:
print(f"{_new_location['id']=}")
_walk = []

[[] if len(_walk) == 0 or new_location['id'] == _walk[-1][0]['id'] else _walk[-1][2]]

_new_location['id']=20


[[]]

In [51]:
# funky way of saying "give me the entries with these indices"
_policy[[]]

array([], dtype=float64)

In [52]:
_policy[[0, -1]]

array([0.33333333, 0.        ])

In [53]:
_repeat_bias_factor = 1

# Add a bias for repeating previous action to walk in straight lines,
# only if (this is not the first step) and (the previous action was a move)
_policy[[] if len(_walk) == 0 or new_location['id'] == _walk[-1][0]['id'] else _walk[-1][2]] *= _repeat_bias_factor

_policy

array([0.33333333, 0.33333333, 0.33333333, 0.        , 0.        ])

In [55]:
_policy = _policy / sum(_policy)
_policy

array([0.33333333, 0.33333333, 0.33333333, 0.        , 0.        ])

In [64]:
_a_rand = np.random.rand()
print(f'{_a_rand=}')
print( np.cumsum(_policy)>_a_rand )
print( np.flatnonzero(np.cumsum(_policy)>_a_rand) )
_new_action = int(np.flatnonzero(np.cumsum(_policy)>_a_rand)[0])
print(_new_action)

_a_rand=0.01538814652392606
[ True  True  True  True  True]
[0 1 2 3 4]
0


In [66]:
_new_location, _new_observation, _new_action

({'id': 3,
  'observation': 33,
  'x': 0.7,
  'y': 0.1,
  'in_locations': [2, 3, 4, 8],
  'in_degree': 4,
  'out_locations': [2, 3, 4, 8],
  'out_degree': 4,
  'actions': [{'id': 0,
    'transition': [0,
     0,
     0,
     1,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0],
    'probability': 0.25},
   {'id': 1,
    'transition': [0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0],
    'probability': 0},
   {'id': 2,
    'transition': [0,
     0,
     0,
     0,
     1,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0],
    'probability': 0.25},
   {'id': 3,
    'transition': [0,
     0,
     0,
     0,
     0,
     

In [79]:
for loc in range(len(_env['locations'])):
    print(f"{loc=:<3} id={_env['locations'][loc]['id']:<3} observation={_env['locations'][loc]['observation']:<3} x={_env['locations'][loc]['x']:<3} y={_env['locations'][loc]['y']:<3}")
    

loc=0   id=0   observation=31  x=0.1 y=0.1
loc=1   id=1   observation=39  x=0.3 y=0.1
loc=2   id=2   observation=21  x=0.5 y=0.1
loc=3   id=3   observation=33  x=0.7 y=0.1
loc=4   id=4   observation=34  x=0.9 y=0.1
loc=5   id=5   observation=5   x=0.1 y=0.3
loc=6   id=6   observation=2   x=0.3 y=0.3
loc=7   id=7   observation=15  x=0.5 y=0.3
loc=8   id=8   observation=10  x=0.7 y=0.3
loc=9   id=9   observation=29  x=0.9 y=0.3
loc=10  id=10  observation=44  x=0.1 y=0.5
loc=11  id=11  observation=32  x=0.3 y=0.5
loc=12  id=12  observation=6   x=0.5 y=0.5
loc=13  id=13  observation=37  x=0.7 y=0.5
loc=14  id=14  observation=41  x=0.9 y=0.5
loc=15  id=15  observation=27  x=0.1 y=0.7
loc=16  id=16  observation=16  x=0.3 y=0.7
loc=17  id=17  observation=40  x=0.5 y=0.7
loc=18  id=18  observation=13  x=0.7 y=0.7
loc=19  id=19  observation=7   x=0.9 y=0.7
loc=20  id=20  observation=4   x=0.1 y=0.9
loc=21  id=21  observation=28  x=0.3 y=0.9
loc=22  id=22  observation=20  x=0.5 y=0.9
loc=23  id=

In [None]:
def get_location(env, walk):
    # First step: start at random location
    if len(walk) == 0:
        new_location = np.random.randint(env['n_locations'])
    # Any other step: get new location from previous location and action
    else:                        
        new_location = int(
            np.flatnonzero(
                np.cumsum(
                    walk[-1][0]['actions'][walk[-1][2]]['transition'],
                )>np.random.rand(),
            )[0]
        )
    # Return the location dictionary of the new location
    return env['locations'][new_location]

def get_observation(env, new_location):
        # Find sensory observation for new state, and store it as one-hot vector
        new_observation = np.eye(env['n_observations'])[new_location['observation']]
        # Create a new observation by converting the new observation to a torch tensor
        new_observation = torch.tensor(new_observation, dtype=torch.float).view((new_observation.shape[0]))
        # Return the new observation
        return new_observation

def get_action(env, new_location, walk, repeat_bias_factor=2):
        # Build policy from action probability of each action of provided location dictionary
        policy = np.array([action['probability'] for action in new_location['actions']])        
        # Add a bias for repeating previous action to walk in straight lines, only if (this is not the first step) and (the previous action was a move)
        policy[[] if len(walk) == 0 or new_location['id'] == walk[-1][0]['id'] else walk[-1][2]] *= repeat_bias_factor
        # And renormalise policy (note that for unavailable actions, the policy was 0 and remains 0, so in that case no renormalisation needed)
        policy = policy / sum(policy) if sum(policy) > 0 else policy
        # Select action in new state
        new_action = int(np.flatnonzero(np.cumsum(policy)>np.random.rand())[0])
        # Return the new action
        return new_action

def walk_default(env, walk, walk_length, repeat_bias_factor=2):
    # Finish the provided walk until it contains walk_length steps
    for curr_step in range(walk_length - len(walk)):
        # Get new location based on previous action and location
        new_location = get_location(env, walk)
        # Get new observation at new location
        new_observation = get_observation(env, new_location)
        # Get new action based on policy at new location
        new_action = get_action(env, new_location, walk)
        # Append location, observation, and action to the walk
        walk.append([new_location, new_observation, new_action])
    # Return the final walk
    return walk

def generate_walks(env, walk_length=10, n_walk=100, repeat_bias_factor=2, shiny=False):
    # Generate walk by sampling actions accoring to policy, then next state according to graph
    walks = [] # This is going to contain a list of (state, observation, action) tuples
    for currWalk in range(n_walk):
        new_walk = []
        # If shiny hasn't been specified: there are no shiny objects, generate default policy
        if shiny is None:
            new_walk = walk_default(env, new_walk, walk_length, repeat_bias_factor)
        ## If shiny was specified: use policy that uses shiny policy to approach shiny objects
        ## sequentially
        ##else:
        ##    new_walk = self.walk_shiny(new_walk, walk_length, repeat_bias_factor)
        # Clean up walk a bit by only keep essential location dictionary entries
        for step in new_walk[:-1]:
            step[0] = {'id': step[0]['id'], 'shiny': step[0]['shiny']}
        # Append new walk to list of walks
        walks.append(new_walk)   
    return walks

In [None]:
# And make a single walk for each environment, where walk lengths can be any between the min and max
# length to de-sychronise world switches.

# Minimum length of a walk on one environment. Walk lengths are sampled uniformly from a window
# that shifts down until its lower limit is walk_it_min at the end of training
# params['walk_it_min'] = 25
# Maximum length of a walk on one environment. Walk lengths are sampled uniformly from a window
# that starts with its upper limit at walk_it_max in the beginning of training, then shifts down
# params['walk_it_max'] = 300
params = {
    # Number of steps to roll out before backpropagation through time
    'n_rollout': 20,
    # Minimum length of a walk on one environment. Walk lengths are sampled uniformly from a window
    # that shifts down until its lower limit is walk_it_min at the end of training
    'walk_it_min': 25,
    # Maximum length of a walk on one environment. Walk lengths are sampled uniformly from a window
    # that starts with its upper limit at walk_it_max in the beginning of training, then shifts down
    'walk_it_max': 300,
}
walks = [
    env.generate_walks(
        env,
        params['n_rollout']*np.random.randint(params['walk_it_min'], params['walk_it_max']),
        1,
    )[0] for env in environments
]


In [None]:
# # Forward-pass this walk through the network
# forward = tem(chunk, prev_iter)


# Model

The TEM model is based on the architecture described in the paper, where there are representations for:

* Abstract location (`g`) - corresponding to grid cells in medial entorhinal cortex
* Grounded location (`p`) - corresponding to place cells in hippocampus
* Sensory observations (`x`) - corresponding to lateral entorhinal cortex

In [28]:
def parameters():
    params = {}
    # -- Model parameters   
    # Decide whether to use seperate grid modules that recieve shiny information for object vector cells.
    # To disable OVC, set this False, and set n_ovc to [0 for _ in range(len(params['n_g_subsampled']))].
    params['separate_ovc'] = False

    # ---- Neuron and module parameters
    # Neurons for subsampled entorhinal abstract location f_g(g) for each frequency module
    params['n_g_subsampled'] = [10, 10, 8, 6, 6]
    # Neurons for object vector cells. Neurons will get new modules if object vector cell modules
    # are separated; otherwise, they are added to existing abstract location modules.
    # a) No additional modules, no additional object vector neurons (e.g. when not using shiny
    #    environments): [0 for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False
    # b) No additional modules, but n additional object vector neurons in each grid module:
    #    [n for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False
    # c) Additional separate object vector modules, with n, m neurons: [n, m], and separate_ovc set to
    #    True
    params['n_ovc'] = [0 for _ in range(len(params['n_g_subsampled']))]
    # Total number of modules
    params['n_f'] = len(params['n_g_subsampled'])

    # Number of hierarchical frequency modules for object vector cells
    params['n_f_ovc'] = len(params['n_ovc']) if params['separate_ovc'] else 0

    # Initial frequencies of each module. For ease of interpretation (higher number = higher frequency)
    # this is 1 - the frequency as James uses it
    params['f_initial'] = [0.99, 0.3, 0.09, 0.03, 0.01]
    # Add frequencies of object vector cell modules, if object vector cells get separate modules
    params['f_initial'] = params['f_initial'] + params['f_initial'][0:params['n_f_ovc']]
    return params

# Initalise hyperparameters for model
params = parameters()

In [29]:
params

{'separate_ovc': False,
 'n_g_subsampled': [10, 10, 8, 6, 6],
 'n_ovc': [0, 0, 0, 0, 0],
 'n_f': 5,
 'n_f_ovc': 0,
 'f_initial': [0.99, 0.3, 0.09, 0.03, 0.01]}

In [30]:
# Save parameters
np.save(os.path.join(save_path, 'params'), params)

In [31]:
hyper = params

In [32]:
# Scale factor in Laplacian transform for each frequency module. High frequency comes first, low frequency comes last. Learn inverse sigmoid instead of scale factor directly, so domain of alpha is -inf, inf
alpha = torch.nn.ParameterList(
    [
        torch.nn.Parameter(
            torch.tensor(
                np.log(hyper['f_initial'][f] / (1 - hyper['f_initial'][f])),
                dtype=torch.float,
            )
        ) for f in range(hyper['n_f'])
    ]
)

In [33]:
alpha

ParameterList(
    (0): Parameter containing: [torch.float32 of size ]
    (1): Parameter containing: [torch.float32 of size ]
    (2): Parameter containing: [torch.float32 of size ]
    (3): Parameter containing: [torch.float32 of size ]
    (4): Parameter containing: [torch.float32 of size ]
)