In [59]:
from __future__ import print_function
import numpy as np

from ifqi import envs
from ifqi.evaluation import evaluation
from ifqi.algorithms.fqi.FQI import FQI

from scipy.optimize import curve_fit

In [3]:
mdp = envs.LQG1D()
state_dim, action_dim, reward_dim = envs.get_space_info(mdp)
reward_idx = state_dim + action_dim
discrete_actions = np.array([-8, -7, -6, -5, -4, -3, -2.5, -2, -1.5, -1, -.75,
                             -.5, -.25, 0, .25, .5, .75, 1, 1.5, 2, 2.5, 3, 4,
                             5, 6, 7, 8])
dataset = evaluation.collect_episodes(mdp, n_episodes=200)
sast = np.append(dataset[:, :reward_idx],
                 dataset[:, reward_idx + reward_dim:-1],
                 axis=1)
r = dataset[:, reward_idx]

In [54]:
class Regressor:
    __slots__ = ('b', 'k')
    
    def __init__(self, b=0, k=0):
        self.b = b
        self.k = k
    
    def fit(self, X, y):
        print("fitting: ", end='')
        (self.b, self.k), pcov = curve_fit(self.Q, X, y)
        print(self.b, self.k)
    
    def Q(self, sa, b, k):
        return b - (sa[:, 1] - k *  sa[:, 0]) ** 2
    
    def predict(self, X):
        return self.Q(X, self.b, self.k)

regressor = Regressor()

In [57]:
fqi = FQI(estimator=regressor,
          state_dim=state_dim,
          action_dim=action_dim,
          discrete_actions=discrete_actions,
          gamma=mdp.gamma,
          horizon=mdp.horizon,
          scaled=False,
          features=None,
          verbose=True)

In [58]:

fqi.partial_fit(sast, r)

iterations = 100
iteration_values = []
for i in range(iterations - 1):
    fqi.partial_fit()
    values = evaluation.evaluate_policy(mdp, fqi, initial_states=np.ones((10,1)) * 10)
    print(values)
    iteration_values.append(values[0])

Iteration 1
fitting: -12.9660015441 0.941251765111
Iteration 2
fitting: 88.5479129325 1.26193682795
(-9503817.08067533, 17000.457054607654, 100.0, 0.0)
Iteration 3
fitting: 234.534202539 1.49357663036
(-9497696.9616293013, 16215.559895439013, 100.0, 0.0)
Iteration 4
fitting: 409.613751879 1.67885567536
(-9495078.1809839718, 19302.130913726189, 100.0, 0.0)
Iteration 5
fitting: 605.913164004 1.83445705765
(-9492620.9903859086, 13135.702363058468, 100.0, 0.0)
Iteration 6
fitting: 818.508499435 1.96898303929
(-9481152.4081875496, 19624.158474330761, 100.0, 0.0)
Iteration 7
fitting: 1043.99251738 2.0875634253
(-9501345.8320422657, 12279.057171541892, 100.0, 0.0)
Iteration 8
fitting: 1279.85742669 2.1935506204
(-9505282.9331755005, 19358.202794656194, 100.0, 0.0)
Iteration 9
fitting: 1524.17873339 2.28927905413
(-9495503.0500901043, 15789.175529276752, 100.0, 0.0)
Iteration 10
fitting: 1775.43433823 2.37645202475
(-9515547.9073031545, 13902.825641760348, 100.0, 0.0)
Iteration 11
fitting: 203