In [1]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import namedtuple
import torch.nn.functional as F

import matplotlib.pyplot as plt

from utils import ReplayBuffer, label_with_episode_number, save_random_agent_gif, file_exists, write_history, get_last_history, compute_reward, decimal2,transform

from environment_wrapper import EnvironmentWrapper, ObservationWrapper, RewardWrapper, ActionWrapper
from fileIO import FileIO

class DQN(nn.Module):
    """CNN with 3 convolution layers"""
    def __init__(self, output, name, batch_size):
        super(DQN, self).__init__()
        self.name = f'{name}_agent.h5'
        self.batch_size = batch_size

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc = nn.Flatten()
        self.fc2 = nn.Linear(128 * 200, output)
    
    def forward(self, x, batch = None):
        # m = x[0].permute(0, 1, 2).numpy()
        # x = x.reshape(batch or self.batch_size, 3, 210, 160)
        # print(type(x[0]))
        # m = x[0].permute(1, 2, 0).numpy()
        # plt.imshow(m)
        # plt.axis('off')  # Optional: Turn off axis ticks and labels
        # plt.show()
        # return 'uuu'
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        # x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.fc2(x)
        # print(x)
        return x

class DQNAgent(object):
    def __init__(self, env, env_name):
        self.env = env
        self.device = torch.device("cpu")
        self.batch_size = env.action_space.n
        self.input_shape = (3, 160, 160)
        self.name = env_name.lower()
        self.wrapped_env = env
        self.action_dim = self.wrapped_env.action_space.n
        
        # Create the DQN model and target network
        self.model = DQN(self.action_dim, self.name, self.batch_size).to(self.device)
        
        # Define hyperparameters
        self.learning_rate = 0.001
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_decay = 0.99
        self.epsilon_min = 0.01
        self.target_update_freq = 10
        self.buffer_capacity = 10000
        
        self.resume = False
        self.policy_path = f'saves/{self.name}/dqn_policy.h5'
        self.history_path = f'saves/{self.name}/dqn_history.csv'
        self.log = ''

        # save time by retraining model from a known trained weights
        if (file_exists(self.policy_path)):
            self.resume = True
            self.model.load_state_dict(torch.load(self.policy_path))

        self.target_model = DQN(self.action_dim, self.name, self.batch_size).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        # self.target_model.eval()

        # Define the loss function and optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

        # Define the replay buffer
        self.Experience = namedtuple('Experience', ('state', 'action', 'reward', 'next_state', 'done'))

        # Initialize the replay buffer
        self.replay_buffer = ReplayBuffer(self.buffer_capacity)

        # implement evaluation history so traning can be ran in batches to not crash system
        self.current_episode =  get_last_history(self.history_path)
        

    def act(self, state):
        if np.random.rand() < self.epsilon:
            action = self.env.action_space.sample()  # Explore
        else:
            with torch.no_grad():
                # obs_batch = torch.tensor(np.array(state), device=self.device, dtype=torch.float32)
                state = state.unsqueeze(0)
                q_values = self.model(state, 1).to(self.device)
                action = torch.argmax(q_values).item() % self.batch_size
                # self.log = f'{q_values}, {action}'
        
        return action

    def learn(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        # Sample a minibatch from replay memory
        batch = self.replay_buffer.sample(self.batch_size)
        batch = self.Experience(*zip(*batch))

        action_batch = torch.tensor(batch[1], device=self.device, dtype=torch.long)
        reward_batch = torch.tensor(batch[2], device=self.device, dtype=torch.float)
        done_batch = torch.tensor(batch[4], device=self.device, dtype=torch.float)

        # state_batch = torch.tensor(batch[0], device=self.device, dtype=torch.float32)
        # next_state_batch = torch.tensor(batch[3], device=self.device, dtype=torch.float32)
        state_batch = torch.stack(batch[0])
        next_state_batch = torch.stack(batch[3])

        # Calculate the Q-values for the current state
        q_values = self.model(state_batch).to(self.device)
        q_values = q_values.gather(1, action_batch.unsqueeze(1)).squeeze(1)

        # Calculate the target Q-values using the target network
        with torch.no_grad():
            target_q_values = self.target_model(next_state_batch).to(self.device)
            target_q_values = target_q_values.gather(1, action_batch.unsqueeze(1)).squeeze(1)
            
            target_q_values = reward_batch + self.gamma * (1 - done_batch) * target_q_values

        # Calculate the loss and perform gradient descent
        loss = F.smooth_l1_loss(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def train(self, num_episodes):
        history_data = []
        save_after = 10
        for episode in range(num_episodes):
            obs = transform(self.wrapped_env.reset()[0], False)

            # plt.imshow(obs)
            # plt.show()
            
            # Render the environment
            # env.render()

            done = False
            total_trad_reward = 0
            total_nontrad_reward = 0
            timestep = 0

            while not done:
                # Epsilon-greedy action selection
                # obs = (np.rint(obs)).astype(int)
                action = self.act(obs)

                # Take action and observe next state and reward
                next_obs, reward, done, err, info = self.wrapped_env.step(action)
                
                # Render the environment
                # self.env.render()
                
                # get traditional and non-traditional reward
                trad, nontrad = compute_reward(reward, done)
                
                combined_reward = trad + nontrad
                total_trad_reward += trad
                total_nontrad_reward += nontrad
                
                # next_obs = (np.rint(next_obs)).astype(int)
                
                # Store the experience in the replay buffer
                next_obs = transform(next_obs, False)
                experience = self.Experience(obs, action, combined_reward, next_obs, done)
                self.replay_buffer.push(experience)

                # Update the current state
                obs = next_obs
                
                # model learning
                self.learn()
                timestep += 1
                
            # Update the target network
            if episode % self.target_update_freq == 0:
                self.target_model.load_state_dict(self.model.state_dict())

            # Decay epsilon
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

            total_nontrad_reward = decimal2(total_nontrad_reward)
            total_reward = total_trad_reward + total_nontrad_reward
            
            # Print the total reward for the episode
            print(f"Episode: {episode}, Total Reward: {total_reward}, Traditional: {total_trad_reward}")
            
            # persist history and policy after every 20 episodes
            history_data.append(f"{int(self.current_episode) + episode + 1}, {total_reward}, {total_trad_reward}, {total_nontrad_reward}, {timestep}, {self.log}")
            if ((episode + 1) % save_after == 0):
                # Save the experiencel
                torch.save(self.model.state_dict(), self.policy_path)
                
                # write history
                write_history(self.history_path, history_data)
                history_data = []

    def evaluate(self, num_episodes):
        wins = 0
        for episode in range(num_episodes):
            state = self.env.reset()[0]
            done = False
            while not done:
                action = self.act(state)
                next_state, reward, done, _, __ = self.env.step(action)
                state = next_state
                if reward > 0:
                    wins += 1
        print(f'Wins - {wins}; Episodes - {num_episodes}; Average - {wins / num_episodes}') 
    
    def close_session(self):
        self.wrapped_env.close()

In [5]:
for i in range(20):
  env = gym.make("Breakout-v4")
  env_name = env.spec.id.split('-')[0]
  agent = DQNAgent(env, env_name)
  agent.train(num_episodes=10)
  agent.close_session()

  del env
  del agent
  
  print(f'---- starting environment {i} ----')

Episode: 0, Total Reward: 16.29, Traditional: 0
Episode: 1, Total Reward: 24.8, Traditional: 1
Episode: 2, Total Reward: 22.2, Traditional: 1
Episode: 3, Total Reward: 20.8, Traditional: 1
Episode: 4, Total Reward: 23.2, Traditional: 1
Episode: 5, Total Reward: 28.9, Traditional: 2
Episode: 6, Total Reward: 28.8, Traditional: 2
Episode: 7, Total Reward: 22.5, Traditional: 1
Episode: 8, Total Reward: 22.0, Traditional: 1
Episode: 9, Total Reward: 17.19, Traditional: 0
---- starting environment 0 ----
Episode: 0, Total Reward: 24.2, Traditional: 1
Episode: 1, Total Reward: 21.7, Traditional: 1
Episode: 2, Total Reward: 28.5, Traditional: 1
Episode: 3, Total Reward: 44.1, Traditional: 4
Episode: 4, Total Reward: 31.6, Traditional: 2
Episode: 5, Total Reward: 31.9, Traditional: 2
Episode: 6, Total Reward: 17.39, Traditional: 0
Episode: 7, Total Reward: 25.1, Traditional: 1
Episode: 8, Total Reward: 23.5, Traditional: 1
Episode: 9, Total Reward: 32.5, Traditional: 2
---- starting environmen

In [11]:
env = gym.make("Breakout-v4")
env_name = env.spec.id.split('-')[0]
agent = DQNAgent(env, env_name)
episodes = 20
agent.evaluate(episodes)
agent.close_session()

Wins - 22; Episodes - 20; Average - 1.1
