In [None]:
# Imports

import stable_baselines3
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, CallbackList, CheckpointCallback
import sb3_contrib
import panda_gym.envs
import gymnasium as gym

import numpy as np
%matplotlib inline

In [None]:
env = panda_gym.envs.PandaReachEnv(control_type="joints")
print(env.observation_space)
print(env.action_space)

In [None]:
class ActionPenalizerWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        reward -= np.mean(abs(action)) * 0.01
        return obs, reward, terminated, truncated, info

def getenv():
    env = panda_gym.envs.PandaPushEnv(control_type="joints", reward_type="dense", render_mode="rgb_array")
    env.task.distance_threshold = -1
    # flatten wrapper
    env = gym.wrappers.FlattenObservation(env)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
    # env = gym.wrappers.TransformReward(env, lambda r: -r**2)
    env = ActionPenalizerWrapper(env)
    # env = gym.wrappers.RecordVideo(env, "./runs/RecurrentPPO", lambda ep: ep % 100 == 0)
    return env

n_envs = 1
vec_env = stable_baselines3.common.env_util.make_vec_env(getenv, n_envs=n_envs)
# vec_env = stable_baselines3.common.vec_env.VecVideoRecorder(vec_env, "./runs/RecurrentPPO", lambda ep: ep % 10 == 0, video_length=100)

if True:
    model = sb3_contrib.RecurrentPPO(
        'MlpLstmPolicy',
        vec_env,
        verbose=0,
        tensorboard_log=f"./runs/RecurrentPPO/",
        policy_kwargs=dict(net_arch=[128, 64]),
        n_steps= 2048//n_envs,
        gamma=0.99,
        ent_coef=.01
    )
    stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=5, verbose=1)
    eval_callback = EvalCallback(vec_env, eval_freq=20000, callback_after_eval=stop_train_callback, verbose=1)
    checkpoint_callback = CheckpointCallback(save_freq=20000, save_path="./runs/RecurrentPPO", name_prefix="RecurrentPPO")
    callback = CallbackList([eval_callback, checkpoint_callback])
    model.learn(total_timesteps=1000000, tb_log_name="versus random", progress_bar=False, log_interval=1, callback=callback)
    model.save(f"./runs/RecurrentPPO2")
else:
    model = sb3_contrib.RecurrentPPO.load(f"./runs/RecurrentPPO.zip")

In [None]:
model

env = gym.wrappers.RecordVideo(getenv(), "./runs/RecurrentPPO")
vec_env = stable_baselines3.common.env_util.make_vec_env(lambda: env)

obs = vec_env.reset()
# cell and hidden state of the LSTM
lstm_states = None
num_envs = 1
# Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool)
score = 0
while True:
    action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
    obs, rewards, dones, info = vec_env.step(action)
    score+=rewards.mean()
    episode_starts = dones
    if dones.all():
        break
print(score)