In [None]:
from lle import LLE, WorldState, World
import json
import marl
import marlenv
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import defaultdict
from marl.other.local_graph import LocalGraphTrainer, LocalGraphBottleneckFinder

In [None]:
def get_bottlenecks_stats(finder: LocalGraphBottleneckFinder[WorldState]):
    """Project the edges on the 2D plane"""
    vertex_bottleneck_scores = defaultdict[tuple[int, int], float](int)
    edge_bottleneck_scores = defaultdict[tuple[WorldState, WorldState], float](int)
    for edge, hit_count in finder.hit_count.items():
        score = finder.predict(edge)
        edge_bottleneck_scores[edge] += score
        start, end = edge
        vertices = start.agents_positions + end.agents_positions
        for vertex in vertices:
            vertex_bottleneck_scores[vertex] += score
    return vertex_bottleneck_scores, edge_bottleneck_scores


def compute_heatmap(world: World, finder: LocalGraphBottleneckFinder[WorldState]):
    """Compute the heatmap of the bottlenecks"""
    vertex_scores, _ = get_bottlenecks_stats(finder)
    heatmap = np.zeros((world.height, world.width))
    for (i, j), score in vertex_scores.items():
        heatmap[i, j] = score
    return heatmap

In [None]:
N_SEEDS = 2
N_STEPS = 100

for train_dqn in [False, True]:
    for n_agents in range(1, 5):
        map_name = f"../maps/subgraph-{n_agents}agents.toml"
        env = LLE.from_file(map_name).obs_type("layered").single_objective()
        world = env.world
        env = marlenv.Builder(env).agent_id().build()
        
        total_heatmap = np.zeros((world.height, world.width), dtype=np.float32)
        for seed in range(N_SEEDS):
            print(f"Running seed {seed+1}/{N_SEEDS}")
            finder = LocalGraphBottleneckFinder()
            if train_dqn:
                qnetwork = marl.nn.model_bank.CNN.from_env(env)
                policy = marl.policy.EpsilonGreedy.linear(1.0, 0.05, 100_000)
                algo = marl.algo.DQN(qnetwork, policy)
                dqn_trainer = marl.training.DQNTrainer(
                    qnetwork,
                    policy,
                    marl.models.TransitionMemory(10_000),
                    mixer=marl.algo.VDN.from_env(env),
                    lr=1e-4,
                    gamma=0.95,
                    train_interval=(100, "step"),
                )
                trainer = LocalGraphTrainer(finder, world, dqn_trainer)
                trainer = trainer.to(marl.utils.get_device())
            else:
                algo = None
                trainer = LocalGraphTrainer(finder, world, None)
            exp = marl.Experiment.create(logdir="logs/test", trainer=trainer, algo=algo, n_steps=N_STEPS, test_interval=0, env=env)
            exp.run()
            
            heatmap = compute_heatmap(world, finder)
            total_heatmap += heatmap
            filename = "heatmap-dqn" if train_dqn else "heatmap"
            filename += f"-{n_agents}_agents-{seed}.json"
            with open(filename, "w") as f:
                json.dump(heatmap.tolist(), f)
        sns.heatmap(total_heatmap)
        img_filename = "heatmap-dqn" if train_dqn else "heatmap"
        img_filename += f"-{n_agents}_agents.png"
        plt.savefig(img_filename)
        plt.clf()

In [1]:
import json
import numpy as np

res = None
for i in range(10):
    filename = f"logs/1-agents/heatmap-{i}.json"
    with open(filename, "r") as f:
        heatmap = np.array(json.load(f))
    if res is None:
        res = heatmap
    else:
        res += heatmap
sns.heatmap(res)


FileNotFoundError: [Errno 2] No such file or directory: 'logs/1-agents/heatmap-0.json'