In [3]:
import numpy as np
import scipy.stats as stats
from scipy.special import digamma, softmax
import matplotlib.pyplot as plt 
import seaborn as sns
from itertools import product

# Aim

This notebook meant as a partial re-implementation in numpy of the model presented in the paper "Active Inference on Discrete State-Space - A synthesis" by Da Costa et al., which can be found [here](https://arxiv.org/pdf/2001.07203.pdf). The paper is really great, but it's pretty difficult at points, so to help me understand the passages clearly I thought of taking notes and organizing them here. Some bits that I found difficult are explained a bit more slowly here than they are in the paper. In other parts I just assume what they say in the paper.

I am doing this cause I had a bit of free time in the last few days, but I will get busy again soon. So if somebody wants to expand this more, please let me know!

## Generative model

Let's start by reviewing the components of the model:

- $s_\tau$: a one-hot vector modelling the state at time $\tau$. The vector is 1 on the index of the state. The number of components is $m$.
- $o_\tau$: like $s_\tau$, but for observations. Number of components is $n$.
- $\pi$ is a policy, which is just a vector whose $i^{th}$ element, $\pi_i$, contains the action to be performed at timestep $i$.
- $A$: A likelihood matrix with shape $n \times m$. This is used in the generative model to model how states generate observations. Each column corresponds to a state, and each row to an observation. Each column is a probability vector over observations, describing the probabilities that the state corresponding to the column would produce each row/observation. $o_\tau^T A$ is the vector of the probabilities of $o_\tau$ being generated by each state. $A s_\tau$ is the vector with the probability of each observation given state $s_\tau$. $o_\tau^T A s_\tau$ is the probability that state $s_\tau$ would generate observation $o_\tau$. 
- $a$: a matrix encoding the hyperprior over likelihood matrices $A$. The $j^{th}$ column of $a$, $a_{\bullet, j}$, contains the $\alpha$ vector to parameterize a Dirichlet distribution. In practice, each column of $a$ contains the parameters to sample the vector of probabilities of each observation given the state corresponding to the column.
- $B$: Transition matrix between states which depends on the agent's action. It is indexed by an action, e.g. $B_{\pi_i}$ (So really, it's a 3-d array, which at each action-index has an $m \times m$ transition matrix). The $i^{th}$ column of $B_j$ contains the probabilities of transitioning to each possible state in the time step. So $B_{\pi_{i-1}} s_{\tau - 1}$ is the vector of probabilities of transitioning to each state at time $\tau$, given the action indicated by the policy at time $i-1$, $\pi_{i-1}$, and the state at time $s_{\tau-1}$.
- $D$: Parameters for categorical hyperprior over first state. 
- $C$: Parameters for categorical hyperprior over states or outcomes, depending on how you want to specify the agent's preferences. It encodes prior probabilities that effectly work like preferences for states or sensory input in the free energy framework.
- $E$: Parameters for categorical hyperprior over policies

And now the generative model in all its beauty, screenshotted from the paper:

<img src="img/generativeModelDescription.png" width="400">

<img src="img/generativeModelGraph.png" width="400">

The generative model tells us the joint probability of likelihood matrix ($A$), policy ($\pi$), states, and observations. It is called generative because we can use it to generate possible histories by successively sampling as described in the model:

In [4]:
def generate_input(n_states=2, n_actions=2, n_obs=2, T=10, policy_setting="essential"):
    
    if policy_setting == "exhaustive":
        # sample pi (node 1). This is not the most efficient implementation, but it is the clearest.
        # create set of possible policies, i.e. set of action indices of length T
        set_policies = list(product(range(n_actions), repeat=T))
        # create the hyperprior over policies
        E = np.random.dirichlet(alpha=[1]*len(set_policies))
        # sample the index of the policy to select from set_policies
        index_policy = np.random.choice(len(set_policies), p=E)
        pi = set_policies[index_policy]
    elif policy_setting == "essential":
        pi = np.random.randint(0, n_actions, size=T)

    # sample A (node 4)
    # First sample hyperparameters
    a = stats.halfnorm.rvs(size=(n_obs, n_states))
    # for each column in the hyperparameters matrix a, sample from a dirichlet
    # (then transpose cause each dirichlet sample create a row, instead of a column)
    A = np.array([np.random.dirichlet(a[:,i]) for i in range(a.shape[1])]).T

    # sample s_1 (node 5)
    D = np.random.dirichlet(alpha=[1]*n_states)

    # for simplicity, I am coding B as a 3-d array with dimensions (n_actions, n_states, n_states)
    # first I create an array with totally random elements >0 and then normalize over the 2nd dimension
    # so that each row sums to 1
    B = np.random.rand(n_actions, n_states, n_states)
    B /= B.sum(axis=1, keepdims=True)
    
    return n_states, n_obs, T, B, pi, A, D, a
    

def generate_history(n_states, n_obs, T, B, pi, A, D):

    s = np.random.choice(n_states, p=D)

    s_history, o_history= [], []
    # in a loop which models the time steps
    for i in range(T):
        s_history.append(s)
        
        # sample observation given A and state s
        o = np.random.choice(n_obs, p=A[:,s])
        o_history.append(o)
        
        # update state s_i given policy, B, and previous state
        s = np.random.choice(n_states, p=B[pi[i],:,s])

        
    return o_history, s_history

Example of a run:

In [5]:
model_input = generate_input()
o_history, s_history = generate_history(*model_input[:-1])
print(np.column_stack((o_history, s_history)))

[[1 1]
 [0 0]
 [1 0]
 [0 0]
 [1 1]
 [0 0]
 [0 0]
 [1 1]
 [1 1]
 [0 0]]


## Approximate posterior

Of course, it is easy to produce imagined histories from the generative model. What's difficult is to _invert_ the generative model. That means to observe data and create a posterior distribution for the values of the unobserved variables in the model.

The point of Variational Bayesian methods is to give us a way to estimate the posterior from the generative model and the data. And that's of course the tricky bit! So instead of using the original generative model, we use another model whose variables we estimate to be as close as possible (in KL-divergence) to the unknown posterior. Before we see that, let me introduce a few more symbols:
- $\boldsymbol{\pi}$: A vector of probabilities that parameterize the categorical distribution from which policies $\pi$ are sampled.
- $\mathbf{s}$: This is a 3-d array which, for every combination of policy and time, gives us a vector of the probabilities of each state. The dimensions are (num policies, T, n_states).
- $\mathbf{a}$: Much like $a$ above, each column parameterizes a Dirichlet distribution corresponding to a state. A sample from the Dirichlet corresponding to a state is the parameters of a categorical distribution over observations.

And finally, the beautiful approximated model, which calculates the joint probability of the history of states and the matrix encoding the transition probabilities between states (_assuming a certain policy_) just in terms of a distribution over states for each timestep. Basically, only the essentials:

<img src="img/approximatedPosterior.png" width="400">

## Perception

In glorious active inference tradition, perception is thought of as inference about hidden states from sensation. This is done by minimizing the free energy wrt the hidden states in the approximate posterior, for every policy. This gives the best approximation to the true posterior in terms of the hidden states for all timesteps. How to do this is derived in equations (5), (6), and (7) of the paper. What I found a little bit harder to understand is the passage from (6) to (7), so I am going to review it a bit.

Consider first the first summand transformed in terms of the sufficient parameters of $Q$, contained in $\mathbf{s}$:

\begin{align}
\sum_{\tau=1}^T \mathbb{E}_{Q(s_\tau \mid \pi)}\left[ \log Q\left( s_\tau \mid \pi \right) \right] & \implies \sum_{\tau=1}^T \mathbf{s}_{\pi \tau} \cdot \log \mathbf{s}_{\pi \tau}
\end{align}

$Q(s_\tau \mid \pi)$ is a distribution over the state at time $\tau$, and therefore ranges over states. $\mathbb{E}_{Q(s_\tau \mid \pi)}\left[ \bullet \right]$ is an expected value under the distribution over states. Since this is a discrete distribution, it can calculated as a sum over states:

$$
\sum_{j=1}^{\text{num states}} Q(s_\tau = s_j \mid \pi) \bullet
$$

$\mathbf{s}_{\pi \tau}$ is a vector of the probabilities of each state given policy $\pi$ and time $\tau$. This is indeed the vector of probabilities that we want to sum over to get the expected value (after multiplying it element-wise by the $\bullet$). $\log$ is applied element-wise here. So, assume for instance that for some policy $\pi_1$ and time $\tau_1$:
$$
\mathbf{s}_{\pi_1 \tau_1} = \begin{bmatrix}
0.5 \\
0.25 \\
0.25
\end{bmatrix}
$$

Then:

\begin{align}
& \mathbf{s}_{\pi_1 \tau_1} \cdot \log \mathbf{s}_{\pi_1 \tau_1} \\
= & \begin{bmatrix} 0.5 \\ 0.25 \\ 0.25 \end{bmatrix} \cdot \log_2 \left( \begin{bmatrix} 0.5 \\ 0.25 \\ 0.25 \end{bmatrix} \right) \\
= & \begin{bmatrix} 0.5 \\ 0.25 \\ 0.25 \end{bmatrix} \cdot \begin{bmatrix} -1 \\ -2 \\ -2 \end{bmatrix} \\ 
= & 0.5 \times -1 + 0.25 \times -2 + 0.25 \times -2 = -1.5 
\end{align}

Once we calculate this single number for all timesteps, we just sum over timesteps.

Now onto the second summand! The expectation is expressed in terms of the sufficient parameters of $Q$ as follows:

$$
\sum_{\tau=1}^t \mathbb{E}_{Q( s_\tau \mid \pi )Q(A)}\left[ \log P\left( o_\tau \mid s_\tau , A \right) \right]
\implies \sum_{\tau=1}^{t} o_\tau \cdot \mathrm{\textbf{log}}\mathbf{A s_{\pi\tau}}
$$

$o_\tau$ here is the observation that was in fact made at time $\tau$ - we are not calculating expectations over that. First, note that from the generative model above we know that (where, as above, $o_\tau$ and $s_\tau$ are one-hot vectors, and $A$ is the likelihood matrix, and $\log$ is applied component-wise):

$$
\log P\left( o_\tau \mid s_\tau , A \right) = o_\tau \cdot \log A s_\tau
$$

And therefore:

\begin{align}
& \mathbb{E}_{Q( s_\tau \mid \pi )Q(A)}\left[ \log P\left( o_\tau \mid s_\tau , A \right) \right] \\
= & \sum_s \sum_a \left[ Q(A=a) Q( s_\tau=s \mid \pi ) \log P\left( o_\tau \mid s, a \right) \right] \\
= &  \sum_s \left[ Q( s_\tau=s \mid \pi ) \sum_a \left[ Q(A=a) o_\tau \cdot \log a s \right] \right] \\
= & o_\tau \cdot \sum_s \left[ Q( s_\tau=s \mid \pi ) \sum_a \left[ Q(A=a) \log a \right] s \right] \\
= & o_\tau \cdot \sum_a \left[ Q(A=a) \log a \right] \sum_s \left[ Q( s_\tau=s \mid \pi )  s \right] \\
= & o_\tau \cdot \mathrm{\textbf{log}}\mathbf{A} \sum_s \left[ Q( s_\tau=s \mid \pi )  s \right] \\
= & o_\tau \cdot \mathrm{\textbf{log}}\mathbf{A} \mathbf{s_{\pi\tau}} \\
\end{align}


$\mathrm{\textbf{log}}\mathbf{A}$ is defined as the expected value under $Q(A)$ of the logarithm of $A$, $\mathbf{E}_{Q(A)} \left[ \log(A) \right]$, and it is an $n \times m$ matrix. It is calculated with the digamma function $\psi$: see expected value of the log of a Dirichlet distribution on wikipedia.

The last thing to clarify for the second summand is why:

$$
\mathbf{s_{\pi\tau}} = \sum_s \left[ Q( s_\tau=s \mid \pi )  s \right]
$$

To see why, recall that $s$ is not a number, but rather a one-hot vector. Therefore, the sum is a sum over one-hot vectors, where the 1 in each vector becomes the probability of the respective state.

Now, onto the third summand! The identity is as follows:

\begin{align}
\mathbb{E}_{Q(s_1 \mid \pi)} \left[ \log P(s_1) \right] & = \sum_s \left[ Q(s_1=s \mid \pi) \log P(s) \right] \\
& = \mathbf{s_{\pi 1}} \log D
\end{align}

And finally, the fourth summand:

$$
\sum_{\tau=2}^T \mathbb{E}_{Q(s_\tau \mid \pi) Q(s_{\tau-1} \mid \pi)}\left[ \log P (s_\tau \mid s_{\tau-1}, \pi) \right] \implies \sum_{\tau=2}^T \mathbf{s_{\pi \tau}} \cdot \log \left( B_{\pi_{\tau-1}} \right) \mathbf{s_{\pi \tau-1}}
$$

Okay, this one looks a little bit nasty. Let's see:

\begin{align}
\mathbb{E}_{Q(s_\tau \mid \pi) Q(s_{\tau-1} \mid \pi)}\left[ \log P (s_\tau \mid s_{\tau-1}, \pi) \right] & = 
\sum_{s_i} \sum_{s_j}\left[ Q(s_\tau = s_i \mid \pi) Q(s_{\tau-1} = s_j \mid \pi)  \log P (s_i \mid s_j, \pi) \right] \\
& = \sum_{s_i} \sum_{s_j} \left[ Q(s_\tau = s_i \mid \pi) Q(s_{\tau-1} = s_j \mid \pi) s_i \cdot \log \left( B_{\pi_{\tau-1}} \right) s_j \right] \\
& = \sum_{s_i} \left[ Q(s_\tau = s_i \mid \pi) s_i \right] \cdot \log \left( B_{\pi_{\tau-1}} \right) \sum_{s_j} \left[ Q(s_{\tau-1} = s_j \mid \pi) s_j \right]  \\
& = \mathbf{s_{\pi \tau}} \cdot \log \left( B_{\pi_{\tau-1}} \right) \mathbf{s_{\pi \tau-1}}
\end{align}

So, to sum up, what do we have? We have a way of calculating the free energy of a parameterization of $Q$ in terms of the probabilities of each state at each timestep (given a policy). Since perception is conceptualized as estimation of the true state, we _almost_ have a model of perception!

In [6]:
def calculate_free_energy_states(s_pi, T, logA, D, B, pi, o_history):
    
    # first summand
    first_summand = np.diag(s_pi @ np.log(s_pi).T)
    
    # second summand (TODO: vectorize)
    second_summand = np.sum([(logA @ s_pi[t].reshape(-1,1))[o] for t, o in enumerate(o_history)])
    
    # third summand
    third_summand = s_pi[0] @ np.log(D).reshape(-1,1)
    
    # fourth summand (TODO: vectorize)
    fourth_summand = np.sum([s_pi[tau] @ np.log(B[pi[tau]]) @ s_pi[tau-1] for tau in range(1,T)])
    
    return first_summand - second_summand - third_summand - fourth_summand


def generate_random_s_pi(T, n_states):
    unnorm = np.random.uniform(size=(T, n_states))
    return unnorm / np.sum(unnorm, axis=1, keepdims=True)

Example of calculation of the free energy. Note that the result is an array with length T. This is because we are approximating, with the Q distribution, one probability vector over states for each timestep. So this is a measure, in free-energy terms, of how close our approximation with $s_\pi$ gets to the true posterior:

In [375]:
n_states, n_obs, T, B, pi, A, D, a = generate_input()
s_pi = generate_random_s_pi(T, n_states)
o_history, s_history = generate_history(n_states, n_obs, T, B, pi, A, D)
o_history_until_t = o_history[:5]
calculate_free_energy_states(s_pi, T, logA, D, B, pi, o_history_until_t)

array([14.37663319, 14.32736888, 14.33893073, 14.29530552, 14.31860815,
       14.42665794, 14.51651092, 14.29676454, 14.30397585, 14.29628657])

What's missing? Well, just knowing the free energy of a set of parameters is not enough. What we do in perception is find the set of parameters that _minimizes_ the free energy. This is done by gradient descent. Now that we have the free energy of a parameterization of Q consisting of the probabilities of all states for each timestep, the authors calculate the gradient wrt to each timestep, which is a vector for each timestep. The total gradient therefore (as it should) has the same shape as the parameters we are updating: a $T \times n$ matrix.

I am not going to repeat the formula in the paper, but rather just implement it in python:

In [7]:
def gradient_free_energy_perception(s_pi, a, D, o_history_until_t, pi):
    """
    Function to calculate the gradient of the sufficient parameters for the Q distribution
    over states at each timestep, s_pi.
    
    Parameters
    ----------
    s_pi: array
        An array with shape (T, n). Encodes the current estimation of the probability 
        of each state at each timestep.
    a: array
        An array with shape (n, m). Encodes the hyperprior over transition matrices in Q.
        Represented with a bold "a" in the paper.
    D: array
        Array of length n, encodes the prior probabilities of each state on the first step.
    o_history_until_t: array
        Vector of observations of length t.
    pi: array
        Vector of action indices.
    
    Returns
    -------
    array
        The gradient of s_pi at the current point.
    """
    
    t = len(o_history_until_t)
    
    a_0 = np.sum(a, axis=1, keepdims=True)
    logA = digamma(a) - digamma(a_0)
    logD = np.log(D)
    
    conditional_part = np.zeros(shape=(s_pi.shape))
    
    # TODO: check if this loop is a bottleneck & vectorize if so. Otherwise, keep it: it's clearer this way.
    for tau in range(T-1):
        s_pi_tau = s_pi[tau]
        
        log_B_pi_tau = np.log(B[pi[tau]])
        log_B_pi_tau_minus_one = np.log(B[pi[tau-1]])
        
        log_s_pi_tau = np.log(s_pi_tau)
        # this is nonsense, but also it's not used, when tau=0
        s_pi_tau_minus_one = s_pi[tau-1]
        
        s_pi_tau_plus_one = s_pi[tau+1]
        
        if tau == 0:
            # NOTE: o_tau is the index of o at time tau, rather than the one-hot like in the paper
            o_tau = o_history_until_t[tau] 
            x = logA[o_tau] + s_pi_tau_plus_one @ log_B_pi_tau + logD
        elif 0 < tau and tau < t:
            o_tau = o_history_until_t[tau]
            x = logA[o_tau] + s_pi_tau_plus_one @ log_B_pi_tau + log_B_pi_tau_minus_one @ s_pi_tau_minus_one
        else:
            x = s_pi_tau_plus_one @ log_B_pi_tau + log_B_pi_tau_minus_one @ s_pi_tau_minus_one
        
        conditional_part[tau] = x
    
    return 1 + np.log(s_pi) - conditional_part


def gradient_descent(n_states, n_obs, T, B, pi, A, D, a, o_history_until_t, s_history,
                     learning_rate=0.05, feedback=[]):
    """
    Perform gradient descent on s_pi.
    """
    
    # careful here, approximate_a is represented as the bold a in the paper. Part of Q distribution!
    fixed_input = {"a": a, "D": D, "o_history_until_t": o_history_until_t, "pi": pi}

    # initial guess - can be improved
    s_pi = generate_random_s_pi(T, n_states)
    
    t = len(o_history_until_t)

    for i in range(10):
        gradient = gradient_free_energy_perception(s_pi, **fixed_input)

        if not i%1:
            if "FE" in feedback:
                print("FE: ", calculate_free_energy_states(s_pi, T, logA, D, B, pi, o_history_until_t))
            if "s_pi" in feedback:
                print("s_pi\n", np.round(s_pi[:10], 3))
            if "gradient" in feedback:
                print("gradient \n", gradient[:10], "\n\n")
            if "accuracy" in feedback:
                print("Accuracy on observed: ", 
                      # np.argmax(s_pi, axis=1) is the vector of predictions on the hidden states
                      1-(np.sum(np.absolute(s_history[:t] - np.argmax(s_pi, axis=1)[:t]))/t))
            if "loss" in feedback:
                print("Loss: ", -np.sum(np.log(s_pi[np.arange(len(s_pi)), s_history])))

        s_pi = softmax(s_pi - learning_rate * gradient, axis=1)
        
    return s_pi

Let's work through a simple example with a generative model that has a predictable behaviour. First define the basic parameters:

In [528]:
T = 80
n_states, n_obs, n_actions = 2,2,2

Probability vector for first state. Almost always starts with state 0:

In [528]:
D = np.array([
    0.9, 0.1
])

Always perform action 0, i.e. only consider the first element of B for all timesteps:

In [528]:
pi = [0]*T

Both actions have the same effect: the state remains the same with high probability:    

In [528]:
B = np.array([
    [[0.9, 0.1],
     [0.1, 0.9]],
    
    [[0.9, 0.1],
     [0.1, 0.9]]
])

Again, A encodes the likelihood of producting each observation (row) given the state (column). If state is 0, usually produces observation 0. If state is 1, usually produces observation 1:

In [528]:
A = np.array([
    [0.9, 0.1],
    [0.1, 0.9]
])

$a$ here should be the expected value of A. I'll just set it to A for simplicity.

In [528]:
a = A

Finally, generate the history and do the gradient descent:

In [528]:
o_history, s_history = generate_history(n_states, n_obs, T, B, pi, A, D)

t = 70 # what's the present timestep?
o_history_until_t = o_history[:t]

s_pi = gradient_descent(
    n_states, n_obs, T, B, pi, A, D, a, o_history_until_t, s_history, learning_rate=0.1, feedback=["accuracy"])

Accuracy on observed:  0.48571428571428577
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143
Accuracy on observed:  0.9142857142857143


The effect of the parameters set above is that the generative model is quite predictable in general.

However, since in general the observations track the state so well, the agent is fooled when the observation happens to be different from the real state, because the observation is guessed rather than the state:

In [529]:
print(f"s history beginning: {np.array(s_history[:20])}")
print(f"o history beginning: {np.array(o_history[:20])}")
print(f"Predicted states:    {np.argmax(s_pi, axis=1)[:20]}")

s history beginning: [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0]
o history beginning: [0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 1 0 0 0 0]
Predicted states:    [0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 1 0 0 0 0]


In a slightly more complicated case, the agent is capable of inferring the the observation is different from the state, even with pretty uninformative observations. To do so, the agent has to know that whenever a certain action is performed, a certain state follows. So we need to change $A$, $B$ and $\pi$:

In [530]:
# in new policy, alternative first and second moves
pi = [0, 1]*(T//2)

# when action 0 is performed, system always tends to go to state 0.
# when action 1 is performed, system always tends to go to state 1.
B = np.array([
    [[0.999, 0.999],
     [0.001, 0.001]],
    
    [[0.001, 0.001],
     [0.999, 0.999]]
])

# make the observation less informative
A = np.array([
    [0.6, 0.4],
    [0.4, 0.6]
])
a = A

o_history, s_history = generate_history(n_states, n_obs, T, B, pi, A, D)
t = 70 # what's the present timestep?
o_history_until_t = o_history[:t]
s_pi = gradient_descent(n_states, n_obs, T, B, pi, A, D, a, o_history_until_t, 
                        s_history, learning_rate=0.1, feedback=["accuracy"])

Accuracy on observed:  0.5285714285714286
Accuracy on observed:  1.0
Accuracy on observed:  1.0
Accuracy on observed:  1.0
Accuracy on observed:  1.0
Accuracy on observed:  1.0
Accuracy on observed:  1.0
Accuracy on observed:  1.0
Accuracy on observed:  1.0
Accuracy on observed:  1.0


Note that the agent is guessing the following: 
1. start with state 0, because of the way $D$ was specified
2. guess state 0 (because action 0 was performed at time 0, leading to state 0 according to B)
3. guess state 1 (because action 1 was performed at time 1, leading to state 1 according to B)
4. Repeat!

Importantly, because $A$ doesn't give much information anymore, the agent disregards the observations in the inference of the states and instead uses the relation between the actions and the resulting states. This can be seen by printing the first few passages:

In [531]:
print(f"s history beginning: {np.array(s_history[:20])}")
print(f"o history beginning: {np.array(o_history[:20])}")
print(f"Actions performed:   {np.array(pi[:20])}")
print(f"Predicted states:    {np.argmax(s_pi, axis=1)[:20]}")

s history beginning: [0 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0]
o history beginning: [0 1 1 1 1 1 0 1 0 1 1 0 0 0 1 0 1 0 1 0]
Actions performed:   [0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1]
Predicted states:    [0 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0]


Finally, it's fun to just see how the agent does with totally random histories:

In [587]:
n_states, n_obs, T, B, pi, A, D, a = generate_input(T=100)
o_history, s_history = generate_history(n_states, n_obs, T, B, pi, A, D)

approximate_a = a # for the moment assume that they are the same
num_observed = 70 # what's the present timestep?
o_history_until_t = o_history[:num_observed]

# print(np.column_stack((o_history, s_history)))
s_pi = gradient_descent(n_states, n_obs, T, B, pi, A, D, approximate_a, o_history_until_t, s_history, 
                 learning_rate=0.01, feedback=["accuracy", "loss"])

print(f"s history beginning: {np.array(s_history[:20])}")
print(f"o history beginning: {np.array(o_history[:20])}")
print(f"Actions performed:   {np.array(pi[:20])}")
print(f"Predicted states:    {np.argmax(s_pi, axis=1)[:20]}")

Accuracy on observed:  0.48571428571428577
Loss:  81.34561299262005
Accuracy on observed:  0.48571428571428577
Loss:  71.05903597983296
Accuracy on observed:  0.4571428571428572
Loss:  69.7692718346429
Accuracy on observed:  0.41428571428571426
Loss:  69.64063637824307
Accuracy on observed:  0.3857142857142857
Loss:  69.68994903883177
Accuracy on observed:  0.37142857142857144
Loss:  69.74098737056622
Accuracy on observed:  0.4
Loss:  69.7724734057163
Accuracy on observed:  0.4
Loss:  69.78945676715549
Accuracy on observed:  0.4
Loss:  69.7981494992192
Accuracy on observed:  0.4
Loss:  69.8024963988255
s history beginning: [0 0 0 0 1 1 0 1 1 1 0 0 0 0 0 0 0 1 0 0]
o history beginning: [1 1 0 0 0 0 1 0 0 0 0 0 0 0 1 0 1 0 1 1]
Actions performed:   [0 1 1 0 1 0 1 1 1 1 1 1 1 0 0 1 0 1 1 0]
Predicted states:    [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


In [592]:
for x, xname in zip([T, B, pi, A, D, a], ["T", "B", "pi", "A", "D", "a"]):
    print(xname, ":")
    display(x)

T :


100

B :


array([[[0.63855892, 0.76243884],
        [0.36144108, 0.23756116]],

       [[0.53727141, 0.3423191 ],
        [0.46272859, 0.6576809 ]]])

pi :


array([0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1,
       0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1,
       0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
       1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0])

A :


array([[0.49603389, 0.96549211],
       [0.50396611, 0.03450789]])

D :


array([0.63646255, 0.36353745])

a :


array([[0.55068062, 0.98749887],
       [0.21095903, 0.62625738]])

## Action selection

- $\textbf{a}$: As explained above, an $n \times m$ matrix, encoding the hyperparameters from which the approximate likelihood matrix is sampled. Each column is an $\alpha$ parameter for a Dirichlet distribution. 
- $\textbf{a_0}$: An $n \times m$ matrix, which consists of one column vector with $n$ components repeated $m$ times. The single value of each row is calculated by summing all the columns of $\textbf{a}$.
- $\textbf{A}$ (Note: different from $A$!): Defined in the paper as the expectation under $Q$ of $A$, i.e. $\sum_a Q(A=a)\left[ a \right]$. Basically it is $\textbf{a}$, but with normalized rows. This normalization is done by component-wise division by the corresponding elements of $\textbf{a_0}$

The fundamental thing to notice is that expected free energy does not concern all timesteps, but rather only a specific one, which it usually set to the last timestep $T$. This is not emphasized very much in the paper so I was confused about dimensions when I started implementing this!

In [125]:
def calculate_H(A_bold):
    # a is n x m, therefore A.T @ np.log(A) is m x m, and its diagonal is length m.
    # which is the number of states.
    # this is right, considering it multiplies s_pi_tau, which contains probabilities of states.
    return - np.diag(A_bold.T @ np.log(A_bold))

def calculate_W(a_bold, a_0_bold):
    return 0.5 * ( (1/a_bold) - (1/a_0_bold) )

def calculate_ambiguity(s_pi_tau, A_bold):
    # they are both vectors, so this is simply the dot product
    return calculate_H(A_bold) @ s_pi_tau

def calculate_risk(s_pi_tau, C):
    return s_pi_tau @ (np.log(s_pi_tau) - np.log(C))

def calculate_novelty(A_bold, s_pi_tau, a_bold, a_0_bold):
    return (A_bold @ s_pi_tau.T) @ (calculate_W(a_bold, a_0_bold) @ s_pi_tau)

def expected_free_energy(s_pi_tau, A_bold, C, a_bold, a_0_bold):
    ambiguity = calculate_ambiguity(s_pi_tau, A_bold)
    risk = calculate_risk(s_pi_tau, C)
    novelty = calculate_novelty(A_bold, s_pi_tau, a_bold, a_0_bold)
    return ambiguity + risk - novelty

Problem: To calculate the expected free energy under a policy, I need the approximate prior over states at each timestep, s_pi_tau. However, I cannot estimate the probability of the states at each timestep (s_pi_tau) without minimizing the free energy over states. The gradient is minimized with respect to a certain history. But producing a history requires a policy!

The solution to this lies in the concept of _active inference_: 
- First sample an initial state, for which no policy is required
- Then, sample an initial observation.
- Based on the initial observation, do free energy minimization on the hidden states and get an s_pi
- Based on the obtained s_pi, do free energy minimization on the policies and pick an action.
- Perform the action, which causes another state, which causes another observation
- Repeat!

The only thing to be careful with is to restrict the set of policies to consider to the ones that are compatible with actions performed in the past: can't change the past!

In [137]:
def pick_policy(n_states, n_obs, n_actions, i, T, B, A_bold, C, D, 
                a_bold, a_0_bold, o_history, s_history, pi):
    """
    Parameters
    ----------
    pi: array
        The policy decided in the previous timestep. The picked policy has to be consistent with the
        previously adopted policies up to the present time.
    i: integer
        Index of present time. Starts with 0 for the first timestep.
    """
    # history of the actions performed until now
    pi_history = pi[:i]
    
    Gs = []
    PI = []
    # loop over possible continuations of the policy
    for remaining_pi in product(np.arange(n_actions), repeat=T-i):
        
        pi = pi_history + remaining_pi
        s_pi = gradient_descent(n_states, n_obs, T, B, pi, A_bold, D, a_bold, o_history, s_history)
        # s_pi[-1] because free energy of policy is minimized wrt to last timestep
        G_pi = expected_free_energy(s_pi[-1], A_bold, C, a_bold, a_0_bold)
        Gs.append(G_pi)
        PI.append(pi)
    print("time ", i)
    print("pi:   ", np.array(pi))
    print("s_pi: ", np.argmax(s_pi, axis=1), "\n")
    Q_pi = softmax(-np.array(Gs))
    index_pi = np.random.choice(np.arange(len(PI)), p=Q_pi)
    return PI[index_pi]


def generate_history_with_active_inference(n_states, n_obs, n_actions, T, B, A, C, D):

    # assume for simplicity that the agent's prior about A accurately reflect reality
    a_bold = a
    # expected value of A, calculated by normalizing the columns of a
    A_bold = a / np.sum(a, axis=0)
    a_0_bold = np.sum(a_bold, axis=1, keepdims=True)
    
    s = np.random.choice(n_states, p=D)
    
    # I can initialize it as an empty tuples, it doesn't matter because
    # on the first round none of it is used.
    pi = ()

    s_history, o_history = [], []
    # in a loop which models the time steps
    for i in range(T):
        
        s_history.append(s)
        
        # sample observation given A and state s
        o = np.random.choice(n_obs, p=A[:,s])
        o_history.append(o)
        
        # update policy by performing active inference
        # actions influence state at i+1
        # action at timestep T doesn't matter, because it could only influence state T+1
        product(np.arange(n_actions), repeat=T)
        pi = pick_policy(n_states, n_obs, n_actions, i, T, B, A_bold, C, D, a_bold, a_0_bold,
                               o_history, s_history, pi)
        
        # update state s_i given policy, B, and previous state
        # not appended on the last timestep, because it always depends on decision in previous timestep
        # so this would be at T+1, in the list index i+1
        s = np.random.choice(n_states, p=B[pi[i],:,s])
        
    return o_history, s_history, pi

In [113]:
T = 10

n_states, n_obs, n_actions = 2,3,2

D = np.array([
    0.9, 0.1
])

B = np.array([
    [[0.999, 0.999],
     [0.001, 0.001]],
    
    [[0.001, 0.001],
     [0.999, 0.999]]
])

A = np.array([
    [0.8, 0.1],
    [0.1, 0.8],
    [0.1, 0.1]
])

a = A

# preference expressed in terms of states (formula is different when preference is wrt outcomes)
C = np.array([0.999, 0.001])

In [138]:
o_history, s_history, pi = generate_history_with_active_inference(n_states, n_obs, n_actions, T, B, A, C, D)

print("s history: ", s_history)
print("pi:        ", pi)
print("o history: ", o_history)

time  0
pi:    [1 1 1 1 1 1 1 1 1 1]
s_pi:  [0 1 1 1 1 1 1 1 1 1] 

time  1
pi:    [1 1 1 1 1 1 1 1 1 1]
s_pi:  [0 1 1 1 1 1 1 1 1 1] 

time  2
pi:    [1 0 1 1 1 1 1 1 1 1]
s_pi:  [0 1 0 1 1 1 1 1 1 1] 

time  3
pi:    [1 0 1 1 1 1 1 1 1 1]
s_pi:  [0 1 0 1 1 1 1 1 1 0] 

time  4
pi:    [1 0 1 1 1 1 1 1 1 1]
s_pi:  [0 1 0 1 1 1 1 1 1 1] 

time  5
pi:    [1 0 1 1 0 1 1 1 1 1]
s_pi:  [0 1 0 1 1 0 1 1 1 0] 

time  6
pi:    [1 0 1 1 0 0 1 1 1 1]
s_pi:  [0 1 0 1 1 0 0 1 1 1] 

time  7
pi:    [1 0 1 1 0 0 0 1 1 1]
s_pi:  [0 1 0 1 1 0 0 0 1 1] 

time  8
pi:    [1 0 1 1 0 0 0 1 1 1]
s_pi:  [0 1 0 1 1 0 0 0 1 0] 

time  9
pi:    [1 0 1 1 0 0 0 1 1 1]
s_pi:  [0 1 0 1 1 0 0 0 1 1] 

s history:  [0, 1, 0, 1, 1, 0, 0, 0, 1, 1]
pi:         (1, 0, 1, 1, 0, 0, 0, 1, 1, 1)
o history:  [0, 1, 0, 1, 1, 0, 0, 2, 1, 2]
