# T-Maze Demo

In this notebook, we simulate a T-Maze task - a type of two-armed contextual bandit - using an active inference (AIF) agent with `pymdp`. Note that unlike earlier demos in the [legacy] version of pymdp, this demo relies on `pymdp`'s new backend in `jax`. The `jax` backend accelerates the [`pymdp` package](https://pymdp-rtd.readthedocs.io/en/latest/), originally written in `numpy`, by dispatching the core inference, learning and planning computations to CUDA-compatible GPU (if enabled). This, in combination with `jax`'s just-in-time (JIT) compilation features, enables pymdp to take advantage of batch processing (i.e., running many AIF processes in parallel) and increased memory-usage / speed.

### The T-Maze Task

The T-Maze task implemented in this notebook is adapted from [the sophisticated inference paper](https://discovery.ucl.ac.uk/id/eprint/10124606/), and was originally introduced in an active inference context in ["Active Inference and Epistemic Value"](). The T-maze is a two-armed contextual bandit: at any given time, the agent can choose between sampling a cue (context) or choosing between two reward arms (left vs. right).” This task represents a classic problem in sequential decision-making, where an agent (in this case, analogized to a rat) must navigate a T-shaped maze. The agent starts at the centre of the T-maze. Within either the left or right arm, there is either a preferred (i.e., rewarding; cheese) stimulus or an aversive (i.e., punishing; shock) stimulus, with these reward contingencies initially unknown to the agent. In the bottom part of the T-Maze, a cue provides information about the which arm the rewarding stimulus is in.

The agent is faced with a dilemma: commit to one of the potentially rewarding arms or first seek information from the cue to identify the more rewarding option before taking action. We use the term "cue validity" to indicate the probability that the cue correctly indicates the reward's location. If the cue has information about which of the two arms is more rewarded (i.e., the cue validity is greater than 50%), then the optimal behavior entails first visiting the cue arm and then choosing one of the two reward arms. 


### Overview
This notebook steps through the following:

1. A deterministic generative process (environment), and a single agent solving the task with vanilla active inference.
2. A noisy generative process, and a single agent solving the task with vanilla active inference.
3. A noisy generative model with A and B learning, with correct and incorrect prior structure of those parameters, and a single agent solving the task with vanilla active inference.
4. A deterministic generative process, and a single agent solving the task with sophisticated inference
5. A deterministic generative process, and a single agent solving the task with inductive inference
6. A deterministic generative process, and multiple agents solving the task with vanilla active inference
7. A deterministic generative process, and multiple agents solving the task with sophisticated inference


In [None]:
# a way to edit and run code and see the effects in the notebook without having to restart the kernel
%load_ext autoreload
%autoreload 2

# importing necessary libraries
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import mediapy

from jax import random as jr
from pymdp.envs import TMaze, rollout
from pymdp.agent import Agent

  assert (


### Creating the T-Maze Task (Generative Process)

- The `batch_size` parameter specifies the number of environments to run in parallel.
- The `reward_condition` parameter determines the reward location: `0` for the left arm, `1` for the right arm, or `None` for random allocation.
- The `reward_probability` parameter sets the chance of receiving a reward in the correct arm. For example, if set at 0.80, there would be an 80% chance of reward and 20% chance of no outcome in the rewarding arm.
- The `punishment_probability` parameter specifies the likelihood of punishment in the other arm. For example, if set at 0.80, there would be an 80% chance of punishment and 20% chance of no outcome in the non-rewarding arm.
- The `cue_validity` parameter represents the accuracy of the cues as a probability between 0 and 1.

<details>
<summary> Click here to see how the generative process is set up. </summary>

#### States and Observations

**State Factors:**
1. Location (5 states):
    - 0: centre (start location)
    - 1: left arm
    - 2: right arm
    - 3: cue location (bottom arm)
    - 4: middle of arms (between left and right arm)
2. Reward Location (2 states):
    - 0: reward in left arm 
    - 1: reward in right arm

**Observation Modalities:**
1. Location (5 observations):
    - Matches the location states exactly
2. Outcome (3 observations):
    - 0: no outcome
    - 1: reward (cheese)
    - 2: punishment (shock)
3. Cue (3 observations):
    - 0: no cue
    - 1: left arm cued
    - 2: right arm cued

#### Environment Parameters

**Observation Likelihood Model (A):**
- A[0]: Location observations (5x5 tensor)
  - Perfect mapping between true and observed location.
- A[1]: Outcome observations (3x5x2 tensor)
  - In the more rewarding arm, reward is presented with a likelihood determined by the `reward_probability` parameter.
  - In the less rewarding arm, punishment is presented with a likelihood determined by the `punishment_probability` parameter.
  - No outcomes are observed in the centre/start location, cue location, or middle of the arms.
- A[2]: Cue observations (3x5x2 tensor)
  - Indicating the reward location, at the cue location (bottom arm), with accuracy set by the `cue_validity` parameter.
  - No cues visible elsewhere.

**Transition Model (B):**
- B[0]: Location transitions (5x5x5 tensor)
  - Agent can move between adjacent maze cells or stay in the same cell.
- B[1]: Reward location (2x2x1 tensor)
  - Reward location remains fixed throughout trial.

**Initial Conditions (D):**
- D[0]: Starting location (5x1 tensor)
  - Agent always begins in centre location
- D[1]: Reward placement (2x1 tensor)
  - Default: Equal chance (50/50) of reward in either arm (`reward_condition=None`)
  - Optional: Can fix reward to specific arm, by setting `reward_condition` to `0` (for left arm) or `1` (for right arm)

</details>



In [None]:
# setting the parameters for the environment
batch_size = 1 # batch_size, which in this case corresponds to the number of environments to run in parallel
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 1.0 # 100% chance of reward in the correct arm
punishment_probability = 1.0 # 100% chance of punishment in the other arm
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm)

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze( 
    batch_size=batch_size, 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition,
    dependent_outcomes=dependent_outcomes
)

# you may print the environment parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code: 

# print([a.shape for a in env.params["A"]]) # shape of all A tensors; the shape should start with the batch_size, then the rows, columns, and additional dimensions for the dependencies
# print(env.params["A"][1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm)
# print(env.params["A"][2][0][:,:,0]) # likelihood of observing no cue, left arm cued, or right arm cued (rows), in each location (columns), when the reward condition is 0 (left arm)

# print([b.shape for b in env.params["B"]]) # shape of all B tensors
# print(env.params["B"][0][0][:,:,4]) # probability of transitioning to each location (rows), from each location (columns), when the agent wants to move to the middle of the arms (location 4)

# 1. A deterministic generative process (environment), and a single agent solving the task with vanilla active inference.

### Creating the Agent (Generative Model)

We will create the agent's generative model based off the A and B tensors of the environment. The A and B tensors remain the same as, in this simple design, we assume the agent has knowledge of the environment's parameters - i.e., it knows that the likelihood of observing a reward in the left arm is 1.0 (`reward_probability=1.0`) if the reward is actually in the left arm (`reward_condition=0`), and it knows the cues are 100% accurate (`cue_validity=1.0`), and it knows that the reward location will be fixed throughout the trial (non-volatile environment). We can of course change these assumptions to create a more complex agent, and we will do that in the next sections where the environment will be more stochastic and we will also add uncertainty to the agent's generative model so it will have to learn that the environment is deterministic or not.

The preference tensors (C) are set using the A tensor's shape. The agent is set to prefer reward and avoid punishment. The agent is set to not have any preference to observe certain locations and cues. 

The initial beliefs tensors (D; i.e., priors) are set using the B tensor's shape. The agent has a prior to start in the center location and it has no prior about the reward location - i.e., the prior for the reward location is uniformly distributed.  

In [7]:
#  setting A tensors from the environment parameters
A = [jnp.array(a, dtype=jnp.float32) for a in env.params["A"]]
A_dependencies = env.dependencies["A"] # dependencies allow you to specify which state factors each observation modality depends on, so you dont have to store all the conditional dependencies between all state factors and each modality

# setting B tensors from the environment parameters
B = [jnp.array(b, dtype=jnp.float32) for b in env.params["B"]]
B_dependencies = env.dependencies["B"]

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(2.0)    # prefer reward
C[1] = C[1].at[:,2].set(-3.0)   # avoid punishment


# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) 
D_loc = D_loc.at[0,0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    policy_len=2, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    apply_batch=False,
    learn_A=False,
    learn_B=False
)

# you may print the agent's generative model parameters to see the shapes of the tensors and the values by editing and uncommenting the following lines and running the code: 

# print([a.shape for a in agent.A]) # shape of all A tensors
# print(agent.A[1][0][:,:,1]) # likelihood of observing no outcome, reward, or punishment (rows), in each location (columns), when the reward condition is 1 (right arm)
# print(agent.C[1]) # preferences for outcomes

### Running the active inference agent

In [9]:
key = jr.PRNGKey(0) # random key for the aif loop
T = 10 # number of timesteps to rollout the aif loop for
_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop

# you may print the info dictionary to see the numerical results of the aif agent completing the T-maze task by editing and uncommenting the following lines and running the code: 

# print(info.keys()) # keys in the info dictionary
# print(info["action"][:,0,:]) # actions taken by the agent (locations throughout the maze) 
# print(info["observation"][0]) # observations of the locations for each batch
# print(info["observation"][2]) # observations of the cues for each batch
# print(jnp.around(info["qs"][1], decimals=2)) # posterior beliefs about the reward location
# print(info["qpi"][0].shape) # shape of the policy tensor


Rendering the Task to Visualise the Agent's Behaviour

In [None]:
frames = []
for t in range(info["observation"][0].shape[0]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][t, :, :],
        info["observation"][1][t, :, :],  
        info["observation"][2][t, :, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)

# # uncomment the following lines to save the video as a gif
# os.makedirs("figures", exist_ok=True)
# pil_frames = [Image.fromarray(frame) for frame in frames]
# reward_location = "random" if reward_condition is None else ("left" if reward_condition == 0 else "right")
# filename = os.path.join("figures", f"tmaze_{batch_size}_{reward_location}.gif")
# pil_frames[0].save(
#     filename,
#     save_all=True,
#     append_images=pil_frames[1:],
#     duration=1000,  # 1000ms per frame
#     loop=0
# )

We can also run multiple agents in parallel to see how they solve the task.

In [None]:
batch_size = 9 # number of environments to run in parallel

# initialising the environment
env = TMaze( 
    batch_size=batch_size, 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition,
    dependent_outcomes=dependent_outcomes
)

# initialising the agent's generative model - we need to generate this again to be the same size as the environment's batch_size. 
#  setting A tensors from the environment parameters
A = [jnp.array(a, dtype=jnp.float32) for a in env.params["A"]]
A_dependencies = env.dependencies["A"] # dependencies allow you to specify the state factors the observation modality depends on so you dont have to compute the full tensor using all state factors

# setting B tensors from the environment parameters
B = [jnp.array(b, dtype=jnp.float32) for b in env.params["B"]]
B_dependencies = env.dependencies["B"]

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(2.0)    # prefer reward
C[1] = C[1].at[:,2].set(-3.0)   # avoid punishment


# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) 
D_loc = D_loc.at[0,0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    policy_len=2,
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    apply_batch=False,
    learn_A=False,
    learn_B=False
)

_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop

# rendering the task to visualise the agent's behaviour 
frames = []
for t in range(info["observation"][0].shape[0]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][t, :, :],
        info["observation"][1][t, :, :],  
        info["observation"][2][t, :, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)

# 2. A noisy generative process, and a single agent solving the task with vanilla active inference.

In [None]:
# THE GENERATIVE PROCESS (NOISY)
# setting the parameters for the environment
batch_size = 4 # number of environments to run in parallel
reward_condition = None # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 0.7 # 70% chance of reward in the correct arm
punishment_probability = 0.6 # 60% chance of punishment in the other arm
cue_validity = 0.9 # 90% valid cues
dependent_outcomes = False # if True, punishment occurs as a function of reward probability (i.e., if reward probability is 0.8, then 20% punishment). If False, punishment occurs with set probability (i.e., 20% no outcome and punishment will only occur in the other (non-rewarding) arm)

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze( 
    batch_size=batch_size, 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition,
    dependent_outcomes=dependent_outcomes
)


In [None]:
# THE GENERATIVE MODEL
#  setting A tensors from the environment parameters
A = [jnp.array(a, dtype=jnp.float32) for a in env.params["A"]]
A_dependencies = env.dependencies["A"] # dependencies allow you to specify the state factors the observation modality depends on so you dont have to compute the full tensor using all state factors

# setting B tensors from the environment parameters
B = [jnp.array(b, dtype=jnp.float32) for b in env.params["B"]]
B_dependencies = env.dependencies["B"]

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(3.0)    # prefer reward
C[1] = C[1].at[:,2].set(-3.0)   # avoid punishment


# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) 
D_loc = D_loc.at[0,0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    policy_len=3, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    apply_batch=False,
    learn_A=False,
    learn_B=False
)

key = jr.PRNGKey(0) # random key for the aif loop
T = 20 # number of timesteps to rollout the aif loop for
_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop


In [None]:
frames = []
for t in range(info["observation"][0].shape[0]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][t, :, :],
        info["observation"][1][t, :, :],  
        info["observation"][2][t, :, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)

# 3. A noisy generative model with A and B learning, with correct and incorrect prior structure of those parameters, and a single agent solving the task with vanilla active inference.

Here, you may tweak the parameters of learn_A and learn_B between True and False when initialising the agent to conduct A learning only, B learning only, or A and B learning together. First, we will set a correct prior structure of the A and B tensors. Then, we will set an incorrect prior structure of the A and B tensors. 

### Correct Prior Structure

In [None]:
# setting the parameters for the environment
batch_size = 1 # number of environments to run in parallel
reward_condition = 0 # 0 is reward in left arm, 1 is reward in right arm, None is random allocation
reward_probability = 1.0 # 100% chance of reward in the correct arm
punishment_probability = 1.0 # 100% chance of punishment in the other arm
cue_validity = 1.0 # 100% valid cues
dependent_outcomes = False

# initialising the environment. see tmaze.py in pymdp/envs for the implementation details.
env = TMaze( 
    batch_size=batch_size, 
    reward_probability=reward_probability,     
    punishment_probability=punishment_probability, 
    cue_validity=cue_validity,          
    reward_condition=reward_condition, 
    dependent_outcomes=dependent_outcomes
)

Making the agent's generative model noisy. Borrowing the structure of the tensors from the environment. 

In [None]:
#  setting A tensors from the environment parameters
A = [jnp.array(a, dtype=jnp.float32) for a in env.params["A"]]
A_dependencies = env.dependencies["A"] 

# # adding noise to flatten the distributions (make more uncertain) 
noise_level = 0.3 
for i in [1, 2]:  # only modifying the outcome (i=1) and cue (i=2) observation likelihood mappings
    # A[i] = jnp.flip(A[i], axis=1) # flipping for testing purposes

    noise = noise_level * jnp.ones_like(A[i])
    A[i] = A[i] + noise
    A[i] = A[i] / jnp.sum(A[i], axis=1, keepdims=True) # normalise to ensure each distribution sums to 1

pA = A

# setting B tensors from the environment parameters
B = [jnp.array(b, dtype=jnp.float32) for b in env.params["B"]]
B_dependencies = env.dependencies["B"]

# B[0] = jnp.flip(B[0], axis=(1,2)) # flipping for testing purposes

# adding noise to flatten the distributions (make more uncertain) 
# key, subkey = jr.split(key)
# noise = noise_level * jr.uniform(subkey, shape=B[0].shape)
# B[0] = B[0] + noise
# B[0] = B[0] / jnp.sum(B[0], axis=1, keepdims=True) # normalise to ensure each distribution sums to 1

key, subkey = jr.split(key)
noise = noise_level * jr.uniform(subkey, shape=B[1].shape)
# B[1] = jnp.flip(B[1], axis=(1,2)) # flipping for testing purposes
B[1] = B[1] + noise
B[1] = B[1] / jnp.sum(B[1], axis=1, keepdims=True) # normalise to ensure each distribution sums to 1

pB = B

# creating C tensors filled with zeros for [location], [reward], [cue] based on A shapes
C = [jnp.zeros((batch_size, a.shape[1]), dtype=jnp.float32) for a in A] 
# setting preferences for outcomes only
C[1] = C[1].at[:,1].set(2.0)    # prefer reward
C[1] = C[1].at[:,2].set(-3.0)   # avoid punishment


# creating D tensors [location], [reward] based on B shapes
D = []
# D[0]: location - all zeros except location 0 (centre) because the agent always starts in the centre
D_loc = jnp.zeros((batch_size, B[0].shape[1]), dtype=jnp.float32) 
D_loc = D_loc.at[0,0].set(1.0)  # set centre location to 1.0
D.append(D_loc)

# D[1]: reward location - uniform distribution
D_reward = jnp.ones((batch_size, B[1].shape[1]), dtype=jnp.float32) 
D_reward = D_reward / jnp.sum(D_reward, axis=1, keepdims=True)  # normalise to get uniform distribution
D.append(D_reward)


# initialising the agent
agent = Agent(
    A, B, C, D, 
    pA=pA,
    pB=pB, # adding the noisy A tensor for learning
    policy_len=5, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    apply_batch=False, 
    learn_A=False,
    learn_B=True,
    gamma=0.1,
    action_selection="stochastic"
)

In [None]:
key = jr.PRNGKey(0) # random key for the aif loop
T = 10 # number of timesteps to rollout the aif loop for
_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop

By printing the following statements out, we can see that the agent's A and B tensors are learning or updating from the first timestep to the last timestep to be closer to the environment's A and B tensors. 

In [None]:
print("the environment's A tensor")
print(env.params["A"][1][0][:,:,1])
print()
print("the agent's A tensor at t=0")
print(agent.A[1][0][:,:,1])
print()
print(f"the agent's A tensor at t={T}")
print(info["agent"].A[1][-1,0,:,:,1])

In [None]:
print(info["agent"].B[0][0,0,:,:,1])


In [None]:
print("the environment's B tensor")
print(env.params["B"][1][0][:,:,1])
print()
print("the agent's B tensor at t=0")
print(info["agent"].B[1][0,0,:,:,1])
print()
print(f"the agent's B tensor at t={T}")
print(info["agent"].B[1][-1,0,:,:,1])

In [None]:
frames = []
for t in range(info["observation"][0].shape[0]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][t, :, :],
        info["observation"][1][t, :, :],  
        info["observation"][2][t, :, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)

In [None]:
# Random initialization test
# key = jr.PRNGKey(42)
key = jr.PRNGKey(24)

#  setting A tensors from the environment parameters
A = [jnp.array(a, dtype=jnp.float32) for a in env.params["A"]]
A_dependencies = env.dependencies["A"] 

# adding noise to flatten the distributions (make more uncertain) 
noise_level = 0.3 
for i in [1, 2]:  # only modifying the outcome (i=1) and cue (i=2) observation likelihood mappings
    # A[i] = jnp.flip(A[i], axis=1) # flipping for testing purposes

    key, subkey = jr.split(key)
    A[i] = jr.uniform(subkey, shape=A[i].shape) # generating random values between 0 and 1

    A[i] = A[i] / jnp.sum(A[i], axis=1, keepdims=True) # normalise to ensure each distribution sums to 1

pA = A



# setting B tensors from the environment parameters
B = [jnp.array(b, dtype=jnp.float32) for b in env.params["B"]]
B_dependencies = env.dependencies["B"]

# B[0] = jnp.flip(B[0], axis=(1,2)) # flipping for testing purposes

# adding noise to flatten the distributions (make more uncertain) 
key, subkey = jr.split(key)
B[0] = jr.uniform(subkey, shape=B[0].shape)
B[0] = B[0] / jnp.sum(B[0], axis=1, keepdims=True) # normalise to ensure each distribution sums to 1

key, subkey = jr.split(key)
B[1] = jr.uniform(subkey, shape=B[1].shape)
B[1] = B[1] / jnp.sum(B[1], axis=1, keepdims=True) # normalise to ensure each distribution sums to 1

pB = B


# initialising the agent
agent = Agent(
    A, B, C, D, 
    pA=pA,
    pB=pB, # adding the noisy A tensor for learning
    policy_len=5, # how long the action sequence is that the agent is evaluating
    A_dependencies=A_dependencies, 
    B_dependencies=B_dependencies,
    apply_batch=False, 
    learn_A=True,
    learn_B=True,
    gamma=0.1,
    action_selection="stochastic"
)

# running the active inference simulation
key = jr.PRNGKey(0) # random key for the aif loop
T = 50 # number of timesteps to rollout the aif loop for
_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key) # running the aif loop

In [None]:
print("the environment's A tensor")
print(env.params["A"][1][0][:,:,1])
print()
print("the agent's A tensor at t=0")
print(agent.A[1][0][:,:,1])
print()
print(f"the agent's A tensor at t={T}")
print(info["agent"].A[1][-1,0,:,:,1])

In [None]:
print("the environment's B tensor")
print(env.params["B"][1][0][:,:,1])
print()
print("the agent's B tensor at t=0")
print(info["agent"].B[1][0,0,:,:,1])
print()
print(f"the agent's B tensor at t={T}")
print(info["agent"].B[1][-1,0,:,:,1])

In [None]:
frames = []
for t in range(info["observation"][0].shape[0]):  # iterate over timesteps
    # get observations for this timestep
    observations_t = [
        info["observation"][0][t, :, :],
        info["observation"][1][t, :, :],  
        info["observation"][2][t, :, :]   
    ]
       
    frame = env.render(mode="rgb_array", observations=observations_t) # render the environment using the observations for this timestep
    frame = np.asarray(frame, dtype=np.uint8)
    plt.close()  # close the figure to prevent memory leak
    frames.append(frame)

frames = np.array(frames, dtype=np.uint8)
mediapy.show_video(frames, fps=1)