In [4]:
import gym
import torch
import torch.optim as optim
import numpy as np
from collections import namedtuple
import torch.nn.functional as F
from skimage.transform import resize
import math

import matplotlib.pyplot as plt

from utils import ReplayBuffer, label_with_episode_number, save_random_agent_gif, file_exists, write_history, get_last_history, compute_reward, decimal2

from environment_wrapper import EnvironmentWrapper, ObservationWrapper, RewardWrapper, ActionWrapper
from fileIO import FileIO


In [7]:

class SARSAAgent(object):
    def __init__(self, env, env_name):
        self.env = env
        self.device = torch.device("mps")
        self.name = env_name.lower()
        self.action_dim = self.env.action_space.n
        self.state_size = (210, 160, 3)
        # Flatten the state shape
        state_space = int(env.observation_space.high.sum())
        
        # Define hyperparameters
        self.learning_rate = 0.1
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_decay = 0.99
        self.epsilon_min = 0.01
        self.target_update_freq = 10

        self.resume = False
        self.policy_path = f'saves/{self.name}/sarsa_policy.npy'
        self.history_path = f'saves/{self.name}/sarsa_history.csv'
        self.log = ''
        
        # Define the replay buffer
        self.Experience = np.zeros((state_space, int(self.action_dim)))

        # save time by retraining model from a known trained weights
        if (file_exists(self.policy_path)):
            self.resume = True
            self.Experience = np.load(self.policy_path)
            

        # implement evaluation history so traning can be ran in batches to not crash system
        self.current_episode =  get_last_history(self.history_path)
        
    def state_index(self, obs):
        obs_index = obs.sum()
        
        return int(obs_index)

    def act(self, state):
        if np.random.rand() < self.epsilon:
            action = self.env.action_space.sample()  # Explore
        else:
            action = np.argmax(self.Experience[self.state_index(state)])  # Exploit

        return action

    def train(self, num_episodes):
        history_data = []
        save_after = 50
        for episode in range(num_episodes):
            obs = self.env.reset()[0]
            action = self.act(obs)

            # plt.imshow(obs)
            # plt.show()

            done = False
            total_trad_reward = 0
            total_nontrad_reward = 0
            timestep = 0

            while not done:
                # Take action and observe next state and reward
                next_obs, reward, done, err, info = self.env.step(action)
                next_action = self.act(next_obs)
                
                # Render the environment
                # self.env.render()
                
                # get traditional and non-traditional reward
                trad, nontrad = compute_reward(reward, done)
                
                combined_reward = trad + nontrad
                total_trad_reward += trad
                total_nontrad_reward += nontrad
                
                # Update Q-table
                q_value = self.Experience[self.state_index(obs), action]
                next_q_value = self.Experience[self.state_index(next_obs), next_action]
                td_error = combined_reward + self.gamma * next_q_value - q_value
                self.Experience[self.state_index(obs), action] += self.learning_rate * td_error

                # Update the current state
                obs = next_obs
                action = next_action
                timestep += 1
                
            total_nontrad_reward = decimal2(total_nontrad_reward)
            total_reward = total_trad_reward + total_nontrad_reward

            # Print the total reward for the episode
            print(f"Episode: {episode}, Total Reward: {total_reward}")
            
            # persist history and policy after every 20 episodes
            history_data.append(f"{self.current_episode + episode + 1}, {total_reward}, {total_trad_reward}, {total_nontrad_reward}, {timestep}, {action}")
            if ((episode + 1) % save_after == 0):
                # Save the experience
                np.save(self.policy_path, self.Experience)
                
                # write history
                write_history(self.history_path, history_data)
                history_data = []

    def evaluate(self, num_episodes):
        wins = 0
        for episode in range(num_episodes):
            state = self.env.reset()[0]
            done = False
            while not done:
                action = self.act(state)
                next_state, reward, done, _, __ = self.env.step(action)
                state = next_state
                if reward > 0:
                    wins += 1
        print(f'Wins - {wins}; Episodes - {num_episodes}; Average - {wins / num_episodes}') 
    
    def close_session(self):
        self.env.close()

In [None]:
env = gym.make("Breakout-v4")
env_name = env.spec.id.split('-')[0]
agent = SARSAAgent(env, env_name)
agent.train(num_episodes=200)

agent.close_session()

print(f'---- end environment ----')

In [9]:
env = gym.make("Breakout-v4")
env_name = env.spec.id.split('-')[0]
agent = SARSAAgent(env, env_name)
episodes = 20
agent.evaluate(episodes)
agent.close_session()

Wins - 30; Episodes - 20; Average - 1.5
