In [124]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

from rllib.agent import RandomAgent
from rllib.policy import RandomPolicy
from rllib.util import rollout_policy, rollout_agent
from rllib.dataset import TrajectoryDataset

from rllib.environment.systems import InvertedPendulum, CartPole
from rllib.environment import SystemEnvironment, GymEnvironment
import numpy as np 
from torch.utils.data import DataLoader
import torch


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [137]:
def termination_func(state):
    return np.abs(state[..., 0]) >= np.deg2rad(45)

def reward_func(state, action):
    return np.exp(-0.5 / (0.2 ** 2) * state[..., 0] ** 2)  

system = InvertedPendulum(mass=0.3, length=0.5, friction=0.005)
# system = CartPole(pendulum_mass=0.3, cart_mass=1, 
#                   length=0.5, rot_friction=0.005)
# 
initial_state = np.array([np.deg2rad(20), 0.])
# 
environment = SystemEnvironment(system, initial_state=initial_state, max_steps=50)
# environment = GymEnvironment('Pendulum-v0')

policy = RandomPolicy(dim_action=environment.dim_action, dim_state=environment.dim_observation,
                      num_action=environment.num_action
                      )

agent = RandomAgent(dim_action=environment.dim_action, dim_state=environment.dim_observation,
                    num_action=environment.num_action
                    )

dataset = TrajectoryDataset(sequence_length=4)
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)


In [138]:
for episode in range(10):
    trajectory = rollout_policy(environment, policy)
    dataset.append(trajectory)
    print(len(trajectory))
    
for episode in range(10):
    rollout_agent(environment, agent)


50
50
50
50
50
50
50
50
50
50


In [140]:
dataset.shuffle()
dataloader = DataLoader(dataset, batch_size=32, num_workers=0, shuffle=False)
states = []
for epoch in range(2):
    for observation in dataloader:
        state, action, reward, next_state, done = observation
        print(state.shape)
        states.append(state)
        print(epoch)
        break 
    

torch.Size([32, 4, 2])
0
torch.Size([32, 4, 2])
1


In [145]:
print(state)

tensor([[[ 0.6238,  2.3849],
         [ 0.6495,  2.7492],
         [ 0.6769,  2.7343],
         [ 0.7050,  2.8828]],

        [[ 1.0919,  3.9832],
         [ 1.1337,  4.3679],
         [ 1.1779,  4.4898],
         [ 1.2234,  4.6046]],

        [[ 0.7327,  2.6738],
         [ 0.7600,  2.7834],
         [ 0.7898,  3.1786],
         [ 0.8224,  3.3493]],

        [[ 0.3647,  0.1089],
         [ 0.3659,  0.1137],
         [ 0.3669,  0.0882],
         [ 0.3689,  0.3151]],

        [[ 0.4745,  1.4174],
         [ 0.4890,  1.4870],
         [ 0.5041,  1.5326],
         [ 0.5204,  1.7345]],

        [[ 0.7221,  2.7507],
         [ 0.7502,  2.8561],
         [ 0.7789,  2.9021],
         [ 0.8085,  3.0069]],

        [[ 0.3849,  0.9128],
         [ 0.3935,  0.8232],
         [ 0.4024,  0.9533],
         [ 0.4127,  1.1093]],

        [[ 0.4338,  1.3746],
         [ 0.4479,  1.4332],
         [ 0.4626,  1.5067],
         [ 0.4771,  1.4092]],

        [[ 0.3831,  0.9771],
         [ 0.3927,  0.9501]