In [1]:
%load_ext autoreload

# --------------- #
# region: Imports #
import os
import sys
module_path = os.path.abspath('../../..')
if module_path not in sys.path:
  sys.path.insert(0, module_path)
# endregion   #
# --------------- #

In [2]:
from examples.cleanup.env import Cleanup
from examples.cleanup.agents import Agent
from examples.RPG.utils import load_config
from gem.models.grid_cells import positional_embedding
from gem.models.DDQN import doubleDQN
from gem.models.iqn import iRainbowModel
from gem.utils import visual_field_sprite, image_from_array, animate, one_hot_encode

from IPython.display import clear_output

import argparse
import matplotlib.pyplot as plt
import random

import torch
import numpy as np

cfg = load_config(argparse.Namespace(config='../configs/config.yaml'))

seed = random.randint(1,100)

N_AGENTS = 1
agents = []
for i in range(N_AGENTS):
  agents.append(
  Agent(cfg, appearance = cfg.agent.agent.appearance, 
    #     model = doubleDQN(
    #     input_size=5224,
    #     number_of_actions=4,
    #     lr=0.001,
    #     gamma=0.97,
    #     per=False,
    #     alpha=0.6,
    #     beta=0.05,
    #     beta_increment=0.0006,
    #     capacity=5000,
    # )
    model = iRainbowModel(
      state_size=[8,9,9],
      action_size=4,
      layer_size=250,
      epsilon=.9,
      device="mps",
      seed=seed,
      num_frames=5,
      n_step=3,
      BATCH_SIZE= 64,
      BUFFER_SIZE= 1024,
      LR=0.001,
      TAU=.001,
      GAMMA=0.99,
      N=12,
      sync_freq=200,
      model_update_freq=4
    )
          )
  )




env = Cleanup(
  cfg, agents
)

In [4]:
cfg.experiment.epochs = 10000 # override the number of epochs

rewards = []
losses = 0
epsilon = .9

EPOCH_PRINT_FREQ = 25
EPSILON_DECAY_RATE = 0.999
EPSILON_DECAY_FREQ = 1
EVAL_EPSILON = 0.05

for epoch in range(cfg.experiment.epochs): # note that the language is not right. epoch is training. episode is the game
    # Reset the environment at the start of each epoch
        env.reset()
        for agent in env.agents:
            agent.reset()
            agent.init_replay()
            agent.model.start_epoch_action(**locals())
        random.shuffle(agents)

        done = 0 
        turn = 0
        losses = 0
        game_points = 0

        images = []
        if epoch % EPSILON_DECAY_FREQ == 0:
            epsilon = epsilon*EPSILON_DECAY_RATE

        while not done:

            turn = turn + 1

            entities = env.get_entities_for_transition()
            # Entity transition
            for entity in entities:
                entity.transition(env)

            # Agent transition
            for agent in agents:
                state = agent.pov_stack(env)

                if epoch % EPOCH_PRINT_FREQ == 0:
                    _image = visual_field_sprite(env.world)
                    image = image_from_array(_image)
                    images.append(image)

                # Take action based on current state
                if epoch % EPOCH_PRINT_FREQ == 0:
                    action = agent.model.take_action(state, EVAL_EPSILON)
                else:
                    action = agent.model.take_action(state, epsilon)

                (reward,
                next_state,
                done_
                ) = agent.transition(env, state, action)

                if turn >= cfg.experiment.max_turns or done_:
                    done = 1

                exp = (1, (state, action, reward, next_state, done))
                agent.episode_memory.append(exp)
                #TODO: decide on memory update procedures
                agent.model.end_epoch_action(**locals())

                game_points += reward

        rewards.append(game_points)
        
        # At the end of each epoch, train as long as the batch size is large enough.
        if epoch > 10:
            loss = agent.model.train_model()
            losses += loss
            
        # Calculate the average of the last 100 rewards
        if len(rewards) >= 100:
            avg_last_100_rewards = sum(rewards[-100:]) / 100
        else:
            avg_last_100_rewards = sum(rewards) / len(rewards)
        if epoch % EPOCH_PRINT_FREQ == 0:
            print(f'Epoch: {epoch} - Epsilon: {epsilon} - Losses {losses} - Avg. last 100 rewards: {avg_last_100_rewards}')
            animate(
                images, filename = f"cleanup_epoch{epoch}", folder = f"{cfg.root}/examples/cleanup/data/"
            )
            losses = 0


Epoch: 0 - Epsilon: 0.8991 - Losses 0 - Avg. last 100 rewards: 5.0
Epoch: 25 - Epsilon: 0.8768901733960046 - Losses 0.2443152368068695 - Avg. last 100 rewards: 4.3076923076923075
Epoch: 50 - Epsilon: 0.8552289803119505 - Losses 0.7499598264694214 - Avg. last 100 rewards: 4.0588235294117645
Epoch: 75 - Epsilon: 0.8341028682449498 - Losses 0.06996339559555054 - Avg. last 100 rewards: 3.9342105263157894
Epoch: 100 - Epsilon: 0.8134986194699355 - Losses 0.2628120183944702 - Avg. last 100 rewards: 3.87
Epoch: 125 - Epsilon: 0.7934033427698834 - Losses 0.32234248518943787 - Avg. last 100 rewards: 3.99
Epoch: 150 - Epsilon: 0.7738044653703178 - Losses 0.34577804803848267 - Avg. last 100 rewards: 3.91
Epoch: 175 - Epsilon: 0.7546897250730513 - Losses 0.2992754876613617 - Avg. last 100 rewards: 3.99
Epoch: 200 - Epsilon: 0.7360471625842407 - Losses 0.8197137117385864 - Avg. last 100 rewards: 4.21
Epoch: 225 - Epsilon: 0.7178651140319564 - Losses 0.4004423916339874 - Avg. last 100 rewards: 3.83


KeyboardInterrupt: 