Skip to content

Latest commit

 

History

History
107 lines (65 loc) · 2.14 KB

tutorials.rst

File metadata and controls

107 lines (65 loc) · 2.14 KB

Tutorials

CartPole

Download dataset

First of all, download the cartpole dataset as follows:

$ wget https://www.dropbox.com/s/vc7fm7qdnu0kh01/cartpole.csv?dl=1 -O cartpole.csv

Or access to https://www.dropbox.com/s/vc7fm7qdnu0kh01/cartpole.csv .

Train

Follow instruction from upload_dataset to start_training.

Deploy

Finally, you can download the trained policy as export_policy_function. At this time, you have two options of the model format, TorchScript and ONNX.

TorchScript

You can load the policy in two lines of codes only with PyTorch.

import torch

policy = torch.jit.load('policy.pt')

It's easy, right?

Then you can write the rest of interaction codes as usual.

import gym

env = gym.make('CartPole-v0')

observation = env.reset()

while True:
    # feed observation to the policy
    action = policy(torch.tensor([observation], dtype=torch.float32))

    # take action to get next observation
    observation, _, done, _ = env.step(action[0].numpy())

    # rendering environment
    env.render()

    # break if the episode reaches the termination
    if done:
        break

ONNX

In this tutorial, onnxruntime is used to load the model.

import onnxruntime as ort

ort_session = ort.InferenceSession('policy.onnx')

Basically, ONNX is also easy to load.

Then you can write the rest of interaction codes like above.

import gym

env = gym.make('CartPole-v0')

observation = env.reset()

while True:
    # change dtype strictly to float32 and expand its shape
    observation = observation.astype('f4').reshape((1, -1))

    # feed observation to the policy
    action = ort_session.run(None, {'input_0': observation})[0]

    # take action to get next observation
    observation, _, done, _ = env.step(action[0])

    # rendering environment
    env.render()

    # break if the episode reaches the termination
    if done:
        break