In [1]:
from MDPDataset.old_dataset import *
from MDP.old_MDP import MDP
from MDP.ChainBandit import ChainBanditMDP
from MDP.ChainBanditState import ChainBanditState
from BPolicy.ChainBanditPolicy import ChainBanditPolicy
import copy

In [2]:
horizon = 3
mdp = ChainBanditMDP(num_states = horizon)
policy = ChainBanditPolicy(mdp)

In [3]:
neps = 1000
observations = []
actions = []
rewards = []
terminals = []
for eps in range(neps):
    for timestep in range(horizon+1):
        # Get state.
        # Add state to list.
        cur_state = copy.deepcopy(mdp.cur_state)
        observations.append(copy.deepcopy(cur_state.num_list))

        # Get action
        # Add action to list
        cur_action = policy._get_action(state = cur_state)
        actions.append(copy.deepcopy(cur_action))

        # Execute action
        reward, next_state = mdp.execute_agent_action(cur_action)
        # Add reward
        rewards.append(copy.deepcopy(reward))
        
        terminals.append(0)
    mdp.reset()
    terminals[-1] = 1

In [4]:
import numpy as np

observations = np.array(observations)
actions = np.array(actions)
rewards = np.array(rewards)
terminals = np.array(terminals)

dataset = MDPDataset(
    observations=observations,
    actions=actions,
    rewards=rewards,
    terminals=terminals,
)

In [5]:
# first episode
episode = dataset.episodes[20]

# access to episode data
print(episode.observations)
print(episode.actions)
print(episode.rewards)

[[ 1  0]
 [ 2  1]
 [ 3  1]
 [-1 -1]]
[[-1]
 [-1]
 [-1]
 [-1]]
[0.5 0.  0.  0. ]


In [6]:
# first transition
transition = episode.transitions[0]

# access to tuple
print(transition.observation)
print(transition.action)
print(transition.reward)
print(transition.next_observation)
print(transition.terminal)

[1 0]
[-1]
0.5
[2 1]
0.0


In [7]:
episode.transitions[2].next_observation

array([-1, -1])

In [8]:
from OfflineLearners.offlineLearners import VI, PVI, SPVI

In [9]:
vi = VI(name = "vi", states = observations, actions = policy.actions, epLen = horizon)
vi.fit(dataset)

In [10]:
mdp = ChainBanditMDP(num_states = horizon)
neps = 10000
viobservations = []
viactions = []
virewards = []
viterminals = []
for eps in range(neps):
    for timestep in range(horizon):
        # Get state.
        # Add state to list.
        cur_state = copy.deepcopy(mdp.cur_state)
        viobservations.append(copy.deepcopy(cur_state.num_list))

        # Get action
        # Add action to list
        cur_action = vi.act(copy.deepcopy(cur_state.num_list), timestep)
        viactions.append(copy.deepcopy(cur_action))

        # Execute action
        reward, next_state = mdp.execute_agent_action(cur_action)
        # Add reward
        virewards.append(copy.deepcopy(reward))
        
        viterminals.append(0)
    mdp.reset()
    viterminals[-1] = 1
print("vi rewards: ",np.sum(np.array(rewards))/neps)

vi rewards:  0.0332


In [11]:
pvi = PVI(name = "pvi", states = observations, actions = policy.actions, epLen = horizon)
pvi.fit(dataset)

In [12]:
mdp = ChainBanditMDP(num_states = horizon)
neps = 10000
pviobservations = []
pviactions = []
pvirewards = []
pviterminals = []
for eps in range(neps):
    for timestep in range(horizon):
        # Get state.
        # Add state to list.
        cur_state = copy.deepcopy(mdp.cur_state)
        pviobservations.append(copy.deepcopy(cur_state.num_list))

        # Get action
        # Add action to list
        cur_action = pvi.act(copy.deepcopy(cur_state.num_list), timestep)
        pviactions.append(copy.deepcopy(cur_action))

        # Execute action
        reward, next_state = mdp.execute_agent_action(cur_action)
        # Add reward
        pvirewards.append(copy.deepcopy(reward))
        
        pviterminals.append(0)
    mdp.reset()
    pviterminals[-1] = 1
print("pvi rewards: ",np.sum(np.array(pvirewards))/neps)

pvi rewards:  0.5316


In [13]:
spvi = SPVI(name = "pvi", states = observations, actions = policy.actions, epLen = horizon, bpolicy = policy)
spvi.fit(dataset)

In [14]:
mdp = ChainBanditMDP(num_states = horizon)
neps = 10000
spviobservations = []
spviactions = []
spvirewards = []
spviterminals = []
for eps in range(neps):
    for timestep in range(horizon):
        # Get state.
        # Add state to list.
        cur_state = copy.deepcopy(mdp.cur_state)
        spviobservations.append(copy.deepcopy(cur_state.num_list))

        # Get action
        # Add action to list
        cur_action = spvi.act(copy.deepcopy(cur_state.num_list), timestep)
        spviactions.append(copy.deepcopy(cur_action))

        # Execute action
        reward, next_state = mdp.execute_agent_action(cur_action)
        # Add reward
        spvirewards.append(copy.deepcopy(reward))
        
        spviterminals.append(0)
    mdp.reset()
    spviterminals[-1] = 1
print("spvi rewards: ",np.sum(np.array(spvirewards))/neps)

spvi rewards:  0.2815


In [15]:
spvi.agent.shift_prior

{(0,
  0,
  0): array([-1.4773244e-04, -5.7551813e-01, -1.4773244e-04, -1.4773244e-04,
        -1.4773244e-04,  5.7625675e-01, -1.4773244e-04], dtype=float32),
 (0,
  1,
  0): array([-0.00113608,  0.3910861 , -0.00113608, -0.00113608, -0.00113608,
        -0.38540575, -0.00113608], dtype=float32),
 (0,
  2,
  0): array([ 0.00256762,  0.36886388,  0.00256762,  0.00256762,  0.00256762,
        -0.38170207,  0.00256762], dtype=float32),
 (1,
  0,
  0): array([-4.3785572e-04, -4.3785572e-04, -5.6322658e-01, -4.3785572e-04,
        -4.3785572e-04, -4.3785572e-04,  5.6541574e-01], dtype=float32),
 (1,
  1,
  0): array([-0.00116053, -0.00116053,  0.3841989 , -0.00116053, -0.00116053,
        -0.00116053, -0.37839633], dtype=float32),
 (1,
  2,
  0): array([ 0.00319677,  0.00319677,  0.3580551 ,  0.00319677,  0.00319677,
         0.00319677, -0.37403902], dtype=float32),
 (2,
  0,
  0): array([-0.00406211, -0.00406211, -0.00406211,  0.02437264, -0.00406211,
        -0.00406211, -0.00406211], d

In [16]:
R , P = spvi.agent.map_mdp()

In [17]:
P

{(0,
  0): array([0.00561798, 0.00561798, 0.00561798, 0.00561798, 0.00561798,
        0.96629214, 0.00561798], dtype=float32),
 (0,
  1): array([0.00462963, 0.9722222 , 0.00462963, 0.00462963, 0.00462963,
        0.00462963, 0.00462963], dtype=float32),
 (0,
  2): array([0.00833333, 0.95      , 0.00833333, 0.00833333, 0.00833333,
        0.00833333, 0.00833333], dtype=float32),
 (1,
  0): array([0.00813008, 0.00813008, 0.00813008, 0.00813008, 0.00813008,
        0.00813008, 0.9512195 ], dtype=float32),
 (1,
  1): array([0.00740741, 0.00740741, 0.95555556, 0.00740741, 0.00740741,
        0.00740741, 0.00740741], dtype=float32),
 (1,
  2): array([0.01176471, 0.01176471, 0.92941177, 0.01176471, 0.01176471,
        0.01176471, 0.01176471], dtype=float32),
 (2, 0): array([0.01, 0.01, 0.01, 0.94, 0.01, 0.01, 0.01], dtype=float32),
 (2,
  1): array([0.01086957, 0.01086957, 0.01086957, 0.9347826 , 0.01086957,
        0.01086957, 0.01086957], dtype=float32),
 (2,
  2): array([0.02857143, 0.0285