In [None]:
!pip install --upgrade pip
!pip install tinyrl gymnasium matplotlib stable-baselines3

In [None]:
# TinyRL
from tinyrl import SAC
# Stable Baselines3
from stable_baselines3 import SAC as SB3_SAC
from stable_baselines3.sac import MlpPolicy
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import torch
# Other
import random, os 
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

custom_environment = {
    "path": os.path.abspath("../custom_environment"),
    "action_dim": 1,
    "observation_dim": 3,
}

In [None]:
n_seeds = 2
n_steps = 10000
evaluation_interval = 1000
environment_name = "Pendulum-v1"
use_python_environment = True

In [None]:
def evaluate_policy(policy, environment_name="Pendulum-v1", seed=0xf00d, n_episodes=10):
    env_replay = gym.make(environment_name)
    env_replay.reset(seed=seed)
    returns = []
    for episode_i in range(n_episodes):
        observation, _ = env_replay.reset()
        finished = False
        rewards = 0
        while not finished:
            action = policy(observation)
            observation, reward, terminated, truncated, _ = env_replay.step(action)
            rewards += reward
            finished = terminated or truncated
        returns.append(rewards)
    return returns

def scale_action(action, env):
    return action * (env.action_space.high - env.action_space.low) / 2.0 + (env.action_space.high + env.action_space.low) / 2.0

In [None]:
def train_sb3():
    returns = []
    for seed in range(n_seeds):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        env = gym.make(environment_name)
        env.reset(seed=seed)
        def policy_factory(obs_dim, action_dim, lr_schedule, **kwargs):
            return MlpPolicy(obs_dim, action_dim, lr_schedule, net_arch=[64, 64])
        model = SB3_SAC(policy_factory, env, learning_starts = 0, batch_size=100)
        returns_step = []
        for evaluation_step_i in range(n_steps // evaluation_interval):
            model.learn(total_timesteps=evaluation_interval, reset_num_timesteps=False)
            def policy(observation):
                return model.predict(observation, deterministic=True)[0]
            returns_step.append(evaluate_policy(policy, environment_name, seed))
        returns.append(returns_step)
    return returns

In [None]:
def train_tinyrl():
    returns = []
    example_env = gym.make(environment_name)
    for seed in range(n_seeds):
        if use_python_environment:
            def env_factory():
                env = gym.make(environment_name)
                env.reset(seed=seed)
                return env
            sac = SAC(env_factory, force_recompile=True)
        else:
            sac = SAC(custom_environment)
        state = sac.State(seed)
        returns_step = []
        for step_i in range(n_steps):
            if step_i % evaluation_interval == 0:
                returns_step.append(evaluate_policy(lambda observation: scale_action(state.action(observation), example_env), environment_name=environment_name, seed=seed))
            state.step()
        returns.append(returns_step)
    return returns

In [None]:
returns_sb3 = train_sb3()
returns_tinyrl = train_tinyrl()

In [None]:
returns_tinyrl = np.array(returns_tinyrl)
returns_sb3 = np.array(returns_tinyrl)

returns_tinyrl_aggregate = returns_tinyrl.mean(axis=-1)
returns_sb3_aggregate = returns_sb3.mean(axis=-1)

returns_tinyrl_mean = returns_tinyrl_aggregate.mean(axis=0)
returns_sb3_mean = returns_sb3_aggregate.mean(axis=0)
returns_tinyrl_std = returns_tinyrl_aggregate.std(axis=0)
returns_sb3_std = returns_sb3_aggregate.std(axis=0)

horizontal = range(0, n_steps, evaluation_interval)
plt.fill_between(horizontal, returns_tinyrl_mean - returns_tinyrl_std, returns_tinyrl_mean + returns_tinyrl_std, alpha=0.1)
plt.plot(horizontal, returns_tinyrl_mean, label="TinyRL")
plt.fill_between(horizontal, returns_sb3_mean - returns_sb3_std, returns_sb3_mean + returns_sb3_std, alpha=0.1)
plt.plot(horizontal, returns_sb3_mean, label="Stable Baselines3")
plt.xlabel("Steps")
plt.ylabel("Returns")
plt.legend()