<a href="https://colab.research.google.com/github/yasu-k2/multimodal-active-inference/blob/main/AudioVisualMaze.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pygame
# !pip install inferactively-pymdp

In [None]:
%cd /content

In [None]:
!git clone https://github.com/infer-actively/pymdp.git

In [None]:
%cd pymdp
!pip install -r requirements.txt
!pip install -e ./
%cd ..

In [None]:
!sed -i -e 's/actions\[factor\]\]/int(actions\[factor\])]/g' pymdp/pymdp/learning.py

In [None]:
!git clone https://github.com/yasu-k2/multimodal-active-inference.git

In [None]:
%cd multimodal-active-inference

In [None]:
# !git pull origin main

## Imports

In [None]:
import copy
import itertools

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import pymdp
from pymdp import maths
from pymdp import utils
from pymdp.agent import Agent

from AudioVisualMazeEnv import print_maze, fn_create_maze, create_maze, AudioVisualMazeEnv

In [None]:
%load_ext autoreload
%autoreload 2

## Audio Visual Maze Setup

In [None]:
maze_dir = './'
n_div_sound = 7
maze_height = 7
maze_width = 7
create = True
seed = 4

# assert (maze_width < 32), ValueError('`maze_width` must be smaller than 32')
env = AudioVisualMazeEnv(n_div_sound=n_div_sound, maze_height=maze_height, maze_width=maze_width,
                         create=create, seed=seed, start_pos=(1,1), end_pos='LowerRight', DIR=maze_dir)
maze_array = copy.deepcopy(env.maze_array)

In [None]:
actions = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"]
n_actions = len(actions)

In [None]:
T = 3
obs = env.reset()
print('obs :', obs[1])
print_maze(obs[0])
for t in range(T):
  action_index = np.random.choice(len(actions))
  # action_index = 1
  action_label = actions[action_index]
  obs, rewards, done, info = env.step(action_label)
  print('=====Time {}====='.format(t+1))
  env.render()
  print(env.current_state)  # [Y, X, audio]
  print('obs :', obs[1])
  print_maze(obs[0])

## Active Inference Agent

In [None]:
def plot_grid(grid_locations, num_y=21, num_x=21):
  grid_heatmap = np.zeros((num_y, num_x))
  for linear_idx, location in enumerate(grid_locations):
    y, x = location
    grid_heatmap[y, x] = linear_idx
  sns.heatmap(grid_heatmap, annot=True, cbar=False, fmt='.0f', cmap='crest')

def plot_point_on_grid(state_vector, grid_locations, num_y=21, num_x=21):
  state_index = np.where(state_vector)[0][0]
  print(np.where(state_vector))
  y, x = grid_locations[state_index]
  grid_heatmap = np.zeros((num_y, num_x))
  grid_heatmap[y,x] = 1.0
  sns.heatmap(grid_heatmap, cbar=False, fmt='.0f')

def plot_likelihood(matrix, title_str=""):
  if not np.isclose(matrix.sum(axis=0), 1.0).all():
    raise ValueError("Distrbution not column-normalized.")
  sns.heatmap(matrix, cmap='OrRd', vmin=0.0, vmax=1.0)
  plt.xticks(range(A.shape[1]))
  plt.yticks(range(A.shape[0]))
  plt.title(title_str)

def plot_beliefs(belief_dist, title_str=""):
  if not np.isclose(belief_dist.sum(), 1.0):
    raise ValueError("Distribution not normalized.")
  plt.grid(zorder=0)
  plt.bar(range(belief_dist.shape[0], belief_dist, color='r', zorder=3))
  plt.xticks(range(belief_dist.shape[0]))
  plt.title(title_str)

In [None]:
def create_oracle_B_matrix(maze_array, grid_locations, actions):
  maze_height, maze_width = maze_array.shape
  n_states = len(grid_locations)
  n_action = len(actions)
  B = np.zeros((n_states, n_states, n_actions))
  for action_id, action_label in enumerate(actions):
    for curr_state, grid_location in enumerate(grid_locations):
      y, x = grid_location
      if action_label == "UP":
        next_y = y - 1 if y > 0 else y
        next_x = x
      elif action_label == "DOWN":
        next_y = y + 1 if y < (maze_height-1) else y
        next_x = x
      elif action_label == "LEFT":
        next_x = x - 1 if x > 0 else x
        next_y = y
      elif action_label == "RIGHT":
        next_x = x + 1 if x < (maze_width-1) else x
        next_y = y
      elif action_label == "STAY":
        next_x = x
        next_y = y

      # Blocked by wall
      if maze_array[next_y, next_x] == 1:
        next_x = x
        next_y = y

      next_location = (next_y, next_x)
      next_state = grid_locations.index(next_location)
      B[next_state, curr_state, action_id] = 1.0
  return B

In [None]:
# 0(path), 1(wall), 2(start), 3(goal)
def categorize_maze_obs(maze_array, gy, gx):
  maze_height, maze_width = maze_array.shape
  if ((0 < gy < (maze_height - 1)) and (0 < gx < (maze_width - 1))):
    maze_obs = maze_array[(gy-1):(gy+2), (gx-1):(gx+2)]
  else:
    if gy == 0:
      maze_obs = maze_array[gy:(gy+2)]
      maze_obs = np.vstack([np.ones((1, maze_width)), maze_obs])
    elif gy == (maze_height - 1):
      maze_obs = maze_array[(gy-1):(gy+1)]
      maze_obs = np.vstack([maze_obs, np.ones((1, maze_width))])
    else:
      maze_obs = maze_array[(gy-1):(gy+2)]
    if gx == 0:
      maze_obs = maze_obs[:, gx:(gx+2)]
      maze_obs = np.hstack([np.ones((3, 1)), maze_obs])
    elif gx == (maze_width - 1):
      maze_obs = maze_obs[:, (gx-1):(gx+1)]
      maze_obs = np.hstack([maze_obs, np.ones((3, 1))])
    else:
      maze_obs = maze_obs[:, (gx-1):(gx+2)]
  # print(maze_obs.shape)
  # maze_obs = maze_obs.reshape(-1)
  maze_obs = np.array([maze_obs[0,1], maze_obs[1,0], maze_obs[1,2], maze_obs[2,1]])
  maze_obs_index = sum([int(mo==1.) * (2**moi) for moi, mo in enumerate(maze_obs)])
  # print(maze_obs_index)
  return maze_obs_index

In [None]:
grid_locations = list(itertools.product(range(maze_height), range(maze_width)))
# print(len(grid_locations), grid_locations)
# plot_grid(grid_locations)
starting_location = (1, 1)  # upper left
starting_state_index = grid_locations.index(starting_location)
desired_location = (maze_height-2, maze_width-2)  # lower right
desired_location_index = grid_locations.index(desired_location)

In [None]:
# Define state space
n_location_states = len(grid_locations)  # current position
# no `maze_state` for blocking(0) or non-blocking(1)
# first consider a known maze state. encoded in B matrix.
n_states = [n_location_states]
n_factors = len(n_states)
n_controls = [n_actions]
print(n_states, n_controls)

# Define observation space
n_location_observations = len(grid_locations)  # current position
# n_maze_observations = 2 ** (3 * 3)  # surrounding tiles
n_maze_observations = 2 ** 4  # surrounding tiles (4 ways)
n_sound_observations = n_div_sound  # sound cue intensity
n_reward_observations = 2  # reward or no reward
n_obs = [n_location_observations, n_maze_observations, n_sound_observations, n_reward_observations]
n_modalities = len(n_obs)
print(n_obs)

In [None]:
# Generate the A array
A = utils.obj_array(n_modalities)

A_location = np.zeros((n_location_observations, *n_states))
# print(A_location.shape)
A_location[:, :] = np.eye(n_location_states)
print(utils.is_normalized(A_location))
print(A_location)
A[0] = A_location

A_maze = np.zeros((n_maze_observations, *n_states))
for gli, gl in enumerate(grid_locations):
  gy, gx = gl
  maze_obs_index = categorize_maze_obs(maze_array, gy, gx)
  A_maze[maze_obs_index, gli] = 1.0
print(utils.is_normalized(A_maze))
print(A_maze)
A[1] = A_maze

A_sound = np.zeros((n_sound_observations, *n_states))
# print(env.bins)
for gli, gl in enumerate(grid_locations):
  sli = env.compute_sound(gl[0], gl[1])
  A_sound[sli, gli] = 1.0  # reliable
  # A_sound[:, gli] = maths.softmax(2. * utils.onehot(sli, n_div_sound))  # ambiguous
print(utils.is_normalized(A_sound))
print(A_sound)
A[2] = A_sound

A_reward = np.zeros((n_reward_observations, *n_states))
# The agent knows the rewarding location
A_reward[0, :] = 1.0
A_reward[0, desired_location_index] = 0.0
A_reward[1, desired_location_index] = 1.0
print(utils.is_normalized(A_reward))
print(A_reward)
A[3] = A_reward

In [None]:
# Generate the B array
B = utils.obj_array(n_factors)
B_location = create_oracle_B_matrix(maze_array, grid_locations, actions)  # perfect
# B_location = utils.random_B_matrix(n_states, n_controls)  # random
print(utils.is_normalized(B_location))
print(B_location)
B[0] = B_location

In [None]:
# Generate the C array
C = utils.obj_array_zeros(n_obs)

C_location = np.zeros(n_location_observations)
C_location[desired_location_index] = 1.0
# print(C_location)
# print(utils.is_normalized(C_location))

# C_maze = np.zeros(n_maze_observations)

C_sound = np.zeros(n_sound_observations)
# desired_sound_index = 0  # smallest sound
# C_sound[desired_sound_index] = 1.0
sound_pref = 1. - np.linspace(0., 1., n_div_sound)  # smaller the better
C_sound[:] = sound_pref
print(C_sound)
# print(utils.is_normalized(C_sound))

C_reward = np.zeros(n_reward_observations)
C_reward[1] = 5.0
print(C_reward)
# print(utils.is_normalized(C_reward))

# Choose your preference
# C[0] = C_location
# C[1] = C_maze
C[2] = C_sound
C[3] = C_reward

In [None]:
# Generate the D array
D = utils.obj_array(n_factors)

D_location = utils.onehot(starting_state_index, n_location_states)
# plot_point_on_grid(D_location, grid_locations)
print(utils.is_normalized(D_location))
print(D_location)
D[0] = D_location

In [None]:
def run_active_inference_loop(my_agent, my_env, T=3):
  obs = my_env.reset()
  reward = 0
  for t in range(T):
    my_env.render()
    print('obs :', obs[1])
    print_maze(obs[0])
  
    location = (my_env.current_state[0], my_env.current_state[1])
    maze_obs_index = categorize_maze_obs(maze_array, *location)
    sound = obs[1]
    obs_agent = [grid_locations.index(location), maze_obs_index, sound, reward]
    qs = my_agent.infer_states(obs_agent)
    q_pi, efe = my_agent.infer_policies()
    action_id = my_agent.sample_action()
    action_index = int(action_id[0])

    action_label = actions[action_index]
    obs, reward, done, info = my_env.step(action_label)
    print('=====Time {}====='.format(t+1))

    print(f'Action at time {t+1}: {action_label}')
    print(f'Reward at time {t+1}: {reward}')

In [None]:
policy_len = 4
controllable_indices = [0]
action_selection = 'stochastic'  # 'deterministic', 'stochastic'

In [None]:
my_agent = Agent(A=A, B=B, C=C, D=D,
                 policy_len=policy_len,
                 control_fac_idx=controllable_indices,
                 action_selection=action_selection)

In [None]:
T = 1
run_active_inference_loop(my_agent, env, T=T)

## ActInf with learning

Choose the matrix you want to learn. Learning A in the example below.

In [None]:
# Agent(A=A, pA=pA, modalities_to_learn=learnable_modalities, lr_pA=1.0)
#  agent.update_A(obs): update_obs_likelihood_dirichlet(pA, A, obs, qs)
pA = utils.dirichlet_like(A, scale=1.0)
# TODO: do something with pA
A_gm = utils.norm_dist_obj_arr(pA)

# Agent(B=B, pB=pB, factors_to_learn=learnable_factors, lr_pB=1.0)
#   agent.update_B(qs_prev): update_state_likelihood_dirichlet(pB, B, actions, qs, qs_prev)
pB = utils.dirichlet_like(B, scale=1.0)
# TODO: do something with pB
B_gm = utils.norm_dist_obj_arr(pB)

# Agent(D=D, pD=pD, factors_to_learn=learnable_factors, lr_pD=1.0)
#   agent.update_D(qs_t0=None): update_state_prior_dirichlet(pD, qs)
pD = utils.dirichlet_like(D, scale=1.0)
# TODO: do something with pD
D_gm = utils.norm_dist_obj_arr(pD)

# TODO: specify indices
learnable_modalities = [1]
A_gm_l = copy.deepcopy(A)
for lm in learnable_modalities:
  # print([n_obs[lm]] + n_states)
  A_gm_l[lm] = utils.norm_dist(np.random.rand(*[n_obs[lm]] + n_states))  # random
  # A_gm_l[lm] = A_gm[lm].copy()  # default

# TODO: specify indices
learnable_factors = [0]
B_gm_l = copy.deepcopy(B)
for lf in learnable_factors:
  # print(n_states[lf], n_states[lf], n_controls[lf])
  B_gm_l[lf] = utils.norm_dist(np.random.rand(n_states[lf], n_states[lf], n_controls[lf]))  # random
  # B_gm_l[lf] = B_gm[lf].copy()  # default

In [None]:
# Learn only A
my_agent_a = Agent(A=A_gm_l, B=B, C=C, D=D,
                   pA=pA, lr_pA=1.0,
                   policy_len=policy_len,
                   control_fac_idx=controllable_indices,
                   action_selection=action_selection,
                   modalities_to_learn=learnable_modalities,
                   use_param_info_gain=True)

In [None]:
# Learn only B
my_agent_b = Agent(A=A, B=B_gm_l, C=C, D=D,
                   pB=pB, lr_pB=1.0,
                   policy_len=policy_len,
                   control_fac_idx=controllable_indices,
                   action_selection=action_selection,
                   factors_to_learn=learnable_factors,
                   use_param_info_gain=True)

In [None]:
# Learn A and B
my_agent_ab = Agent(A=A_gm_l, B=B_gm_l, C=C, D=D,
                    pA=pA, lr_pA=1.0, pB=pB, lr_pB=1.0,
                    policy_len=policy_len,
                    control_fac_idx=controllable_indices,
                    action_selection=action_selection,
                    modalities_to_learn=learnable_modalities,
                    factors_to_learn=learnable_factors,
                    use_param_info_gain=True)

In [None]:
def run_active_inference_loop_with_learning(my_agent, my_env, T=3, learn_A=False, learn_B=False):
  obs = my_env.reset()
  reward = 0
  # qs = [utils.norm_dist(np.ones(n_location_states))]  # tmp
  qs = [D_location]  # tmp
  for t in range(T):
    # env.render()
    location = (my_env.current_state[0], my_env.current_state[1])
    maze_obs_index = categorize_maze_obs(maze_array, *location)
    sound = obs[1]
    obs_agent = [grid_locations.index(location), maze_obs_index, sound, reward]
    qs_prev = qs.copy()
    qs = my_agent.infer_states(obs_agent)
    if learn_A:
      pA_t = my_agent.update_A(obs_agent)  # update A via pA
    print(qs_prev, qs)
    if learn_B and (t > 0):
      pB_t = my_agent.update_B(qs_prev)  # update B via pB
    q_pi, efe = my_agent.infer_policies()
    action_id = my_agent.sample_action()

    action_index = int(action_id[0])
    action_label = actions[action_index]
    obs, reward, done, info = my_env.step(action_label)    

In [None]:
T = 1
run_active_inference_loop_with_learning(my_agent_a, env, T=T, learn_A=True)

In [None]:
T = 1
run_active_inference_loop_with_learning(my_agent_b, env, T=T, learn_B=True)

In [None]:
T = 1
run_active_inference_loop_with_learning(my_agent_ab, env, T=T, learn_A=True, learn_B=True)