In [42]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import numpy as np 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm, trange

# env = gym.make("CartPole-v1")
env = gym.make('LunarLander-v2')

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
	from IPython import display

plt.ion()

# if GPU is to be used
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [43]:
transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class replay_memory(object): 
	def __init__(self, capacity): 
		self.memory = deque([], maxlen=capacity)
	
	def push(self, *args): 
		self.memory.append(transition(*args))
	
	def sample(self, batch_size): 
		return random.sample(self.memory, batch_size) 
	
	def __len__(self): 
		return len(self.memory) 

In [44]:
class dqn(nn.Module): 
	def __init__(self, n_observations, n_actions): 
		super().__init__() 
		
		self.layer1 = nn.Linear(n_observations, 128) 
		self.layer2 = nn.Linear(128, 128) 
		self.layer3 = nn.Linear(128, n_actions) 

	def forward(self, x): 
		x = F.relu(self.layer1(x)) 
		x = F.relu(self.layer2(x)) 
		return self.layer3(x) 

In [46]:
bs = 128 
gamma = 0.99 
eps_start, eps_end = 0.9, 0.05 
eps_decay = 1000 
tau = 0.005 
lr = 1e-4 

n_actions = env.action_space.n 
state, info = env.reset() 
n_observations = len(state) 

policy_net = dqn(n_observations, n_actions).to(device) 
target_net = dqn(n_observations, n_actions).to(device) 
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=lr, amsgrad=True)
memory = replay_memory(10000) 

steps_done = 0 

def select_action(state): 
	global steps_done 
	sample = random.random() 
	eps_threshold = eps_end + (eps_start - eps_end) * math.exp(-1. * steps_done / eps_decay) 
	steps_done += 1 
	if sample > eps_threshold: 
		with torch.no_grad(): 
			return policy_net(state).argmax(keepdims=True)#.max(1)[1].view(-1, 1)
	else: 
		return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

def plot_durations(show_result=False): 
	_, ax1 = plt.subplots() 
	durations_t = torch.tensor(episode_duration, dtype=torch.float)
	reward_t = torch.tensor(rewards, dtype=torch.float)
	if show_result: 
		plt.title('Result')
	else: 
		# plt.clf()
		plt.title('Training...')
	ax1.set_xlabel('Episode')
	ax1.set_ylabel('Duration')
	ax1.plot(durations_t.numpy(), color='blue')

	ax2 = ax1.twinx()
	ax2.set_ylabel('Reward')
	ax2.plot(reward_t.numpy(), color='red')

	# if len(durations_t) >= 100: 
	# 	means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
	# 	means = torch.cat((torch.zeros(99), means)) 
	# 	ax1.plot(means.numpy(), color='black')
	
	# plt.pause(0.001) 

	if is_ipython: 
		if not show_result: 
			display.display(plt.gcf()) 
			display.clear_output(wait=True) 
		else: 
			display.display(plt.gcf())

In [47]:
criterion = nn.SmoothL1Loss()

def optimize_model(): 
	if len(memory) < bs: 
		return False
	transitions = memory.sample(bs) 
	batch = transition(*zip(*transitions))
	non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
	non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
	state_batch = torch.cat(batch.state) 
	action_batch = torch.cat(batch.action)
	reward_batch = torch.cat(batch.reward) 

	state_action_values = policy_net(state_batch).gather(1, action_batch) 
	
	next_state_values = torch.zeros(bs, device=device)

	with torch.no_grad(): 
		next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
	
	expected_state_action_values = (next_state_values * gamma) + reward_batch 

	loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

	optimizer.zero_grad() 
	loss.backward() 

	nn.utils.clip_grad_value_(policy_net.parameters(), 100) 
	optimizer.step() 

In [None]:
num_episodes = 600
episode_duration = list() 
rewards, reward_agg = list(), 0

for i_episode in range(num_episodes): 
	state, info = env.reset() 
	state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) 
	
	for t in count():
		action = select_action(state) 
		observation, reward, terminated, truncated, _ = env.step(action.item())
		reward_agg += reward
		reward = torch.tensor([reward], device=device)
		done = terminated or truncated 

		if terminated: 
			next_state = None 
		else: 
			next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0) 

		memory.push(state, action, next_state, reward) 

		state = next_state 

		optimize_model() 

		target_net_state_dict = target_net.state_dict() 
		policy_net_state_dict = policy_net.state_dict() 
		for key in policy_net_state_dict: 
			target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1-tau)
		
		target_net.load_state_dict(target_net_state_dict) 

		if done: 
			episode_duration.append(t + 1) 
			plot_durations() 
			break 
	rewards.append(reward_agg/t)

print('Complete')
plot_durations(show_result=True) 
plt.ioff() 
plt.show()
# duration: blue 
# reward: red 

In [52]:
torch.save(policy_net.state_dict(), 'models/policy_net_lunar_lander')