In [None]:
from framework import Featurizer
from dqn import DQNAgent, DQNTrainer

from PIL import Image
import gym
import numpy as np

In [None]:
def to_gif(matrices, filepath, duration=25):
    frames = []
    for matrix in matrices:
        image = Image.fromarray(matrix)
        frames.append(image)
    frames[0].save(filepath, save_all=True, append_images=frames[1:], duration=duration, loop=0)

## Inverted Pendulum

In [None]:
env = gym.make('InvertedDoublePendulum-v4', render_mode='rgb_array')
env.action_space = np.linspace(env.action_space.low, env.action_space.high, 21) # discretize action space
state_dim = env.observation_space.shape[0]

featurizer = Featurizer(state_dim) # apply adaptive scaling to state vectors
agent = DQNAgent( 
                 input_dim=state_dim, # neural net params
                 output_dim=env.action_space.shape[0],
                 hidden_dim=128,
                 hidden_layers=5,
                 batch_size=256,
                 gamma=0.99, # discount factor
                 min_epsilon=0.1, epsilon_decay=0.999, # exploration rate and decay
                 tau=0.005 # update rate of target net
                )

trainer = DQNTrainer(env, agent, featurizer)

In [None]:
trainer.train(episodes=1000)

In [None]:
info = trainer.run_episode(False)
print(f"cumulative reward: {info['reward']:.2f}, steps: {info['steps']}")
trainer.plot_losses()

## Ball and Beam Problem

In [None]:
import ballbeam_gym.envs
import warnings
warnings.filterwarnings('ignore')

env = ballbeam_gym.envs.BallBeamSetpointEnv(timestep=0.02, setpoint=-0.8, beam_length=2.0, max_angle=0.5, max_timesteps=500, action_mode='discrete')
env.action_space = np.arange(3)
state_dim = env.observation_space.shape[0]

featurizer = Featurizer(state_dim)
agent = DQNAgent(state_dim, env.action_space.shape[0], batch_size=128, epsilon_decay=0.9995)

In [None]:
trainer = DQNTrainer(env, agent, featurizer)
trainer.train(episodes=1000)
trainer.plot_losses()

In [None]:
info = trainer.run_episode()
to_gif(info['rgb_arrays'], 'ball_and_beam.gif', duration=25)