In [15]:
# !pip install swig
# !pip install gymnasium 
# !pip install matplotlib
# !pip install pygame
# !pip install imageio[ffmpeg]

In [16]:
import torch

import gymnasium as gym

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import random

from IPython.display import HTML

In [17]:
if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

TOTAL_STEPS = 1000000
START_EPSILON = 1
END_EPSILON = 0.05
STEP_DECAY_FINAL_STEP = TOTAL_STEPS
LEARNING_RATE = 1e-5
MAX_GRAD_NORM = 5000

GAMMA = 0.99

TIMESTEPS_PER_EPOCH = 100
BATCH_SIZE = 4096

REFRESH_TARGET_NETWORK_FREQ = 500
REPLAY_BUFFER_SIZE = 100000
LOSS_SHOW_FREQ = 50000


In [18]:
env = gym.make("LunarLander-v3", render_mode="rgb_array")

state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n

In [19]:
def run_policy(env, agent):
  frames = list()

  state, info = env.reset()
  for _ in range(1000):

      qvalues = agent.get_qvalues([state])
      action = agent.sample_actions(qvalues)[0]

      state, reward, terminated, truncated, info = env.step(action)

      frames.append(env.render())

      if terminated or truncated:
          observation, info = env.reset()
          break


  env.close()

  return np.array(frames)


def update_animation_canvas(num, frames, animation_canvas):
    animation_canvas.set_data(frames[num])
    return animation_canvas


def animate_policy_actions(env, policy=None):
  frames = run_policy(env, policy)

  fig = plt.figure()
  animation_canvas = plt.imshow(frames[0])

  plt.title("Policy Acting Animation")


  frames_ani = animation.FuncAnimation(fig, update_animation_canvas, frames=len(frames), fargs=(frames, animation_canvas), interval=100)
  plt.close()

  return frames_ani



In [20]:
class DQNAgent(torch.nn.Module):
  def __init__(self, state_dim, n_actions, epsilon=0):
    super(DQNAgent, self).__init__()

    self.epsilon = epsilon
    self.n_actions = n_actions
    self.state_dim = state_dim


    self.qvalue_network_estimator = torch.nn.ModuleList()

    self.qvalue_network_estimator.append(torch.nn.Linear(self.state_dim, 192))
    self.qvalue_network_estimator.append(torch.nn.ReLU())
    self.qvalue_network_estimator.append(torch.nn.Linear(192, 256))
    self.qvalue_network_estimator.append(torch.nn.ReLU())
    self.qvalue_network_estimator.append(torch.nn.Linear(256, 64))
    self.qvalue_network_estimator.append(torch.nn.ReLU())
    self.qvalue_network_estimator.append(torch.nn.Linear(64, self.n_actions))

  def forward(self, state):
    q_values = state
    for layer in self.qvalue_network_estimator:
      q_values = layer(q_values)
    return q_values

  def get_qvalues(self, states):
    states = torch.tensor(states, device=device, dtype=torch.float32)
    q_values = self.forward(states)

    return q_values.data.cpu().numpy()

  def sample_actions(self, q_values):
    batch_size, n_actions = q_values.shape
    random_actions = np.random.choice(n_actions, size=batch_size)
    best_actions = q_values.argmax(axis=-1)
    should_explore = np.random.choice([0, 1], batch_size, p=[1-self.epsilon, self.epsilon])

    return np.where(should_explore, random_actions, best_actions)

In [21]:
class ReplayBuffer():
  def __init__(self, size):
    self.size = size
    self.buffer = list()

  def __len__(self):
    return len(self.buffer)

  def add(self, state, action, reward, next_state, done):
    item = (state, action, reward, next_state, done)

    if len(self.buffer)<self.size:
      self.buffer.append(item)
    else:
      self.buffer[random.randint(0, self.size-1)] = item

  def sample(self, batch_size):
    idxs = np.random.choice(len(self.buffer), batch_size)
    samples = [self.buffer[i] for i in idxs]
    states, actions, rewards, next_states, done_flags = list(zip(*samples))

    return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(done_flags)

In [22]:
def run_and_record(start_state, agent, env, exp_replay, n_steps=1):
  s = start_state
  sum_rewards = 0

  for _ in range(n_steps):
    q_values = agent.get_qvalues([s])
    a = agent.sample_actions(q_values)[0]
    next_s, r, terminated, truncated, _ = env.step(a) ######## done dimensions
    done = terminated or truncated
    sum_rewards+=r
    exp_replay.add(s, a, r, next_s, done)
    if done:
      s, _ = env.reset()
    else:
      s = next_s

  return sum_rewards, s


def compute_td_loss(agent, target_network, states, actions, rewards, next_states, done_flags, gamma=0.99, device=device):
  states = torch.tensor(states, device=device, dtype=torch.float32)
  actions = torch.tensor(actions, device=device, dtype=torch.long)
  rewards = torch.tensor(rewards, device=device, dtype=torch.float32)
  next_states = torch.tensor(next_states, device=device, dtype=torch.float32)
  done_flags = torch.tensor(done_flags.astype("float32"), device=device, dtype=torch.float32)

  predicted_qvalues = agent(states)
  predicted_next_qvalues = target_network(next_states)
  predicted_qvalues_for_actions = predicted_qvalues[range(len(actions)), actions] ##### check for :

  next_state_values, _ = torch.max(predicted_next_qvalues, dim=1)
  target_qvalues_for_actions = rewards + gamma*next_state_values*(1-done_flags)
  loss = torch.mean((predicted_qvalues_for_actions - target_qvalues_for_actions.detach())**2) ### double check

  return loss

def epsilon_schedule(start_eps, end_eps, step, final_step):
  return start_eps + (end_eps - start_eps)*min(step, final_step)/final_step

In [23]:
agent = DQNAgent(state_dim, n_actions, epsilon=0.5).to(device)
target_network = DQNAgent(state_dim, n_actions, epsilon=0.5).to(device)
target_network.load_state_dict(agent.state_dict())

exp_replay = ReplayBuffer(REPLAY_BUFFER_SIZE)

opt = torch.optim.Adam(agent.parameters(), lr=LEARNING_RATE)

In [24]:
frames_animation = animate_policy_actions(env, agent)
#HTML(frames_animation.to_html5_video())
frames_animation.save("animation_before_training.gif", writer="pillow")

In [None]:
state, _ = env.reset()

for step in range(TOTAL_STEPS):
  opt.zero_grad()

  agent.epsilon = epsilon_schedule(START_EPSILON, END_EPSILON, step, STEP_DECAY_FINAL_STEP)
  _, state = run_and_record(state, agent, env, exp_replay, TIMESTEPS_PER_EPOCH)

  states, actions, rewards, next_states, done_flags = exp_replay.sample(BATCH_SIZE)

  loss = compute_td_loss(agent, target_network, states, actions, rewards, next_states, done_flags, gamma=GAMMA, device=device)
  loss.backward()
  grad_norm = torch.nn.utils.clip_grad_norm_(agent.parameters(), MAX_GRAD_NORM)

  opt.step()

  if step%LOSS_SHOW_FREQ == 0:
    print(f"Step: {step}, loss: {loss.data.cpu().item():.6f}")
  if step%REFRESH_TARGET_NETWORK_FREQ == 0:
    target_network.load_state_dict(agent.state_dict())




Step: 0, loss: 104.813995
Step: 50000, loss: 29.936726
Step: 100000, loss: 20.897337
Step: 150000, loss: 15.799156
Step: 200000, loss: 9.514006
Step: 250000, loss: 50.602657
Step: 300000, loss: 75.415100
Step: 350000, loss: 31.569639
Step: 400000, loss: 42.221207
Step: 450000, loss: 20.224836
Step: 500000, loss: 14.642189
Step: 550000, loss: 16.187210
Step: 600000, loss: 18.033875
Step: 650000, loss: 15.197511
Step: 700000, loss: 8.487288


In [None]:
frames_animation = animate_policy_actions(env, agent)
torch.save(agent.state_dict(), "trained_agent.pth")
#HTML(frames_animation.to_html5_video())
frames_animation.save("animation_after_training.gif", writer="pillow")