In [12]:
from spm import *
from spm.__wrapper__ import Runtime
from scipy.linalg import block_diag
import numpy as np

```
% set up and preliminaries
%==========================================================================
rng('default')
```

In [6]:
Runtime.call("rng", "default");

```
% observation probabilities
%--------------------------------------------------------------------------
a      = .9;
A{1,1} = [.5 .5; .5 .5; 0 0; 0 0];
A{2,2} = [0 0; 0 0; a (1 - a); (1 - a) a];
A{3,3} = [0 0; 0 0; (1 - a) a; a (1 - a)];
A{4,4} = [1 0; 0 1; 0 0; 0 0];
A      = spm_cat(A);
```

In [13]:
a = 0.9
A1 = np.array([[0.5, 0.5], [0.5, 0.5], [0, 0], [0, 0]])
A2 = np.array([[0, 0], [0, 0], [a, (1 - a)], [(1 - a), a]])
A3 = np.array([[0, 0], [0, 0], [(1 - a), a], [a, (1 - a)]])
A4 = np.array([[1, 0], [0, 1], [0, 0], [0, 0]])
A = block_diag(A1, A2, A3, A4);

```
% transition probabilities
%--------------------------------------------------------------------------
for i = 1:4
    B{i} = zeros(4,4);
    B{i}([2 3],[2 3]) = eye(2);
    B{i}(i,[1 4])     = 1;
    B{i} = kron(B{i},eye(2));
end
```

In [24]:
B = [[]] * 4
i = 0
for i, _ in enumerate(B):
    B[i] = np.zeros((4, 4))
    B[i][1:3, 1:3] = np.eye(2)
    B[i][i, [0, 3]] = 1
    B[i] = np.kron(B[i], np.eye(2))

```
% priors: softmax(utility)
%--------------------------------------------------------------------------
c  = 2;
C  = spm_softmax(kron(ones(4,1),[0; 0; c; -c]));

% prior beliefs about initial state
%--------------------------------------------------------------------------
D  = kron([1 0 0 0],[1 1]/2)';

% true initial state
%--------------------------------------------------------------------------
S  = kron([1 0 0 0],[1 0])';


% allowable policies (of depth T)
%--------------------------------------------------------------------------
V  = [1  1  1  1  2  2  2  2  3  3  3  3  4  4  4  4
      1  2  3  4  1  2  3  4  1  2  3  4  1  2  3  4
      1  2  3  4  1  2  3  4  1  2  3  4  1  2  3  4];
```

In [44]:
c = 2
C = spm_softmax(np.kron(np.ones((4, 1)), np.array([0, 0, c, -c]).reshape((-1, 1))))

D = np.kron([1, 0, 0, 0], [0.5, 0.5])

S = np.kron([1, 0, 0, 0], [1, 0]).reshape((-1, 1))

V = np.array(
    [
        1,
        1,
        1,
        1,
        2,
        2,
        2,
        2,
        3,
        3,
        3,
        3,
        4,
        4,
        4,
        4,
        1,
        2,
        3,
        4,
        1,
        2,
        3,
        4,
        1,
        2,
        3,
        4,
        1,
        2,
        3,
        4,
        1,
        2,
        3,
        4,
        1,
        2,
        3,
        4,
        1,
        2,
        3,
        4,
        1,
        2,
        3,
        4,
    ]
);

```
% MDP Structure
%==========================================================================
MDP.N = 8;                          % number of variational iterations
MDP.S = S;                          % true initial state
MDP.A = A;                          % observation model
MDP.B = B;                          % transition probabilities (priors)
MDP.C = C;                          % terminal cost probabilities (priors)
MDP.D = D;                          % initial state probabilities (priors)
MDP.V = V;                          % allowable policies

MDP.alpha  = 64;                    % gamma hyperparameter
MDP.beta   = 4;                     % gamma hyperparameter
MDP.lambda = 1/4;                   % precision update rate
```

In [48]:
MDP = Struct()

MDP.N = 8  # number of variational iterations
MDP.S = S  # true initial state
MDP.A = A  # observation model
MDP.B = B  # transition probabilities (priors)
MDP.C = C  # terminal cost probabilities (priors)
MDP.D = D  # initial state probabilities (priors)
MDP.V = V  # allowable policies

MDP.alpha = 64  # gamma hyperparameter
MDP.beta = 4  # gamma hyperparameter
MDP["lambda"] = 0.25  # precision update rate

```
% Solve - an example game
%==========================================================================
spm_figure('GetWin','Figure 1'); clf
MDP.plot = gcf;
MDP      = spm_MDP_game(MDP,'FE');
```

In [49]:
MDP["plot"] = True
MDP = spm_MDP_game(MDP, "FE");