In [2]:
import gymnasium as gym
import ale_py
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation
import torch
from torch import nn
import matplotlib.pyplot as plt
from collections import deque
import random
import numpy as np

In [4]:
class Agent(nn.Module):


    def __init__(self):

        super(Agent, self).__init__()

        self.conv = nn.Sequential(

        #NoisyConv2d(input_shape[0], 32, kernel_size=8, stride=4),

        nn.Conv2d(4,32,8,4),

        nn.ReLU(),

        #NoisyConv2d(32, 64, kernel_size=4, stride=2),

        nn.Conv2d(32, 64, kernel_size=4, stride=2),

        nn.ReLU(),

        #NoisyConv2d(64, 64, kernel_size=3, stride=1),

        nn.Conv2d(64, 64, kernel_size=3, stride=1),

        nn.ReLU()

        )

        conv_out_size = self._get_conv_out((4,84,84))

        self.fc = nn.Sequential(

        #NoisyLinear(conv_out_size, 512),

        nn.Linear(conv_out_size, 512),

        nn.ReLU()

        )

        self.state_layer = nn.Linear(512,1)#NoisyLinear(512,num_bins)

        self.advantage_layer = nn.Linear(512,4)#NoisyLinear(512,num_actions * num_bins)


    def _get_conv_out(self, shape: tuple) -> int:

        o = self.conv(torch.zeros(1, *shape))

        return int(torch.prod(torch.tensor(o.size())))



    def forward(self, x):


        x = x / 255.0

        x = self.conv(x)

        x = torch.flatten(x, start_dim=1)

        x = self.fc(x)

        value = self.state_layer(x)

        advantage = self.advantage_layer(x)


        value = value.expand_as(advantage)

        mean_advantage = advantage.mean(dim=1, keepdim = True)


        state_action = value + advantage - mean_advantage


        return state_action 

device = "mps"
epsilon = 1.0
min_epsilon = 0.1
gamma = 0.99
lr = 2.5e-4
mini_batch_size = 32
max_steps = 10_000_000
cur_step = 0
target_network_update_frequency = 10_000
action_space = [0,1,2,3]
experience_buffer = deque(maxlen = 300_000)

agent = Agent().to(device=device)
target_network = Agent().to(device = device)
target_network.load_state_dict(agent.state_dict())
optim = torch.optim.RMSprop(agent.parameters(), lr, alpha = 0.95, eps = 0.01) 
loss_fn = torch.nn.SmoothL1Loss()

gym.register_envs(ale_py)
env = gym.make("ALE/Breakout-v5", frameskip = 1)
env = AtariPreprocessing(env, frame_skip=4, grayscale_obs=True, scale_obs=False, screen_size=84)
env = FrameStackObservation(env, stack_size = 4)

best_loss = -10
loss_over_time = []

model_state, _ = env.reset()
model_state = torch.tensor(model_state).to(device)

def replay_exp(size = mini_batch_size):
    memories = random.sample(experience_buffer, k = size)
    state, action, reward, next_state = zip(*memories)

    states = torch.stack([s.to(device) for s in state])
    actions = torch.tensor(action, dtype = torch.long, device = device)
    rewards = torch.tensor(reward, dtype = torch.float32, device = device)
    non_terminal_states = [ns is not None for ns in next_state]
    next_states = torch.stack([ns.to(device) if ns is not None else torch.zeros_like(model_state).to(device) for ns in next_state])

    q_values = agent(states).gather(1, actions.unsqueeze(1))
    with torch.no_grad():
        target = torch.zeros(size, device=device)
        next_action = agent(next_states[non_terminal_states]).argmax(1)
        target[non_terminal_states] = gamma * target_network(next_states[non_terminal_states]).gather(1, next_action.unsqueeze(1)).squeeze()
        target += rewards
        target = target.unsqueeze(1)
    loss = loss_fn(q_values, target)

    optim.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(agent.parameters(), 10.0)
    optim.step()

@torch.no_grad()
def select_action(state, test = False):
    global epsilon
    if not test:
        epsilon = max(min_epsilon, epsilon - 9e-7)
    if random.random() < epsilon:
        return random.choice(action_space)
    else:
        with torch.no_grad():
            if not isinstance(state, torch.Tensor):
                state = torch.tensor(state, dtype = torch.float32).unsqueeze(0)
            state = state.to(device=device)
            action_values = agent(state)
            return torch.argmax(action_values)

@torch.no_grad()
def test_model(num_tests = 50):
    total_reward = 0
    for i in range(num_tests):
        state, _ = env.reset()
        action = 1
        while True:
            next_state, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward
            if terminated or truncated:
                break
            state = next_state
            action = select_action(state, True)
    return total_reward / num_tests


while cur_step < max_steps:
    state, _ = env.reset()
    action = 1
    while True:
        cur_step += 1
        next_state, reward, terminated, truncated, _ = env.step(action)
        reward = max(min(reward, 1), -1)
        experience_buffer.append((torch.tensor(state, dtype = torch.float32), action, reward, torch.tensor(next_state, dtype = torch.float32) if not terminated and not truncated else None))
        if cur_step > 50_000:
            replay_exp(mini_batch_size)


        if cur_step % target_network_update_frequency == 0:
            target_network.load_state_dict(agent.state_dict())

        if cur_step % 5000 == 0:
            test_val = test_model(10)
            loss_over_time.append(test_val)
            if test_val > best_loss:
                best_loss = test_val                    
                torch.save(agent.state_dict(), "AtariAgent.pt")
            print(f"{cur_step/max_steps * 100:.2f}% done, current score: {test_val:.2f}, average_score: {np.mean(loss_over_time[max(0, (cur_step//5000) - 20):cur_step // 5000]):.2f}, exploration rate: {epsilon * 100:.2f}%")
        if terminated or truncated:
            break

        state = next_state
        action = select_action(state)


    
plt.plot(loss_over_time)

0.05% done, current score: 1.30, average_score: 1.30, exploration rate: 99.55%
0.10% done, current score: 1.40, average_score: 1.35, exploration rate: 99.11%
0.15% done, current score: 1.00, average_score: 1.23, exploration rate: 98.66%
0.20% done, current score: 2.10, average_score: 1.45, exploration rate: 98.21%
0.25% done, current score: 1.70, average_score: 1.50, exploration rate: 97.76%
0.30% done, current score: 0.80, average_score: 1.38, exploration rate: 97.31%
0.35% done, current score: 1.60, average_score: 1.41, exploration rate: 96.87%
0.40% done, current score: 1.10, average_score: 1.38, exploration rate: 96.42%
0.45% done, current score: 1.50, average_score: 1.39, exploration rate: 95.97%
0.50% done, current score: 0.90, average_score: 1.34, exploration rate: 95.53%
0.55% done, current score: 1.30, average_score: 1.34, exploration rate: 95.08%


KeyboardInterrupt: 