In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import numpy as np 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from tqdm import tqdm, trange

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
	from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
env = gym.make('ALE/Freeway-ram-v5')

RAM_mask = np.asarray([14, 16, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117])

In [3]:
transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class replay_memory(object): 
	def __init__(self, capacity): 
		self.memory = deque(list(), maxlen=capacity) 
		
	def push(self, *args): 
		self.memory.append(transition(*args))
	
	def sample(self, batch_size): 
		return random.sample(self.memory, batch_size) 
	
	def __len__(self): 
		return len(self.memory) 

In [4]:
steps_done = 0

def select_action(state): 
	global steps_done
	sample = random.random() 
	eps_threshold = eps_end + (eps_start - eps_end) * math.exp(-1. * steps_done / eps_decay) 
	steps_done += 1 
	if sample > eps_threshold: 
		with torch.no_grad(): 
			return policy_net(state).argmax(keepdims=True)
	else: 
		return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

def plot_durations(show_result=False): 
	_, ax1 = plt.subplots() 
	durations_t = torch.tensor(episode_duration, dtype=torch.float)
	reward_t = torch.tensor(rewards, dtype=torch.float)
	if show_result: 
		plt.title('Result')
	else: 
		# plt.clf()
		plt.title('Training...')
	ax1.set_xlabel('Episode')
	ax1.set_ylabel('Duration')
	ax1.plot(durations_t.numpy(), color='blue')

	ax2 = ax1.twinx()
	ax2.set_ylabel('Reward')
	ax2.plot(reward_t.numpy(), color='red')

	if is_ipython: 
		if not show_result: 
			display.display(plt.gcf()) 
			display.clear_output(wait=True) 
		else: 
			display.display(plt.gcf())

def optimize_model(): 
	if len(memory) < bs: 
		return False
	transitions = memory.sample(bs)
	batch = transition(*zip(*transitions))
	non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
	non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]).to(device)
	state_batch = torch.cat(batch.state)
	action_batch = torch.cat(batch.action)
	reward_batch = torch.cat(batch.reward) 

	state_action_values = policy_net(state_batch).gather(1, action_batch) 
	
	next_state_values = torch.zeros(bs, device=device)

	with torch.no_grad(): 
		next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
	
	expected_state_action_values = (next_state_values * gamma) + reward_batch 

	loss = loss_fn(state_action_values, expected_state_action_values.unsqueeze(1))

	optimizer.zero_grad() 
	loss.backward() 

	nn.utils.clip_grad_value_(policy_net.parameters(), 100) 
	optimizer.step() 

In [5]:
class dqn(nn.Module): 
	def __init__(self, n_action=env.action_space.n): 
		super().__init__() 

		self.net = nn.Sequential(
			nn.Linear(12, 8), nn.ReLU(), 
			nn.Linear(8, 4), nn.ReLU(), 
			nn.Linear(4, n_action)
		)
	
	def forward(self, x): 
		out = self.net(x/255.) 
		return out 

In [6]:
bs = 128 
gamma = 0.99 
eps_start, eps_end = 0.9, 0.05 
eps_decay = 100000
tau = 0.005 
lr = 1e-4  

state, info = env.reset() 

policy_net = dqn().to(device) 
target_net = dqn().to(device) 
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=lr, amsgrad=True)
memory = replay_memory(100000) 

loss_fn = nn.SmoothL1Loss()

num_episodes = 500
episode_duration = list() 
rewards, rewards_agg = list(), 0 

score = 0 
highest_score = 0 

for i_episode in (loop := trange(num_episodes)): 
	state, info = env.reset() 
	state = torch.tensor(state[RAM_mask], dtype=torch.float32, device=device).unsqueeze(0) 
	for t in count(): 
		action = select_action(state)
		observation, reward, terminated, truncated, _ = env.step(action.item()) 
		
		if reward == 1: 
			reward += 499
			score += 1
		if observation[14] - state[0, 0] == 0: 
			reward -= 1 
		if observation[16] != 255: 
			reward -= 100

		reward = torch.tensor([reward], device=device)

		if terminated: 
			next_state = None 
		else: 
			next_state = torch.tensor(observation[RAM_mask], dtype=torch.float32, device=device).unsqueeze(0) 

		memory.push(state, action, next_state, reward) 

		state = next_state

		optimize_model() 

		if t % 50 == 0: 
			# target_net_state_dict = target_net.state_dict() 
			# policy_net_state_dict = policy_net.state_dict() 
			# for key in policy_net_state_dict: 
			# 	target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1-tau)

			# target_net.load_state_dict(target_net_state_dict) 
			target_net.load_state_dict(policy_net.state_dict())

		if terminated or truncated: 
			# state, info = env.reset() 
			# state = torch.tensor(state[RAM_mask], dtype=torch.float32, device=device).unsqueeze(0)
			if score > highest_score: 
				torch.save(policy_net.state_dict(), f'models/freeway/{i_episode}_{t}_highest_score.pth')
				highest_score = score
			score = 0 
			break
		rewards.append(reward.item())
		loop.set_description(f'Iterations {i_episode+1}/{num_episodes}')
		loop.set_postfix(score=score, highest_score=highest_score)
	
	if (i_episode + 1) % 10 == 0: 
		torch.save(policy_net.state_dict(), f'models/freeway/episode_{i_episode+1}.pth')

print('Complete')
# plot_durations(show_result=True) 
# plt.ioff() 
# plt.show()

Iterations 500/500: 100%|██████████| 500/500 [3:00:54<00:00, 21.71s/it, highest_score=27, score=21]  

Complete





In [7]:
steps_done

1024000