In [None]:
import numpy as np


class MDP():
  def __init__(self):
    # Discount factor
    self.γ = 0.95
    self.A = [0, 1]
    self.S = [0, 1, 2]

    # Transition matrix if dont' invest
    P0 = np.array([[1, 0, 0],
                   [0.1, .75, 0.15],
                   [0.05, .1, 0.85]])

    R0 = np.array([0, 1, 2])

    # Transition matrix if invest
    P1 = np.array([[1, 0, 0],
                   [0.05, .75, 0.2],
                   [0.02, .06, 0.92]])
    R1 = np.array([0, 0.5, 1.5])

    self.P = [P0, P1]
    self.R = [R0, R1]

  def step(self, s, a):
    s_prime = np.random.choice(len(self.S), p=self.P[a][s])
    R = self.R[a][s]
    if s_prime == 0:
      done = True
    else:
      done = False
    return s_prime, R, done

  def simulate(self, s, a, π):
    done = False
    t = 0
    history = []
    while not done:
      if t > 0:
        a = π[s]
      s_prime, R, done = self.step(s, a)
      history.append((s, a, R))
      s = s_prime
      t += 1

    return history

In [None]:
mdp = MDP()

# Estimating Qπ
S = np.zeros(3)
q = np.zeros((3, 2))


s = 1
α = 0.001
ϵ = 0.01


def ε_greedy(q, s, ε=0.01):
  a = np.argmax(q[s])

  # Draw a random uniform
  u = np.random.uniform()
  if u < ε:
    a = np.random.randint(2)
  return a


for iteration in range(10000000):
    # Choose action for today:
    a = ε_greedy(q, s, ε=ε)
    s_prime, R, done = mdp.step(s, a)

    # SARSA update
    a_prime = ε_greedy(q, s_prime, ε=ε)
    # TD_target = R + mdp.γ * q[s_prime, a_prime]

    # Q-learning upate
    TD_target = R + mdp.γ * np.max(q[s_prime, :])

    δ = TD_target - q[s, a]

    # TD update
    q[s, a] = q[s, a] + α * δ

    if done:
      s = 1
    else:
      s = s_prime

    if iteration % 100000 == 0:
      print(q)




[[0.    0.   ]
 [0.001 0.   ]
 [0.    0.   ]]
[[ 0.          0.        ]
 [10.65011844  2.17458558]
 [15.05973152  2.8557018 ]]
[[ 0.          0.        ]
 [11.04810557  4.10582443]
 [15.48841972  5.47889054]]
[[ 0.          0.        ]
 [11.22567641  5.86693087]
 [15.62173416  7.83134999]]
[[ 0.          0.        ]
 [11.27943941  7.05566645]
 [16.01981484  9.73571699]]
[[ 0.          0.        ]
 [11.21797151  8.09156419]
 [15.72665754 11.07754874]]
[[ 0.          0.        ]
 [11.29620872  8.71380643]
 [15.84430382 12.13602133]]
[[ 0.          0.        ]
 [11.26873391  9.42767876]
 [15.81961041 13.06782105]]
[[ 0.          0.        ]
 [11.26367754  9.81101607]
 [15.93136022 13.62906443]]
[[ 0.          0.        ]
 [11.37106551 10.35210426]
 [16.15154794 14.2419797 ]]
[[ 0.          0.        ]
 [11.26285309 10.61459046]
 [15.77598112 14.65604784]]
[[ 0.          0.        ]
 [11.55528041 10.84436947]
 [16.0156699  15.0231247 ]]
[[ 0.          0.        ]
 [10.99418999 12.50640652

KeyboardInterrupt: ignored