In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import numpy as np 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from tqdm import tqdm, trange
import wandb

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
	from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

RAM_mask = np.asarray([14, 16, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117])

In [2]:
def make_env(env_id, seed, idx, run_name): 
	def thunk(): 
		env = gym.make(env_id)
		env.action_space.seed(seed) 
		env.observation_space.seed(seed) 
		return env 
	return thunk() 

In [3]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 
	nn.init.orthogonal_(layer.weight, std) 
	nn.init.constant_(layer.bias, bias_const)
	return layer 

class Agent(nn.Module): 
	def __init__(self, envs): 
		super().__init__() 
		self.critic = nn.Sequential(
			layer_init(nn.Linear(RAM_mask.size, 64)), #np.array(envs.single_observation_space.shape).prod(), 64)), 
			nn.Tanh(), 
			layer_init(nn.Linear(64, 64)), 
			nn.Tanh(), 
			layer_init(nn.Linear(64, 1), std=1.0)
		)

		self.actor = nn.Sequential(
			layer_init(nn.Linear(RAM_mask.size, 64)), 
			nn.Tanh(), 
			layer_init(nn.Linear(64, 64)), 
			nn.Tanh(), 
			layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01)
		)

	def get_value(self, x): 
		return self.critic(x) 
	
	def get_action_and_value(self, x, action=None): 
		logits = self.actor(x) 
		probs = torch.distributions.categorical.Categorical(logits=logits) 
		if action is None: 
			action = probs.sample() 
		return action, probs.log_prob(action), probs.entropy(), self.critic(x) 

In [4]:
seed = 1337 
num_envs = 4
num_steps = 2048

envs = gym.vector.SyncVectorEnv([
    lambda: make_env('ALE/Freeway-ram-v5', seed+i, i, 'the first') for i in range(num_envs)
])

agent = Agent(envs).to(device) 
optimizer = optim.Adam(agent.parameters(), lr=2.5e-4, eps=1e-5)

In [5]:
obs = torch.zeros((num_steps, num_envs, RAM_mask.size)).to(device)
actions = torch.zeros((num_steps, num_envs) + envs.single_action_space.shape).to(device) 
logprobs = torch.zeros((num_steps, num_envs)).to(device) 
rewards = torch.zeros((num_steps, num_envs)).to(device) 
dones = torch.zeros((num_steps, num_envs)).to(device) 
values = torch.zeros((num_steps, num_envs)).to(device) 

total_timesteps = 64*4096
bs = 1024
gamma = 0.99
update_epochs = 4 
mini_bs = 64 #4 
clip_coef = 0.2
ent_coef = 0.01 
vf_coef = 0.5

global_step = 0 
next_obs = torch.tensor(envs.reset()[0][:, RAM_mask]/255., dtype=torch.float32).to(device) 
next_done = torch.zeros(num_envs).to(device) 
num_updates = total_timesteps // bs

highest_score = 0 

for update in (loop := trange(1, num_updates + 1)): 
	score = np.zeros(4)
	for step in range(0, num_steps):
		# Simulate first, learn later. 
		global_step = 1 * num_envs 
		obs[step] = next_obs 
		dones[step] = next_done 

		with torch.no_grad(): 
			action, logprob, _, value = agent.get_action_and_value(next_obs) 
			values[step] = value.flatten()
		actions[step] = action 
		logprobs[step] = logprob 

		next_obs, reward, terminated, truncated, info = envs.step(action.cpu().numpy())
		done = np.bitwise_or(terminated, truncated).astype(np.float32)
		next_obs, next_done = torch.tensor(next_obs[:, RAM_mask]/255., dtype=torch.float32).to(device), torch.tensor(done).to(device) 
		
		score[reward == np.ones(4)] += 1

		# for success
		reward[reward == np.ones(4)] += 10 
		# for staying in one spot
		reward[(obs[step][:, 0] - next_obs[:, 0]).detach().cpu().numpy() == np.zeros(4)] -= 1
		# for collision 
		reward[next_obs[:, 1] != np.ones(4)] -= 100

		rewards[step] = torch.tensor(reward).to(device).view(-1) 
	if score.max() > highest_score: 
		highest_score = score.max()
	
	with torch.no_grad(): 
		next_value = agent.get_value(next_obs).reshape(1, -1)
		returns = torch.zeros_like(rewards).to(device)
		for t in reversed(range(num_steps)):
			if t == num_steps - 1: 
				# if it is the last step then look at the next step to come (not stored)
				next_non_terminal = 1.0 - next_done 
				next_return = next_value 
			else: # if it is not the last step, it is already stored. 
				next_non_terminal = 1.0 - dones[t + 1]
				next_return = returns[t + 1]
			# if it is done, 'next_non_terminal' is 0. 
			# return = rewards_t + \sum_{t+1}^{t+n} \gamma(return_{t+1})(next_non_terminal)
			returns[t] = rewards[t] + gamma * next_non_terminal * next_return 
		advantages = returns - values 
		
	b_obs = obs.reshape((-1, RAM_mask.size)) 
	b_logprobs = logprobs.reshape(-1) 
	b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
	b_advantages = advantages.reshape(-1) 
	b_returns = returns.reshape(-1) 
	b_values = values.reshape(-1) 
	b_inds = np.arange(bs) 
	clipfracs = list() 

	for epoch in range(update_epochs): 
		np.random.shuffle(b_inds)
		for start in range(0, bs, mini_bs): 
			end = start + mini_bs 
			mb_inds = b_inds[start:end]
			_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
			logratio = newlogprob - b_logprobs[mb_inds]
			ratio = logratio.exp() # Ratio between current policy and old policy 

			with torch.no_grad(): 
				old_approx_kl = (-logprob).mean() 
				approx_kl = ((ratio - 1) - logratio).mean() 
				clipfracs += [((ratio - 1.0).abs() > clip_coef).float().mean().item()]

			mb_advantages = b_advantages[mb_inds]
			mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

			pg_loss1 = -mb_advantages * ratio
			pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
			pg_loss = torch.max(pg_loss1, pg_loss2).mean()
			# max because we are taking gradient ascent 

			v_loss_unclipped = (newvalue - b_returns[mb_inds])**2 
			v_clipped = b_values[mb_inds] + torch.clamp(newvalue - b_values[mb_inds], -clip_coef, clip_coef)
			v_loss_clipped = (v_clipped - b_returns[mb_inds])**2 
			v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
			v_loss = 0.5 * v_loss_max.mean() 

			entropy_loss = entropy.mean() 
			loss = pg_loss - ent_coef * entropy_loss + vf_coef * v_loss
			# 'entropy_loss' maintains exploratory behavior of the model. 
			# 'ent_coef' controls the explotory behavior 

			optimizer.zero_grad() 
			loss.backward() 
			nn.utils.clip_grad_norm_(agent.parameters(), 0.5) 
			optimizer.step()

			loop.set_postfix(highest_score=highest_score)

  0%|          | 0/256 [00:00<?, ?it/s]

100%|██████████| 256/256 [52:39<00:00, 12.34s/it, highest_score=23]


In [6]:
one_env = gym.make('ALE/Freeway-ram-v5')
state, info = one_env.reset() 
state = torch.tensor(state[RAM_mask], dtype=torch.float32, device=device).unsqueeze(0) 

total_score = 0 
while True: 
  with torch.no_grad(): 
    action = agent.actor(state).argmax().item()
  observation, rewards, terminated, truncated, _ = one_env.step(action)
  state = torch.tensor(observation[RAM_mask], dtype=torch.float32, device=device)
  if rewards == 1: 
    total_score += 1

  if terminated or truncated: 
    break 
print(total_score)

21


In [7]:
torch.save(agent.state_dict(), 'models/freeway_ppo/minus_100.pth')