In [1]:
import numpy as np
from numpy import dot
from numpy.linalg import pinv

from functools import reduce

from features import *
from mdptools import *

In [2]:
def td_solution(P, R, phi_func, gm_func, lm_func):
    # TODO: check parameters
    ns = len(P)
    states = state_vectors(P)
    I = np.eye(ns)
    X = feature_matrix(states, phi_func)
    G = np.diag([gm_func(s) for s in states])
    L = np.diag([lm_func(s) for s in states])
    
    # Need to be able to compute the distribution matrix...
    
    # Solve the system of equations
    b = np.dot(X.T, R)
    A = mult(X.T, (I - np.dot(P, G)), X)
    return np.dot(pinv(A), b)

In [3]:
def etd_solution(P, R, phi_func, gm_func, lm_func, i_func):
    ns = len(P)
    states = state_vectors(P)
    
    # Compute matrices/vectors for state-dependent parameter functions
    I = np.eye(ns)
    X = feature_matrix(states, phi_func)
    G = np.diag([gm_func(s) for s in states])
    L = np.diag([lm_func(s) for s in states])
    ivec = np.array([i_func(s) for s in states])

    # Compute intermediate values
    # Need to be able to compute the distribution matrix...
    D = np.eye(ns)                       # placeholder distribution matrix 
    d_i = np.dot(D, ivec)                # interest-weighted distribution 
    P_trace = pinv(I - mult(P, G, L))    # trace reweighting matrix
    P_gm = I - np.dot(P, G)              # gamma-discounted occupancy 
    P_disc = I - np.dot(P_trace, P_gm)   # trace-weighted distribution
    mvec = np.dot(pinv(I - P_disc), d_i) # emphasis vector
    M = np.diag(mvec)                    # emphasis matrix

    # Solve the system of equations
    b = mult(X.T, M, P_trace, R)
    A = mult(X.T, M, P_trace, P_gm, X)

    return np.dot(pinv(A), b)

In [8]:
P = np.array([
        [0, 0.5, 0, 0.5, 0],
        [0.5, 0, 0.5, 0, 0],
        [0, 0.5, 0, 0, 0.5],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1],
    ])
R = [0,0,1/2,0,0]
indices = state_indices(P) 
states = state_vectors(P) 
ns = len(states)
terminals = [as_tuple(s) for s in find_terminals(P)]
phi = Wrap(Bias(), terminals=terminals)
# phi = Wrap(Identity(ns))
gmfunc = Constant(1.0, terminals=terminals)
lmfunc = Constant(0.0, terminals=terminals)
ifunc = Constant(1.0, terminals=terminals)



In [9]:
etd_solution(P, R, phi, gmfunc, lmfunc, ifunc)

array([ 0.5])

In [10]:
td_solution(P, R, phi, gmfunc, lmfunc)

array([ 0.5])

In [18]:
# Expected number of visits
nn = len(find_nonterminals(P))
Q = P[:nn, :nn]
N = pinv(np.eye(nn) - Q)

In [19]:
N[1]

array([ 1.,  2.,  1.])