In [1]:
%load_ext autoreload

# --------------- #
# region: Imports #
import os
import sys
module_path = os.path.abspath('../../..')
if module_path not in sys.path:
  sys.path.insert(0, module_path)
# endregion   #
# --------------- #

In [None]:
from examples.cleanup.env import Cleanup
from examples.cleanup.agents import Agent
from examples.RPG.utils import load_config
from agentarium.embedding import positional_embedding
from agentarium.models.DDQN import doubleDQN
from agentarium.models.iqn import iRainbowModel
from agentarium.utils import visual_field_sprite, image_from_array, animate, one_hot_encode

from IPython.display import clear_output
from datetime import datetime

import argparse
import matplotlib.pyplot as plt
import random

import torch
import numpy as np

cfg = load_config(argparse.Namespace(config='../configs/config.yaml'))

seed = random.randint(1,100)

N_AGENTS = 1
agents = []
for i in range(N_AGENTS):
  agents.append(
  Agent(cfg, appearance = cfg.agent.agent.appearance, 
    model = iRainbowModel(
      state_size=[8,11,11],
      action_size=4,
      layer_size=250,
      epsilon=.9,
      device="mps",
      seed=seed,
      num_frames=5,
      n_step=3,
      BATCH_SIZE= 64,
      BUFFER_SIZE= 1024,
      LR=0.00025,
      TAU=.001,
      GAMMA=0.95,
      N=12,
      sync_freq=200,
      model_update_freq=4
    )
          )
  )

# Set up tensorboard logging
if cfg.log:
    log_dir = os.path.abspath(f'../runs/{datetime.now().strftime("%Y%m%d-%H%m%s")}/')
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(
        log_dir=log_dir
    )


env = Cleanup(
  cfg, agents
)

In [48]:
cfg.experiment.epochs = 4000 # override the number of epochs

rewards = []
losses = 0
epsilon = .9

EPOCH_EVAL_FREQ = 50
EPSILON_SCHEDULE = [0.9, 0.7, 0.5, 0.3, 0.1, 0.05, 0.05, 0.05, 0.05]
EVAL_EPSILON = 0.05

for epoch in range(cfg.experiment.epochs): # note that the language is not right. epoch is training. episode is the game
    # Reset the environment at the start of each epoch
        env.reset()
        for agent in env.agents:
            agent.reset()
            agent.init_replay()
            agent.model.start_epoch_action(**locals())
        random.shuffle(agents)

        epsilon = EPSILON_SCHEDULE[(epoch // 500)]

        done = 0 
        turn = 0
        loss = 0
        losses = 0
        game_points = 0

        images = []

        while not done:

            turn = turn + 1

            entities = env.get_entities_for_transition()
            # Entity transition
            for entity in entities:
                entity.transition(env)

            # Agent transition
            for agent in agents:

                print(agent.pov(env).size())
                state = agent.pov_stack(env)

                if epoch % EPOCH_EVAL_FREQ == 0:
                    _image = visual_field_sprite(env.world)
                    image = image_from_array(_image)
                    images.append(image)

                # Take action based on current state
                if epoch % EPOCH_EVAL_FREQ == 0:
                    action = agent.model.take_action(state, EVAL_EPSILON)
                else:
                    action = agent.model.take_action(state, epsilon)

                (reward,
                next_state,
                done_
                ) = agent.transition(env, state, action)

                if turn >= cfg.experiment.max_turns or done_:
                    done = 1

                exp = (1, (state, action, reward, next_state, done))
                agent.episode_memory.append(exp)
                #TODO: decide on memory update procedures
                agent.model.end_epoch_action(**locals())

                game_points += reward

        rewards.append(game_points)
        
        # At the end of each epoch, train as long as the batch size is large enough.
        if epoch > 10:
            loss = agent.model.train_model()
            losses += loss
            
        # Calculate the average of the last 100 rewards
        if len(rewards) >= 100:
            avg_last_100_rewards = sum(rewards[-100:]) / 100
        else:
            avg_last_100_rewards = sum(rewards) / len(rewards)

        if cfg.log:
            writer.add_scalar('Loss', loss, epoch)
            writer.add_scalar('Reward', game_points, epoch)

        if epoch % EPOCH_EVAL_FREQ == 0:
            # print(f'Epoch: {epoch} - Epsilon: {epsilon} - Losses {losses} - Avg. last 100 rewards: {avg_last_100_rewards}')
            animate(
                images, filename = f"cleanup_epoch{epoch}", folder = f"{cfg.root}/examples/cleanup/data/"
            )
            losses = 0


torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])
torch.Size([968])


KeyboardInterrupt: 

In [4]:
for agent in agents:
  agent.model.save(f"../data/models/{datetime.now().strftime("%Y%m%d-%H%m%s")}")

In [33]:
%autoreload 2
from gem.models.DDQN import ClaasyReplayBuffer as Buffer
import numpy as np

buffer = Buffer(
  capacity=100,
  obs_shape=(100, )
)

for i in range(20):
  buffer.add(
    obs=np.zeros(shape=(100)),
    action=0,
    reward=0,
    done=False
  )
buffer.add(
    obs=np.zeros(shape=(100)),
    action=0,
    reward=0,
    done=True
)


# S, A, R, S', D, valid
tuple( buffer.sample(2, stacked_frames=3)[i][0] for i in range(6) )

(array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,