In [None]:
# import modules
import copy
import os
import sys
import torch
import numpy as np

# --------------- #
# region: Imports #
import os
import sys

module_path = os.path.abspath("../../..")
if module_path not in sys.path:
    sys.path.insert(0, module_path)
# endregion       #
# --------------- #

from examples.state_punishment.utils import (
    init_log,
    parse_args,
    load_config,
    create_models,
    create_agents,
    create_entities,
)
from examples.state_punishment import agents, entities
from examples.state_punishment.env import state_punishment
from examples.state_punishment.utils import inspect_the_env
from agentarium.logging_utils import GameLogger
from agentarium.models import human_player
from agentarium.utils import visual_field_sprite, image_from_array
from examples.state_punishment.state_sys import state_sys, Monitor
from IPython.display import clear_output
from matplotlib import pyplot as plt
import random
import argparse
import time
from copy import deepcopy

In [None]:
# load cfg
config_path = "../configs/config_fixed_rate_no_vote.yaml"

cfg = load_config(argparse.Namespace(config=config_path))

In [None]:
# initialize models and envs

models = create_models(cfg)
agents = create_agents(cfg, models)
entities = create_entities(cfg)
envs = []
for i in range(len(agents)):
    envs.append(state_punishment(cfg, [agents[i]], deepcopy(entities)))


for ixs, agent in enumerate(agents):
    agent.model = human_player.ModelHumanPlayer(
        action_space=8,
        state_size=cfg.model.iqn.parameters.state_size,
        extra_percept_size=cfg.model.iqn.parameters.extra_percept_size,
        memory_size=1,
        name=f"human {ixs}",
    )
    agent.model.epsilon = 0.01
    agent.model.load(
        f"../models/checkpoints/fixed_punishment_rate_0.75_twoAs_extra_percept_v2_higher_harm_gem_has_value_save_model_agent{ixs}_iRainbowModel_20241127-04111732699956.pkl"
    )
    # agent.model.load(
    #      f'../models/checkpoints/fixed_punishment_rate_0.75_twoAs_extra_percept_v2_higher_harm_gem_has_value_save_model_agent{ixs}_iRainbowModel_20241127-04111732699956.pkl')
    # agent.model.load(
    #      '../models/checkpoints/fixed_punishment_rate_1.0_oneAs_size15_init_spawn_0.2_agent0_iRainbowModel.pkl'
    # )
    # agent.model.epsilon = 0.01

cfg.state_sys.prob_list = {
    "Gem": cfg.state_sys.prob_list.Gem,
    "Coin": cfg.state_sys.prob_list.Coin,
    "Bone": cfg.state_sys.prob_list.Bone,
}
# initialize state system
state_entity = state_sys(
    cfg.state_sys.init_prob,
    cfg.state_sys.prob_list,
    cfg.state_sys.magnitude,
    cfg.state_sys.taboo,
    cfg.state_sys.change_per_vote,
)
state_entity.prob = ...


done = 0
turn = 0
losses = 0
game_points = [0 for _ in range(len(agents))]
# data collection
env_templates = []

# place entities

# clear the world
for env in envs:
    env.clear_world()

# place entities in the world
envs[0].world[7, 7, 0] = [entity for entity in entities if entity.type == "gem"][0]

# add env to templates
env_templates.append(copy.deepcopy(envs[0]))

In [None]:
# define the target locs to place the objects
target_locs = []

In [None]:
# parameters
world_size = env_templates[0].height
num_reps = 10

v_heatmap_lst = [[] for _ in range(len(models))]

# main
for agent in agents:
    for env in env_templates:

        # find the empty locations
        empty_locs = []
        for i in range(env.world.shape[0]):
            for j in range(env.world.shape[1]):
                if str(env.world[i, j, 0]) == "EmptyObject":
                    empty_locs.append((i, j, 0))

        # create an empty template for the heatmap
        v_heatmap = np.empty((world_size, world_size))
        v_heatmap[:] = np.nan

        # iteratively place the agent on every empty location in the env,
        # then record the model's output (Q/V/...), and store it in the
        # heatmap template

        if target_locs is not None:
            available_locs = target_locs
        else:
            available_locs = empty_locs

        for loc in available_locs:
            env_ = copy.deepcopy(env)
            # place the agent
            env_.world[loc] = agent
            agent.location = loc
            # generate the agent's state
            model_input = agent.get_model_input(
                env_, state_sys=state_entity, state_is_composite=True, envs=envs
            )

            val_estimation_trials = []

            # this loop is tailored for the model used in my experiment
            for _ in range(num_reps):

                action, Qs = agent.model.take_action(model_input, eval=True)
                max_q = float(torch.max(Qs))
                v = float(torch.mean(Qs, dim=0))
                val_estimation_trials.append(v)

            # store the value of interest in the heatmap
            v_m = np.mean(val_estimation_trials)
            v_heatmap[loc[0], loc[1]] = v_m

        v_heatmap_lst.append(v_heatmap)