# Chunking

First off, the following code should not exist.

```python
# Now pop the first n_rollout steps from this walk and append them to the chunk
for step in range(params['n_rollout']):
    # For the first environment: simply copy the components (g, x, a) of each step
    if len(chunk) < params['n_rollout']:
        chunk.append([[comp] for comp in walk.pop(0)])
    # For all next environments: add the components to the existing list of components for each step
    else:
        for comp_i, comp in enumerate(walk.pop(0)):
            chunk[step][comp_i].append(comp)
```

What that is trying to do is something like so

1. `walk.pop(0)` returns `[location, observation, action]`
2. `[[comp] for comp in walk.pop(0)]` becomes `[[location], [observation], [action]]`

When the chunk is has less than `n_rollout` entries it looks like so
```
chunk = [
    [[loc1], [obs1], [act1]],
    [[loc2], [obs2], [act2]],
    ...
]
```

**POSSIBLE BUG**
```python
# Now pop the first n_rollout steps from this walk and append them to the chunk
for step in range(params['n_rollout']):
    ...
    # For all next environments: add the components to the existing list of components for each step
    else:
        for comp_i, comp in enumerate(walk.pop(0)):
            chunk[step][comp_i].append(comp)
```

Not only cas its confusing but that should give an error because the code doesn't know that `chunk` is a list of lists.
Switching to the `chunk[...][...]` syntax should raise an error.

In [108]:
import torch
import numpy as np
import json

In [109]:
with open('./envs/5x5.json', 'r') as f:
    env = json.load(f)

for location in env['locations']:
    location['shiny'] = None

environments = [env]

In [110]:
def get_location(env, walk):
    """
    Each "walk" (step) is a list with the following:
    1. new_location is the actual location from the env file.
    2. new_observation is the one-hot-vector for the given observation.
    3. new_action is the first action ID that is greater than some random number.
    """
    # First step: start at random location
    if len(walk) == 0:
        new_location = np.random.randint(env['n_locations'])
    # Any other step: get new location from previous location and action
    else:
        prev_location = walk[-1][0]
        prev_action_chosen = walk[-1][2]
        prev_location['actions'][prev_action_chosen]

        # The transition array will have a 1 in the location we are moving to.
        # The index with the 1, if there is 1, corresponds to an index in the
        # locations array.
        new_location = int(
            np.flatnonzero(
                np.cumsum(
                    prev_location['actions'][prev_action_chosen]['transition'],
                )>np.random.rand(),
            )[0]
        )
    # Return the location dictionary of the new location.
    return env['locations'][new_location]

def get_observation(env, new_location):
    # Find sensory observation for new state, and store it as one-hot vector
    new_observation = np.eye(env['n_observations'])[new_location['observation']]
    # Create a new observation by converting the new observation to a torch tensor
    new_observation = torch.tensor(new_observation, dtype=torch.float).view((new_observation.shape[0]))
    # Return the new observation
    return new_observation

def get_action(env, new_location, walk, repeat_bias_factor=2):
    # Build policy from action probability of each action of provided location dictionary
    policy = np.array([action['probability'] for action in new_location['actions']])  
    
    # Add a bias for repeating previous action to walk in straight lines, only if
    # (this is not the first step) and (the previous action was a move)
    policy[
        [] if len(walk) == 0 or new_location['id'] == walk[-1][0]['id'] else walk[-1][2]
    ] *= repeat_bias_factor
    
    # And renormalise policy (note that for unavailable actions, the policy was 0 and remains 0,
    # so in that case no renormalisation needed)
    policy = policy / sum(policy) if sum(policy) > 0 else policy
    
    # Select action in new state.
    _some = np.random.rand()
    new_action = int(np.flatnonzero(np.cumsum(policy)>_some)[0])
    # Return the new action
    return new_action

def walk_default(env, walk, walk_length, repeat_bias_factor=2):
    # Finish the provided walk until it contains walk_length steps
    for curr_step in range(walk_length - len(walk)):
        # Get new location based on previous action and location
        new_location = get_location(env, walk)
        # Get new observation at new location
        new_observation = get_observation(env, new_location)
        # Get new action based on policy at new location
        new_action = get_action(env, new_location, walk)
        # Append location, observation, and action to the walk
        walk.append([new_location, new_observation, new_action])
    # Return the final walk
    return walk

def generate_walks(env, walk_length=10, n_walk=100, repeat_bias_factor=2, shiny=False):
    print(f'{walk_length=}')
    print(f'{n_walk=}')
    # Generate walk by sampling actions accoring to policy, then next state according to graph
    walks = [] # This is going to contain a list of (state, observation, action) tuples
    for currWalk in range(n_walk):
        new_walk = []
        # If shiny hasn't been specified: there are no shiny objects, generate default policy
        if shiny is False:
            new_walk = walk_default(env, new_walk, walk_length, repeat_bias_factor)
        ## If shiny was specified: use policy that uses shiny policy to approach shiny objects
        ## sequentially
        ##else:
        ##    new_walk = self.walk_shiny(new_walk, walk_length, repeat_bias_factor)
        # Clean up walk a bit by only keep essential location dictionary entries
        for step in new_walk[:-1]: # dont include the last step.
            step[0] = {'id': step[0]['id'], 'shiny': step[0]['shiny']}
        # Append new walk to list of walks
        walks.append(new_walk)   
    return walks

In [111]:
# And make a single walk for each environment, where walk lengths can be any between the min and max
# length to de-sychronise world switches.
params = {
    # Number of steps to roll out before backpropagation through time
    'n_rollout': 2,
    # Minimum length of a walk on one environment. Walk lengths are sampled uniformly
    # from a window that shifts down until its lower limit is walk_it_min at the end of training
    'walk_it_min': 5,
    # Maximum length of a walk on one environment. Walk lengths are sampled uniformly from a window
    # that starts with its upper limit at walk_it_max in the beginning of training, then shifts down
    'walk_it_max': 8,
}
walks = [
    generate_walks(
        env,
        params['n_rollout']*np.random.randint(params['walk_it_min'], params['walk_it_max']),
        1,
    )[0] for env in environments
]

walk_length=14
n_walk=1


In [112]:
len(walks), type(walks)

(1, list)

In [113]:
len(walks[0]), type(walks[0])

(14, list)

In [114]:
walks[0][0]

[{'id': 5, 'shiny': None},
 tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 3]

In [115]:
walks[0][1]

[{'id': 10, 'shiny': None},
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1.]),
 3]

In [116]:
walks[0][2]

[{'id': 15, 'shiny': None},
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 3]

In [117]:
walks[0][-2]

[{'id': 11, 'shiny': None},
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 2]

In [118]:
walks[0][-1]

[{'id': 12,
  'observation': 6,
  'x': 0.5,
  'y': 0.5,
  'in_locations': [7, 11, 12, 13, 17],
  'in_degree': 5,
  'out_locations': [7, 11, 12, 13, 17],
  'out_degree': 5,
  'actions': [{'id': 0,
    'transition': [0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     1,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0],
    'probability': 0.2},
   {'id': 1,
    'transition': [0,
     0,
     0,
     0,
     0,
     0,
     0,
     1,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0],
    'probability': 0.2},
   {'id': 2,
    'transition': [0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     1,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0,
     0],
    'probability': 0.2},
   {'id': 3,
    'transition': [0,
     0,
     0,
     0,

In [119]:
# Number of walks to generate
params['train_it'] = 5

# Train TEM on walks in different environment
for i in range(0, params['train_it']):
    # Make an empty chunk that will be fed to TEM in this backprop iteration
    chunk = []
    # For each environment: fill chunk by popping the first batch_size steps of the walk
    for env_i, walk in enumerate(walks):
        # Now pop the first n_rollout steps from this walk and append them to the chunk
        for step in range(params['n_rollout']):
            # For the first environment: simply copy the components (g, x, a) of each step
            if len(chunk) < params['n_rollout']:
                chunk.append([[comp] for comp in walk.pop(0)])
            # For all next environments: add the components to the existing list of components for each step
            else:
                for comp_i, comp in enumerate(walk.pop(0)):
                    chunk[step][comp_i].append(comp)

In [120]:
chunk, len(chunk), type(chunk) # chunk.append([[comp] for comp in walk.pop(0)])

([[[{'id': 15, 'shiny': None}],
   [tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0.])],
   [0]],
  [[{'id': 15, 'shiny': None}],
   [tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0.])],
   [2]]],
 2,
 list)

In [121]:
chunk[0], len(chunk[0]), type(chunk[0]) # [[comp] for comp in walk.pop(0)]

([[{'id': 15, 'shiny': None}],
  [tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0.])],
  [0]],
 3,
 list)

In [122]:
chunk[0][0], len(chunk[0][0]), type(chunk[0][0]) # [chunk]

([{'id': 15, 'shiny': None}], 1, list)

In [123]:
chunk[0][0][0], len(chunk[0][0][0]), type(chunk[0][0][0]) # chunk

({'id': 15, 'shiny': None}, 2, dict)

In [124]:
# NOTE: this step doesn't do anything in this setup.
# TODO: check that it actually has a purpose.

# Stack all observations (x, component 1) into tensors along the first dimension for batch processing
for i_step, step in enumerate(chunk):
    print(f'{step}')
    chunk[i_step][1] = torch.stack(step[1], dim=0)

[[{'id': 15, 'shiny': None}], [tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.])], [0]]
[[{'id': 15, 'shiny': None}], [tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.])], [2]]


In [125]:
chunk[0][0][0], len(chunk[0][0][0]), type(chunk[0][0][0]) # chunk

({'id': 15, 'shiny': None}, 2, dict)

In [126]:
chunk[0][1][0], len(chunk[0][1][0]), type(chunk[0][1][0]) # chunk

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 45,
 torch.Tensor)

In [127]:
chunk

[[[{'id': 15, 'shiny': None}],
  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
  [0]],
 [[{'id': 15, 'shiny': None}],
  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
  [2]]]

# Making sense of the damn model

The TEM has two main parts:

**Grid cells (g)** - learn abstract spatial structure that generalizes across environments
**Place cells (p)** - bind specific sensory experiences to locations in the current environment

Key Functions to Understand:

1. `gen_g()` - Does "path integration" (moves grid cells based on action)
1. `inference()` - Figures out current place cell state from sensory input + grid cells
1. `generative()` - Predicts what observation should be seen given current state
1. `hebbian()` - Updates memory matrix M

MLP is just a simple Multi-Layer Perceptron (basic neural network).
It's a utility class that creates a 2-layer network: input > hidden > output.