In [4]:
import gym
import torch
import numpy as np
from collections import namedtuple
import torch.nn.functional as F
from skimage.transform import resize

import matplotlib.pyplot as plt

from utils import ReplayBuffer, label_with_episode_number, save_random_agent_gif, file_exists, write_history, get_last_history, compute_reward, decimal2

from fileIO import FileIO

class SARSAAgent(object):
    def __init__(self, env, env_name, eval_mode = False):
        self.env = env
        self.device = torch.device("mps")
        self.name = env_name.lower()
        self.action_dim = self.env.action_space.n
        self.state_size = (210, 160, 3)
        # Flatten the state shape
        state_space = int(env.observation_space.high.sum())
        
        # Define hyperparameters
        self.learning_rate = 0.1
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_decay = 0.99
        self.epsilon_min = 0.01
        self.target_update_freq = 10

        self.resume = False
        self.eval_mode = eval_mode
        self.policy_path = f'saves/{self.name}/sarsa_policy.npy'
        self.history_path = f'saves/{self.name}/sarsa_history.csv'
        self.log = ''
        
        # Define the replay buffer
        self.Experience = np.zeros((state_space, int(self.action_dim)))

        # save time by retraining model from a known trained weights
        if (file_exists(self.policy_path)):
            self.resume = True
            self.Experience = np.load(self.policy_path)
            print('----- loaded saved weights ------')
            

        # implement evaluation history so traning can be ran in batches to not crash system
        self.current_episode =  get_last_history(self.history_path)
        
    def state_index(self, obs):
        obs_index = obs.sum()
        
        return int(obs_index)

    def act(self, state):
        if np.random.rand() < self.epsilon or not self.eval_mode:
            action = self.env.action_space.sample()  # Explore
        else:
            action = np.argmax(self.Experience[self.state_index(state)])  # Exploit

        return action

    def train(self, num_episodes):
        history_data = []
        save_after = 100
        for episode in range(num_episodes):
            obs = self.env.reset()[0]
            action = self.act(obs)

            # plt.imshow(obs)
            # plt.show()

            done = False
            total_trad_reward = 0
            total_nontrad_reward = 0
            timestep = 0

            while not done:
                # Take action and observe next state and reward
                next_obs, reward, done, err, info = self.env.step(action)
                next_action = self.act(next_obs)
                
                # Render the environment
                # self.env.render()
                
                # get traditional and non-traditional reward
                trad, nontrad = compute_reward(reward, done)
                
                combined_reward = trad + nontrad
                total_trad_reward += trad
                total_nontrad_reward += nontrad
                
                # Update Q-table
                q_value = self.Experience[self.state_index(obs), action]
                next_q_value = self.Experience[self.state_index(next_obs), next_action]
                td_error = combined_reward + self.gamma * next_q_value - q_value
                self.Experience[self.state_index(obs), action] += self.learning_rate * td_error

                # Update the current state
                obs = next_obs
                action = next_action
                timestep += 1
                
            total_nontrad_reward = decimal2(total_nontrad_reward)
            total_reward = total_trad_reward + total_nontrad_reward

            # Print the total reward for the episode
            print(f"Episode: {episode + 1}, Total Reward: {total_reward}, Traditional: {total_trad_reward}")
            
            # persist history and policy after every 20 episodes
            history_data.append(f"{self.current_episode + episode + 1}, {total_reward}, {total_trad_reward}, {total_nontrad_reward}, {timestep}, {action}")
            if ((episode + 1) % save_after == 0):
                # Save the experience
                np.save(self.policy_path, self.Experience)
                
                # write history
                write_history(self.history_path, history_data)
                history_data = []

    def evaluate(self, num_episodes):
        wins = 0
        for episode in range(num_episodes):
            state = self.env.reset()[0]
            done = False
            while not done:
                action = self.act(state)
                next_state, reward, done, _, __ = self.env.step(action)
                state = next_state
                if reward > 0:
                    wins += 1
        print(f'Wins - {wins}; Episodes - {num_episodes}; Average - {wins / num_episodes}') 
    
    def close_session(self):
        self.env.close()

In [5]:
# env_name = 'Breakout-v4'
env_name = 'Berzerk-v4'
training_episodes = 100
eval_episodes = 10

In [6]:
env = gym.make(env_name)
env_name = env.spec.id.split('-')[0]
agent = SARSAAgent(env, env_name)
agent.train(num_episodes=training_episodes)

agent.close_session()

print(f'---- end environment ----')

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


----- loaded saved weights ------
Episode: 1, Total Reward: 23.6, Traditional: 1
Episode: 2, Total Reward: 157.09, Traditional: 7
Episode: 3, Total Reward: 28.1, Traditional: 3
Episode: 4, Total Reward: 20.7, Traditional: 0
Episode: 5, Total Reward: 27.1, Traditional: 3
Episode: 6, Total Reward: 29.7, Traditional: 2
Episode: 7, Total Reward: 19.79, Traditional: 2
Episode: 8, Total Reward: 25.5, Traditional: 2
Episode: 9, Total Reward: 32.3, Traditional: 3
Episode: 10, Total Reward: 18.39, Traditional: 0
Episode: 11, Total Reward: 34.7, Traditional: 3
Episode: 12, Total Reward: 21.2, Traditional: 1
Episode: 13, Total Reward: 36.8, Traditional: 2
Episode: 14, Total Reward: 29.1, Traditional: 2
Episode: 15, Total Reward: 29.6, Traditional: 3
Episode: 16, Total Reward: 39.6, Traditional: 4
Episode: 17, Total Reward: 54.0, Traditional: 9
Episode: 18, Total Reward: 23.2, Traditional: 0
Episode: 19, Total Reward: 29.3, Traditional: 4
Episode: 20, Total Reward: 22.8, Traditional: 1
Episode: 21

In [7]:
env = gym.make(env_name)
env_name = env.spec.id.split('-')[0]
agent = SARSAAgent(env, env_name, True)
agent.evaluate(eval_episodes)
agent.close_session()

  logger.warn(


----- loaded saved weights ------
Wins - 42; Episodes - 10; Average - 4.2
