In [1]:
%load_ext autoreload

# --------------- #
# region: Imports #
import os
import sys
module_path = os.path.abspath('../../..')
if module_path not in sys.path:
  sys.path.insert(0, module_path)
# endregion   #
# --------------- #

In [2]:
from examples.cleanup.env import Cleanup
from examples.cleanup.agents import Agent
from examples.RPG.utils import load_config
from gem.models.grid_cells import positional_embedding
from gem.models.DDQN import doubleDQN
from gem.utils import visual_field_sprite, image_from_array, animate, one_hot_encode

from IPython.display import clear_output

import argparse
import matplotlib.pyplot as plt
import random

import torch
import numpy as np

cfg = load_config(argparse.Namespace(config='../configs/config.yaml'))

N_AGENTS = 1
agents = []
for i in range(N_AGENTS):
  agents.append(
  Agent(cfg, appearance = cfg.agent.agent.appearance, 
        model = doubleDQN(
        input_size=5224,
        number_of_actions=5,
        lr=0.001,
        gamma=0.97,
        per=False,
        alpha=0.6,
        beta=0.05,
        beta_increment=0.0006,
        capacity=5000,
    )
          )
  )




env = Cleanup(
  cfg, agents
)

In [3]:
import jax
import jax.numpy as jnp

seed = 42
rng = jax.random.PRNGKey(seed)

# Generate random numbers
random_numbers = jax.random.uniform(rng, shape=(5,))

In [4]:
cfg.experiment.epochs = 10000 # override the number of epochs

rewards = []
losses = 0
epsilon = .9

EPOCH_PRINT_FREQ = 10
EPSILON_DECAY_RATE = 0.999
EPSILON_DECAY_FREQ = 1
EVAL_EPSILON = 0.5

for epoch in range(cfg.experiment.epochs): # note that the language is not right. epoch is training. episode is the game
    # Reset the environment at the start of each epoch
        env.reset()
        for agent in env.agents:
            agent.reset()
        random.shuffle(agents)

        done = 0 
        turn = 0
        losses = 0
        game_points = 0

        images = []
        if epoch % EPSILON_DECAY_FREQ == 0:
            epsilon = epsilon*EPSILON_DECAY_RATE

        while not done:

            turn = turn + 1

            entities = env.get_entities_for_transition()
            # Entity transition
            for entity in entities:
                entity.transition(env)

            # Agent transition
            for agent in agents:

                location_code = positional_embedding(agent.location, env, 3, 3)
                direction = one_hot_encode(agent.direction, 4)

                # Get current state
                state = np.concatenate([agent.pov(env).flatten(), location_code, direction]).reshape(1, -1)

                if epoch % EPOCH_PRINT_FREQ == 0:
                    _image = visual_field_sprite(env.world)
                    image = image_from_array(_image)
                    images.append(image)

                # Take action based on current state
                if epoch % EPOCH_PRINT_FREQ == 0:
                    action = agent.model.take_action(state, EVAL_EPSILON)
                else:
                    action = agent.model.take_action(state, epsilon)

                (reward,
                next_state,
                done_
                ) = agent.transition(env, state, action)

                if turn >= cfg.experiment.max_turns or done_:
                    done = 1

                exp = (1, (state, action, reward, next_state, done))
                agent.episode_memory.append(exp)
                #TODO: decide on memory update procedures
                agent.model.replay_buffer.add(torch.tensor(state), action, reward, torch.tensor(next_state), done)

                game_points += reward

        rewards.append(game_points)
        
        # At the end of each epoch, train as long as the batch size is large enough.
        if epoch > 10:
            loss = agent.model.train_step(batch_size = 64)
            losses += loss
            
        # Calculate the average of the last 100 rewards
        if len(rewards) >= 100:
            avg_last_100_rewards = sum(rewards[-100:]) / 100
        else:
            avg_last_100_rewards = sum(rewards) / len(rewards)
        if epoch % EPOCH_PRINT_FREQ == 0:
            print(f'Epoch: {epoch} - Epsilon: {epsilon} - Losses {losses} - Avg. last 100 rewards: {avg_last_100_rewards}')
            animate(
                images, filename = f"cleanup_epoch{epoch}", folder = f"{cfg.root}/examples/cleanup/data/"
            )
            losses = 0


Epoch: 0 - Epsilon: 0.8991 - Losses 0 - Avg. last 100 rewards: 0.0
Epoch: 10 - Epsilon: 0.8901493517965845 - Losses 0 - Avg. last 100 rewards: 0.2727272727272727
Epoch: 20 - Epsilon: 0.8812878083682346 - Losses 581132.9375 - Avg. last 100 rewards: 0.9047619047619048
Epoch: 30 - Epsilon: 0.8725144826662403 - Losses 95823.0390625 - Avg. last 100 rewards: 2.0
Epoch: 40 - Epsilon: 0.8638284964725682 - Losses 631.8280639648438 - Avg. last 100 rewards: 2.048780487804878
Epoch: 50 - Epsilon: 0.8552289803119505 - Losses 14682.884765625 - Avg. last 100 rewards: 2.3333333333333335
Epoch: 60 - Epsilon: 0.8467150733648501 - Losses 1084.8265380859375 - Avg. last 100 rewards: 2.2459016393442623
Epoch: 70 - Epsilon: 0.8382859233812912 - Losses 454.34765625 - Avg. last 100 rewards: 2.535211267605634
Epoch: 80 - Epsilon: 0.8299406865955485 - Losses 322.9467468261719 - Avg. last 100 rewards: 2.5308641975308643
Epoch: 90 - Epsilon: 0.8216785276416861 - Losses 749.4376831054688 - Avg. last 100 rewards: 2.