In [1]:
import numpy as np
import torch

In [5]:
import os
import datetime
import logging
import glob
import shutil

In [14]:
def make_directories():
    '''
    Creates directories for storing data during a model training run
    '''    
    # Get current date for saving folder
    date = datetime.datetime.today().strftime('%Y-%m-%d')
    # Initialise the run and dir_check to create a new run folder within the current date
    run = 0
    dir_check = True
    # Initialise all pahts
    train_path, model_path, save_path, script_path, run_path = None, None, None, None, None
    # Find the current run: the first run that doesn't exist yet
    while dir_check:
        # Construct new paths
        run_path    = f'./Summaries/{date}/run{run}/'
        train_path  = os.path.join(run_path, 'train')
        model_path  = os.path.join(run_path, 'model')
        save_path   = os.path.join(run_path, 'save')
        script_path = os.path.join(run_path, 'script')
        envs_path   = os.path.join(script_path, 'envs')
        run += 1

        # And once a path doesn't exist yet: create new folders
        if not os.path.exists(train_path) and not os.path.exists(model_path) and not os.path.exists(save_path):
            os.makedirs(train_path)
            os.makedirs(model_path)
            os.makedirs(save_path)
            os.makedirs(script_path)
            os.makedirs(envs_path)
            dir_check = False
    # Return folders to new path
    return run_path, train_path, model_path, save_path, script_path, envs_path

In [15]:
# Start training from step 0
i_start = 0

# Create directories for storing all information about the current run
run_path, train_path, model_path, save_path, script_path, envs_path = make_directories()

In [16]:
# Save all python files in current directory to script directory
files = glob.iglob(os.path.join('.', '*.py'))
for file in files:
    if os.path.isfile(file):
        dst = os.path.join(script_path, file)
        print(f'copying {file=} to {dst=}')
        shutil.copy2(file, dst) 

copying file='./hello.py' to dst='./Summaries/2025-05-19/run0/script/./hello.py'


# Data

In [35]:
# Create list of environments that we will sample from during training to provide TEM with trajectory input
envs = ['./envs/5x5.json']
# Save all environment files that are being used in training in the script directory
for file in set(envs):
    dst = os.path.join(envs_path, os.path.basename(file))
    print(f'copying {file=} to {dst=}')
    shutil.copy2(file, dst)

copying file='./envs/5x5.json' to dst='./Summaries/2025-05-19/run0/script/envs/5x5.json'


In [36]:
import json

In [37]:
with open('./envs/5x5.json', 'r') as f:
    _data = json.load(f)

In [47]:
_data.keys()

dict_keys(['n_locations', 'n_observations', 'n_actions', 'adjacency', 'locations'])

In [48]:
_data['n_locations'], _data['n_observations'], _data['n_actions']

(25, 45, 5)

In [49]:
type(_data['adjacency']), len(_data['adjacency']), _data['adjacency'][0]

(list,
 25,
 [1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [51]:
type(_data['locations']), len(_data['locations'])

(list, 25)

In [63]:
for k, v in _data['locations'][0].items():
    if k != 'actions':
        print(f'{k:<15} -> {v}')
    else:
        print(f'{k} ->')
        for action in v:
            print(f'  {action}')

len(_data['locations'][0]['actions'][0]['transition'])

id              -> 0
observation     -> 31
x               -> 0.1
y               -> 0.1
in_locations    -> [0, 1, 5]
in_degree       -> 3
out_locations   -> [0, 1, 5]
out_degree      -> 3
actions ->
  {'id': 0, 'transition': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 1, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}
  {'id': 2, 'transition': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 3, 'transition': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 4, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}


25

In [73]:
np.where( np.array(_data['locations'][0]['actions'][0]['transition']) > 0 )

(array([0]),)

In [74]:
np.where( np.array(_data['locations'][0]['actions'][1]['transition']) > 0 )

(array([], dtype=int64),)

In [77]:
envs = ['./envs/5x5.json']

np.random.choice(envs, 3)

array(['./envs/5x5.json', './envs/5x5.json', './envs/5x5.json'],
      dtype='<U15')

In [80]:
np.random.choice([0, 1], 3)

array([0, 1, 1])

In [81]:
# This is done in
# https://github.com/jbakermans/torch_tem/blob/bf103fb32b5fdc7541ebbd95ba77a2d35d049d7c/world.py#L56
# TEM needs to know that this is a non-shiny environment (e.g. for providing actions to
# generative model), so set shiny to None for each location.
for location in _data['locations']:
    location['shiny'] = None

In [82]:
for k, v in _data['locations'][0].items():
    if k != 'actions':
        print(f'{k:<15} -> {v}')
    else:
        print(f'{k} ->')
        for action in v:
            print(f'  {action}')

len(_data['locations'][0]['actions'][0]['transition'])

id              -> 0
observation     -> 31
x               -> 0.1
y               -> 0.1
in_locations    -> [0, 1, 5]
in_degree       -> 3
out_locations   -> [0, 1, 5]
out_degree      -> 3
actions ->
  {'id': 0, 'transition': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 1, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}
  {'id': 2, 'transition': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 3, 'transition': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0.3333333333333333}
  {'id': 4, 'transition': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'probability': 0}
shiny           -> None


25

In [92]:
environments = [_data]

# Initialise whether a state has been visited for each world
visited = [ [ False for _ in range(env['n_locations']) ] for env in environments]
len(visited), len(visited[0])

(1, 25)

In [None]:
def get_location(env, walk):
    # First step: start at random location
    if len(walk) == 0:
        new_location = np.random.randint(env['n_locations'])
    # Any other step: get new location from previous location and action
    else:                        
        new_location = int(
            np.flatnonzero(
                np.cumsum(
                    walk[-1][0]['actions'][walk[-1][2]]['transition'],
                )>np.random.rand(),
            )[0]
        )
    # Return the location dictionary of the new location
    return env['locations'][new_location]

def get_observation(env, new_location):
        # Find sensory observation for new state, and store it as one-hot vector
        new_observation = np.eye(env['n_observations'])[new_location['observation']]
        # Create a new observation by converting the new observation to a torch tensor
        new_observation = torch.tensor(new_observation, dtype=torch.float).view((new_observation.shape[0]))
        # Return the new observation
        return new_observation

def get_action(env, new_location, walk, repeat_bias_factor=2):
        # Build policy from action probability of each action of provided location dictionary
        policy = np.array([action['probability'] for action in new_location['actions']])        
        # Add a bias for repeating previous action to walk in straight lines, only if (this is not the first step) and (the previous action was a move)
        policy[[] if len(walk) == 0 or new_location['id'] == walk[-1][0]['id'] else walk[-1][2]] *= repeat_bias_factor
        # And renormalise policy (note that for unavailable actions, the policy was 0 and remains 0, so in that case no renormalisation needed)
        policy = policy / sum(policy) if sum(policy) > 0 else policy
        # Select action in new state
        new_action = int(np.flatnonzero(np.cumsum(policy)>np.random.rand())[0])
        # Return the new action
        return new_action

def walk_default(env, walk, walk_length, repeat_bias_factor=2):
    # Finish the provided walk until it contains walk_length steps
    for curr_step in range(walk_length - len(walk)):
        # Get new location based on previous action and location
        new_location = get_location(env, walk)
        # Get new observation at new location
        new_observation = get_observation(env, new_location)
        # Get new action based on policy at new location
        new_action = get_action(env, new_location, walk)
        # Append location, observation, and action to the walk
        walk.append([new_location, new_observation, new_action])
    # Return the final walk
    return walk

def generate_walks(env, walk_length=10, n_walk=100, repeat_bias_factor=2, shiny=False):
    # Generate walk by sampling actions accoring to policy, then next state according to graph
    walks = [] # This is going to contain a list of (state, observation, action) tuples
    for currWalk in range(n_walk):
        new_walk = []
        # If shiny hasn't been specified: there are no shiny objects, generate default policy
        if shiny is None:
            new_walk = walk_default(env, new_walk, walk_length, repeat_bias_factor)
        ## If shiny was specified: use policy that uses shiny policy to approach shiny objects
        ## sequentially
        ##else:
        ##    new_walk = self.walk_shiny(new_walk, walk_length, repeat_bias_factor)
        # Clean up walk a bit by only keep essential location dictionary entries
        for step in new_walk[:-1]:
            step[0] = {'id': step[0]['id'], 'shiny': step[0]['shiny']}
        # Append new walk to list of walks
        walks.append(new_walk)   
    return walks

In [None]:
# And make a single walk for each environment, where walk lengths can be any between the min and max
# length to de-sychronise world switches
walks = [
    env.generate_walks(
        params['n_rollout']*np.random.randint(params['walk_it_min'], params['walk_it_max']),
        1,
    )[0] for env in environments
]


In [None]:
# # Forward-pass this walk through the network
# forward = tem(chunk, prev_iter)


# Model

The TEM model is based on the architecture described in the paper, where there are representations for:

* Abstract location (`g`) - corresponding to grid cells in medial entorhinal cortex
* Grounded location (`p`) - corresponding to place cells in hippocampus
* Sensory observations (`x`) - corresponding to lateral entorhinal cortex

In [28]:
def parameters():
    params = {}
    # -- Model parameters   
    # Decide whether to use seperate grid modules that recieve shiny information for object vector cells.
    # To disable OVC, set this False, and set n_ovc to [0 for _ in range(len(params['n_g_subsampled']))].
    params['separate_ovc'] = False

    # ---- Neuron and module parameters
    # Neurons for subsampled entorhinal abstract location f_g(g) for each frequency module
    params['n_g_subsampled'] = [10, 10, 8, 6, 6]
    # Neurons for object vector cells. Neurons will get new modules if object vector cell modules
    # are separated; otherwise, they are added to existing abstract location modules.
    # a) No additional modules, no additional object vector neurons (e.g. when not using shiny
    #    environments): [0 for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False
    # b) No additional modules, but n additional object vector neurons in each grid module:
    #    [n for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False
    # c) Additional separate object vector modules, with n, m neurons: [n, m], and separate_ovc set to
    #    True
    params['n_ovc'] = [0 for _ in range(len(params['n_g_subsampled']))]
    # Total number of modules
    params['n_f'] = len(params['n_g_subsampled'])

    # Number of hierarchical frequency modules for object vector cells
    params['n_f_ovc'] = len(params['n_ovc']) if params['separate_ovc'] else 0

    # Initial frequencies of each module. For ease of interpretation (higher number = higher frequency)
    # this is 1 - the frequency as James uses it
    params['f_initial'] = [0.99, 0.3, 0.09, 0.03, 0.01]
    # Add frequencies of object vector cell modules, if object vector cells get separate modules
    params['f_initial'] = params['f_initial'] + params['f_initial'][0:params['n_f_ovc']]
    return params

# Initalise hyperparameters for model
params = parameters()

In [29]:
params

{'separate_ovc': False,
 'n_g_subsampled': [10, 10, 8, 6, 6],
 'n_ovc': [0, 0, 0, 0, 0],
 'n_f': 5,
 'n_f_ovc': 0,
 'f_initial': [0.99, 0.3, 0.09, 0.03, 0.01]}

In [30]:
# Save parameters
np.save(os.path.join(save_path, 'params'), params)

In [31]:
hyper = params

In [32]:
# Scale factor in Laplacian transform for each frequency module. High frequency comes first, low frequency comes last. Learn inverse sigmoid instead of scale factor directly, so domain of alpha is -inf, inf
alpha = torch.nn.ParameterList(
    [
        torch.nn.Parameter(
            torch.tensor(
                np.log(hyper['f_initial'][f] / (1 - hyper['f_initial'][f])),
                dtype=torch.float,
            )
        ) for f in range(hyper['n_f'])
    ]
)

In [33]:
alpha

ParameterList(
    (0): Parameter containing: [torch.float32 of size ]
    (1): Parameter containing: [torch.float32 of size ]
    (2): Parameter containing: [torch.float32 of size ]
    (3): Parameter containing: [torch.float32 of size ]
    (4): Parameter containing: [torch.float32 of size ]
)