First of all, download the cartpole dataset as follows:
$ wget https://www.dropbox.com/s/vc7fm7qdnu0kh01/cartpole.csv?dl=1 -O cartpole.csv
Or access to https://www.dropbox.com/s/vc7fm7qdnu0kh01/cartpole.csv .
Follow instruction from upload_dataset
to start_training
.
Finally, you can download the trained policy as export_policy_function
. At this time, you have two options of the model format, TorchScript and ONNX.
You can load the policy in two lines of codes only with PyTorch.
import torch
policy = torch.jit.load('policy.pt')
It's easy, right?
Then you can write the rest of interaction codes as usual.
import gym
env = gym.make('CartPole-v0')
observation = env.reset()
while True:
# feed observation to the policy
action = policy(torch.tensor([observation], dtype=torch.float32))
# take action to get next observation
observation, _, done, _ = env.step(action[0].numpy())
# rendering environment
env.render()
# break if the episode reaches the termination
if done:
break
In this tutorial, onnxruntime is used to load the model.
import onnxruntime as ort
ort_session = ort.InferenceSession('policy.onnx')
Basically, ONNX is also easy to load.
Then you can write the rest of interaction codes like above.
import gym
env = gym.make('CartPole-v0')
observation = env.reset()
while True:
# change dtype strictly to float32 and expand its shape
observation = observation.astype('f4').reshape((1, -1))
# feed observation to the policy
action = ort_session.run(None, {'input_0': observation})[0]
# take action to get next observation
observation, _, done, _ = env.step(action[0])
# rendering environment
env.render()
# break if the episode reaches the termination
if done:
break