## Notes on [Wang08](https://webdocs.cs.ualberta.ca/~dale/papers/dualdp.pdf)

[Peter Dayan (1993)](http://www.gatsby.ucl.ac.uk/~dayan/papers/d93b.pdf)'s successor representation

In [1]:
import numpy as np, matplotlib.pyplot as pl
from rl.mdp import random_MDP, random_dist
from arsenal import iterview

mdp = random_MDP(S=10, A=3, b=2, r=1, γ=0.85)

π = policy = random_dist(mdp.S, mdp.A)

S = range(mdp.S); A = range(mdp.A)          # set of states and set of actions
γ = mdp.γ
V = mdp.V(π)
Q = mdp.Q(π)
Adv = mdp.Advantage(π)
R = mdp.r
P = mdp.P

Π = mdp.Π(π)
M = mdp.successor_representation(π, normalize=True)   

The ordinary successor features:
$$
\newcommand{\tup}[1]{\langle #1 \rangle}
\Phi(i,k) = (1-\gamma)\, 1(i=k) + \gamma \, \sum_a \sum_j \Phi(j, k) \, P(j \mid i, a)\, \pi(a \mid i)
$$

The state$-$action extension:
$$
W(\tup{k,c} \mid \tup{i,a}) = (1-\gamma) \, 1(\tup{i,a}=\tup{k,c}) + \gamma \, \sum_{\tup{j,b}} W(\tup{k,c} \mid \tup{j,b}) \, P(\tup{j,b} \mid \tup{i,a})
$$

$$
P(\tup{j,b} \mid \tup{i,a}) = \pi(b \mid j) \, p(j \mid i, a) 
$$

In [2]:

# H = (1 - γ)*np.eye(M.S*M.A) + γ * np.reshape(M.S*M.A, M.s) @ P @ M.Π(π) @ H
# (np.eye(M.S*M.A) - γ * np.reshape(M.S*M.A, M.s) @ P @ M.Π(π)) @ H = (1 - γ)*np.eye(M.S*M.A) 

# Wang08's H matrix, which I'll call W, is a Markov chain over (s,a) -> (s'', a'')
W = mdp.sasa_matrix(π, normalize=True)

# Lemma 10 W ≥ 0 and W @ 1 = 1
assert np.all(W >= 0)
assert np.allclose(1.0, np.einsum('iakc->ia', W))

for i in S:
    for a in A:
        assert np.allclose(Q[i,a]*(1-γ), sum(W[i,a,k,b] * R[k,b] for k in S for b in A))
assert np.allclose(Q*(1-γ), np.einsum('iakb,kb->ia',W,R))

# Check that W solves our equations
for k in S:
    for c in A:
        for i in S:
            for a in A:
                np.allclose(
                    W[i,a,k,c],
                    (1-γ)*((i,a)==(k,c)) + γ*sum(W[j,b,k,c] * π[j,b] * P[i,a,j] for j in S for b in A)
                )                

In [3]:
# Lemma 13
assert np.allclose(V, Π @ Q.flat)

In [4]:
# Lemma 14
assert np.allclose(M @ Π, Π @ W.reshape(mdp.S*mdp.A, mdp.S*mdp.A))