In [1]:
%load_ext autoreload

# --------------- #
# region: Imports #
import os
import sys
module_path = os.path.abspath('../../..')
if module_path not in sys.path:
  sys.path.insert(0, module_path)
# endregion   #
# --------------- #

In [2]:
from examples.cleanup.env import Cleanup
from examples.cleanup.agents import Agent
from examples.RPG.utils import load_config
from gem.models.grid_cells import positional_embedding
from gem.models.DDQN import doubleDQN
from gem.utils import visual_field, visual_field_multilayer

from IPython.display import clear_output

import argparse
import matplotlib.pyplot as plt
import random

import torch
import numpy as np

cfg = load_config(argparse.Namespace(config='../configs/config.yaml'))

N_AGENTS = 3
agents = []
for i in range(N_AGENTS):
  agents.append(
  Agent(cfg, appearance = cfg.agent.agent.appearance, 
        model = doubleDQN(
        input_size=5220,
        number_of_actions=4,
        lr=0.001,
        gamma=0.97,
        per=False,
        alpha=0.6,
        beta=0.05,
        beta_increment=0.0006,
        capacity=5000,
    )
          )
  )




env = Cleanup(
  cfg, agents
)

In [3]:
import jax
import jax.numpy as jnp

seed = 42
rng = jax.random.PRNGKey(seed)

# Generate random numbers
random_numbers = jax.random.uniform(rng, shape=(5,))

In [4]:
cfg.experiment.epochs = 10000 # override the number of epochs

rewards = []
losses = 0
epsilon = .9

for epoch in range(cfg.experiment.epochs): # note that the language is not right. epoch is training. episode is the game
    # Reset the environment at the start of each epoch
        for agent in env.agents:
            agent.reset()
        random.shuffle(agents)

        done = 0 
        turn = 0
        losses = 0
        game_points = 0



        while not done:

            turn = turn + 1

            entities = env.get_entities_for_transition()
            # Entity transition
            for entity in entities:
                entity.transition(env)

            # Agent transition
            for agent in agents:
                if (epoch+1) % 100 == 0:
                    epsilon = epsilon*0.9999

                location_code = positional_embedding(agent.location, env, 3, 3)

                # Get current state
                state = np.concatenate([agent.pov(env).flatten(), location_code]).reshape(1, -1)

                # Take action based on current state
                action = agent.model.take_action(state, epsilon)

                (reward,
                next_state,
                done_
                ) = agent.transition(env, state, action)

                if turn >= cfg.experiment.max_turns or done_:
                    done = 1

                exp = (1, (state, action, reward, next_state, done))
                agent.episode_memory.append(exp)
                #TODO: decide on memory update procedures
                agent.model.replay_buffer.add(torch.tensor(state), action, reward, torch.tensor(next_state), done)

                game_points += reward

        rewards.append(game_points)
        
        # At the end of each epoch, train as long as the batch size is large enough.
        if epoch > 10:
            loss = agent.model.train_step(batch_size = 64)
            losses += loss
            
        # Calculate the average of the last 100 rewards
        if len(rewards) >= 100:
            avg_last_100_rewards = sum(rewards[-100:]) / 100
        else:
            avg_last_100_rewards = sum(rewards) / len(rewards)
        if epoch % 150 == 0:
            print(f'Epoch: {epoch} - Epsilon: {epsilon} - Losses {losses} - Avg. last 100 rewards: {avg_last_100_rewards}')
            losses = 0


Epoch: 0 - Epsilon: 0.9 - Losses 0 - Avg. last 100 rewards: 27.0
