In [None]:
from env.balancebot_env import BalancebotEnv
from stable_baselines import PPO2
from stable_baselines.common.policies import FeedForwardPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.bench import Monitor

import os
import time
import torch

In [None]:
log_dir = "/tmp/gym/{}".format(int(time.time()))
os.makedirs(log_dir, exist_ok=True)

In [None]:
# Create the environment
def make_env(rank):
    def _init():
        env = BalancebotEnv(render=False)
        env = Monitor(env, os.path.join(log_dir, str(rank)))
        return env
    return _init

num_cpu = 16
env = SubprocVecEnv([make_env(rank=i) for i in range(num_cpu)])


In [None]:
hidden_1_dim = 32
hidden_2_dim = 16

In [None]:
# Create the RL Agwnt
class CustomPolicy(FeedForwardPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomPolicy, self).__init__(*args, **kwargs,
                                           layers=[hidden_1_dim, hidden_2_dim],
                                           feature_extraction="mlp")

model = PPO2(CustomPolicy, env, verbose=1, tensorboard_log=log_dir+"/tensorboard")


In [None]:
observation_dim = env.observation_space.shape[0]

In [None]:
weight_encoder = torch.rand(observation_dim, hidden_1_dim,requires_grad=False).numpy()

In [None]:
weight_encoder.shape

## How do I put the weight of encoder into model ?
## I want to put the weight into 'pi_fc0' and 'vf_fc0'
 

![Network](assets/network.png)

In [None]:
# Train and Save the agent
model.learn(total_timesteps=1e3, tb_log_name="PPO2")
model.save("ppo_save")

In [None]:
log_dir+"/tensorboard/"

## You can open tensorboard at terminal
## For example:
### tensorboard --logdir log_dir+"/tensorboard"

In [None]:
# delete trained model to demonstrate loading
del model 

In [None]:
# Create the evaluation env
env = DummyVecEnv([lambda: BalancebotEnv(render=False)])

In [None]:
# Load the trained agent
model = PPO2.load("ppo_save", env=env, policy=CustomPolicy)

In [None]:

# Enjoy trained agent
for ep in range(10):
    obs = env.reset()
    dones = False
    while not dones:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)