In [88]:
import os
import random
import pickle
from typing import Tuple, Union
import warnings
warnings.simplefilter("ignore", UserWarning)
from tqdm import tqdm
import numpy as np
import torch
import torch.multiprocessing as mp
from environment import Environment
from model import Network
import config
import copy
import ollama
from openai import OpenAI
import json

os.environ["OPENAI_API_KEY"] = "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"
client = OpenAI()

In [89]:
# 함수들
directiondict = {
    'stay': 4, 'north': 0, 'south': 1, 'west': 2, 'east': 3
}
reverse_directiondict = {v: k for k, v in directiondict.items()}

def get_possible_directions(obs, obs_agents, agent_idx, agents_not_exchangeable, agents_fixed):
    directions = []
    directions_pushed_agents = []
    if obs[0][agent_idx][1][3, 4] == 0:
        directions.append('north')
    if obs[0][agent_idx][1][5, 4] == 0:
        directions.append('south')
    if obs[0][agent_idx][1][4, 3] == 0:
        directions.append('west')
    if obs[0][agent_idx][1][4, 5] == 0:
        directions.append('east')

    direction_conditions = [
        ('north', obs_agents[agent_idx][3, 4] - 1),
        ('south', obs_agents[agent_idx][5, 4] - 1),
        ('west', obs_agents[agent_idx][4, 3] - 1),
        ('east', obs_agents[agent_idx][4, 5] - 1)
    ]

    for direction, agent_value in direction_conditions:
        if agent_value in agents_not_exchangeable or agent_value in agents_fixed:
            if direction in directions:
                directions.remove(direction)

    for direction, agent_value in direction_conditions:
        if direction in directions:
            directions_pushed_agents.append((direction, None if agent_value == -1 else agent_value))

    return directions_pushed_agents


def get_possible_directions_super(obs, obs_agents, agent_idx, agents_fixed):
    directions = []
    directions_pushed_agents = []
    if obs[0][agent_idx][2][4, 4] == 1:
        directions.append('north')
    if obs[0][agent_idx][3][4, 4] == 1:
        directions.append('south')
    if obs[0][agent_idx][4][4, 4] == 1:
        directions.append('west')
    if obs[0][agent_idx][5][4, 4] == 1:
        directions.append('east')

    direction_conditions = [
        ('north', obs_agents[agent_idx][3, 4] - 1),
        ('south', obs_agents[agent_idx][5, 4] - 1),
        ('west', obs_agents[agent_idx][4, 3] - 1),
        ('east', obs_agents[agent_idx][4, 5] - 1)
    ]

    for direction, agent_value in direction_conditions:
        if agent_value in agents_fixed:
            if direction in directions:
                directions.remove(direction)

    for direction, agent_value in direction_conditions:
        if direction in directions:
            directions_pushed_agents.append((direction, None if agent_value == -1 else agent_value))

    return directions_pushed_agents


def push_recursive(obs, obs_agents, agent_super, agents_fixed):
    relayed_actions = []
    agents_not_exchangeable = []

    current_agent = agent_super
    depth = 0

    # 스택에 (현재 에이전트, 남은 방향들, depth) 저장
    stack = []

    while True:
        # 가능한 방향들 계산
        if current_agent == agent_super:
            possible_directions = get_possible_directions_super(obs, obs_agents, current_agent, agents_fixed)
        else:
            possible_directions = get_possible_directions(obs, obs_agents, current_agent, agents_not_exchangeable, agents_fixed)

        while not possible_directions:
            if not stack:
                # 백트래킹할 곳이 없으면 종료
                text = 'nope'
                return text

            # 스택에서 이전 상태로 백트래킹
            last_agent, last_possible_directions, last_depth = stack.pop()

            # 남은 방향이 있다면 그 중 하나를 선택하고 진행
            if last_possible_directions:
                relayed_actions = relayed_actions[:last_depth]  # 이전 선택을 지우고 다시 선택
                current_agent = last_agent
                possible_directions = last_possible_directions
                depth = last_depth
            else:
                # 백트래킹할 방향이 없으면 계속 백트래킹
                possible_directions = []

        # 랜덤으로 가능한 방향 중 하나 선택
        choosen_action = random.choice(possible_directions)
        possible_directions.remove(choosen_action)

        relayed_actions.append((current_agent, choosen_action[0]))

        # 더 이상 밀 에이전트가 없으면 종료
        if choosen_action[1] is None:
            break

        # depth에 따른 agents_not_exchangeable 처리
        if depth == 1:
            agents_not_exchangeable = []
        agents_not_exchangeable.append(current_agent)

        # 스택에 현재 상태를 저장
        stack.append((current_agent, possible_directions, depth))

        # 다음 에이전트를 선택하고 루프를 계속
        current_agent = choosen_action[1]
        depth += 1

    return relayed_actions


def get_possible_directions_radiation(obs, obs_agents, center_coordinates, agent_idx, agents_fixed):
    directions = []
    directions_pushed_agents = []
    if obs[0][agent_idx][1][3, 4] == 0:
        directions.append('north')
    if obs[0][agent_idx][1][5, 4] == 0:
        directions.append('south')
    if obs[0][agent_idx][1][4, 3] == 0:
        directions.append('west')
    if obs[0][agent_idx][1][4, 5] == 0:
        directions.append('east')
    
    row_diff = center_coordinates[0] - obs[2][agent_idx][0]
    col_diff = center_coordinates[1] - obs[2][agent_idx][1]

    if row_diff < 0:  # 에이전트가 중앙보다 아래에 있으면 북쪽으로 이동 불가
        if 'north' in directions:
            directions.remove('north')
    elif row_diff > 0:  # 에이전트가 중앙보다 위에 있으면 남쪽으로 이동 불가
        if 'south' in directions:
            directions.remove('south')
    if col_diff < 0:  # 에이전트가 중앙보다 오른쪽에 있으면 서쪽으로 이동 불가
        if 'west' in directions:
            directions.remove('west')
    elif col_diff > 0:  # 에이전트가 중앙보다 왼쪽에 있으면 동쪽으로 이동 불가
        if 'east' in directions:
            directions.remove('east')

    direction_conditions = [
        ('north', obs_agents[agent_idx][3, 4] - 1),
        ('south', obs_agents[agent_idx][5, 4] - 1),
        ('west', obs_agents[agent_idx][4, 3] - 1),
        ('east', obs_agents[agent_idx][4, 5] - 1)
    ]

    for direction, agent_value in direction_conditions:
        if agent_value in agents_fixed:
            if direction in directions:
                directions.remove(direction)

    for direction, agent_value in direction_conditions:
        if direction in directions:
            directions_pushed_agents.append((direction, None if agent_value == -1 else agent_value))

    return directions_pushed_agents


def push_recursive_radiation(obs, obs_agents, center_coordinates, agent_idx, agents_fixed):

    relayed_actions = []
    agents_not_exchangeable = []
    
    current_agent = agent_idx
    depth = 0

    # 스택에 (현재 에이전트, 남은 방향들, depth) 저장
    stack = []

    while True:
        # 가능한 방향들 계산
        if current_agent == agent_idx:
            possible_directions = get_possible_directions_radiation(obs, obs_agents, center_coordinates, current_agent, agents_fixed)
        else:
            possible_directions = get_possible_directions(obs, obs_agents, current_agent, agents_not_exchangeable, agents_fixed)

        while not possible_directions:
            if not stack:
                # 백트래킹할 곳이 없으면 종료'
                return [(agent_idx, 'stay')]

            # 스택에서 이전 상태로 백트래킹
            last_agent, last_possible_directions, last_depth = stack.pop()

            # 남은 방향이 있다면 그 중 하나를 선택하고 진행
            if last_possible_directions:
                relayed_actions = relayed_actions[:last_depth]  # 이전 선택을 지우고 다시 선택
                current_agent = last_agent
                possible_directions = last_possible_directions
                depth = last_depth
            else:
                # 백트래킹할 방향이 없으면 계속 백트래킹
                possible_directions = []

        # 랜덤으로 가능한 방향 중 하나 선택
        choosen_action = random.choice(possible_directions)
        possible_directions.remove(choosen_action)

        relayed_actions.append((current_agent, choosen_action[0]))

        # 더 이상 밀 에이전트가 없으면 종료
        if choosen_action[1] is None:
            break

        # depth에 따른 agents_not_exchangeable 처리
        if depth == 1:
            agents_not_exchangeable = []
        agents_not_exchangeable.append(current_agent)

        # 스택에 현재 상태를 저장
        stack.append((current_agent, possible_directions, depth))

        # 다음 에이전트를 선택하고 루프를 계속
        current_agent = choosen_action[1]
        depth += 1

    return relayed_actions


def get_possible_directions_not_deadlock(obs, obs_agents, agent_idx, agent_action, agents_fixed):
    directions = [agent_action]
    directions_pushed_agents = []

    direction_conditions = [
        ('north', obs_agents[agent_idx][3, 4] - 1),
        ('south', obs_agents[agent_idx][5, 4] - 1),
        ('west', obs_agents[agent_idx][4, 3] - 1),
        ('east', obs_agents[agent_idx][4, 5] - 1)
    ]

    for direction, agent_value in direction_conditions:
        if agent_value in agents_fixed:
            if direction in directions:
                directions.remove(direction)

    for direction, agent_value in direction_conditions:
        if direction in directions:
            directions_pushed_agents.append((direction, None if agent_value == -1 else agent_value))

    return directions_pushed_agents


def push_recursive_not_deadlock(obs, obs_agents, agent_idx, agent_action, agents_fixed):

    relayed_actions = []
    agents_not_exchangeable = []
    
    current_agent = agent_idx
    depth = 0

    # 스택에 (현재 에이전트, 남은 방향들, depth) 저장
    stack = []

    while True:
        # 가능한 방향들 계산
        if current_agent == agent_idx:
            possible_directions = get_possible_directions_not_deadlock(obs, obs_agents, agent_idx, agent_action, agents_fixed)
        else:
            possible_directions = get_possible_directions(obs, obs_agents, current_agent, agents_not_exchangeable, agents_fixed)

        while not possible_directions:
            if not stack:
                # 백트래킹할 곳이 없으면 종료'
                return [(agent_idx, 'stay')]

            # 스택에서 이전 상태로 백트래킹
            last_agent, last_possible_directions, last_depth = stack.pop()

            # 남은 방향이 있다면 그 중 하나를 선택하고 진행
            if last_possible_directions:
                relayed_actions = relayed_actions[:last_depth]  # 이전 선택을 지우고 다시 선택
                current_agent = last_agent
                possible_directions = last_possible_directions
                depth = last_depth
            else:
                # 백트래킹할 방향이 없으면 계속 백트래킹
                possible_directions = []

        # 랜덤으로 가능한 방향 중 하나 선택
        choosen_action = random.choice(possible_directions)
        possible_directions.remove(choosen_action)

        relayed_actions.append((current_agent, choosen_action[0]))

        # 더 이상 밀 에이전트가 없으면 종료
        if choosen_action[1] is None:
            break

        # depth에 따른 agents_not_exchangeable 처리
        if depth == 1:
            agents_not_exchangeable = []
        agents_not_exchangeable.append(current_agent)

        # 스택에 현재 상태를 저장
        stack.append((current_agent, possible_directions, depth))

        # 다음 에이전트를 선택하고 루프를 계속
        current_agent = choosen_action[1]
        depth += 1

    return relayed_actions


def get_sorted_agents(agent_groups, env):
    super_agents = []
    for set_of_agents in agent_groups:
        agent_super = max(set_of_agents, key=lambda i: np.sum(np.abs(env.agents_pos[i] - env.goals_pos[i])))
        super_agents.append(agent_super)

    # 각 에이전트와 목표 사이의 거리 계산
    agent_distances = [(agent, np.sum(np.abs(env.agents_pos[agent] - env.goals_pos[agent]))) for agent in super_agents]

    # 거리를 기준으로 내림차순 정렬
    sorted_agents = sorted(agent_distances, key=lambda x: x[1], reverse=True)

    # 정렬된 에이전트 ID 추출
    sorted_agent_groups = [agent for agent, distance in sorted_agents]
    return sorted_agent_groups

In [90]:
# 프롬프트
class gpt4pathfinding:
    def detection(self, agents_state):
        response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "당신은 MAPF 문제에서 에이전트들의 데드락 여부 감지를 위해 호출된 관리자입니다. 당신은 에이전트들의 움직임을 통해 에이전트들이 각각 어떤 상태에 있는지 추론할 수 있는 능력을 가지고 있습니다."},
            {"role": "user", "content":
                f"""
                당신은 에이전트들의 32번간의 action log를 통해 에이전트들의 데드락 여부를 확인하고, 만약 데드락이면 서로 얽힌 가능성이 높은 에이전트들을 묶고, 데드락에 대한 해결 방법을 제시해야 합니다.
                
                데드락으로 분류되는 상태들은 다음과 같습니다:
                - not arrived인 상태로 이동하지 않는 상태
                - not arrived인 상태로 32번의 행동을 보았을 때 의미있는 좌표의 변화를 나타내지 못하고 배회하는 상태

                데드락으로 분류되지 않는 상태들은 다음과 같습니다:
                - not arrived였지만 특정 시점에 arrived가 되어 계속 멈춰있는 상태
                - 계속 not arrived이지만 일관성 있는 좌표의 변화를 나타내며 이동하는 상태

                만약 데드락인 에이전트들이 서로 매우 가까이 있다면 서로 얽혔을 가능성이 높으므로, 그들을 묶으세요/

                만약 에이전트가 독립으로 데드락이라면, "prime" 방법을 이용합니다.
                서로 얽힌 에이전트들 중, goal이 현재 위치에서 8 이상 떨어져 있는 에이전트가 있다면, "prime" 방법을 이용합니다. 에이전트들의 거리가 가깝더라도, 거리가 먼 goal로 한 에이전트를 이동시켜서 문제를 단순화할 수 있다면, "prime" 방법을 이용합니다.
                서로 얽힌 에이전트들 중, goal도 모두 현재 위치에 가까이 위치한다면, "radiation" 방법을 이용합니다.

                아래는 각 에이전트의 32번의 action log입니다.

                {agents_state}

                설명을 생성하지 말고, 아래 형식의 JSON 결과값만을 반환하세요.

                그 다음, 각 에이전트의 상태를 JSON 형식으로 반환하세요. 각 에이전트의 상태는 다음 형식을 따라야 합니다:
                {{
                    "agent_id": [서로 얽힌 가능성이 높은 <에이전트 ID>들],
                    "deadlock": "yes" 또는 "no"
                    "solusion": "prime" 또는 "radiation"
                }}

                결과값에서 중복되는 에이전트가 있어서는 안 됩니다.
                "deadlock" 상태가 "no"인 경우에는 "solution"이 필요하지 않습니다.
                
                예시:
                [
                    {{"agent_id": [1, 24, 32], "deadlock": "yes", "solusion": "prime"}},
                    {{"agent_id": [30], "deadlock": "yes", "solusion": "prime"}},
                    {{"agent_id": [4, 5], "deadlock": "yes", "solusion": "radiation"}},
                    {{"agent_id": [16], "deadlock": "no"}}
                    {{"agent_id": [20], "deadlock": "no"}}
                ]
                """
            }],
        )
        return response.choices[0].message.content
    
pathfinder = gpt4pathfinding()

In [91]:
torch.manual_seed(config.test_seed)
np.random.seed(config.test_seed)
random.seed(config.test_seed)
DEVICE = torch.device('cpu')
torch.set_num_threads(1)

In [92]:
def create_test(test_env_settings: Tuple = config.test_env_settings, num_test_cases: int = config.num_test_cases):
    '''
    create test set
    '''

    for map_length, num_agents, density in test_env_settings:

        name = f'./test_set/{map_length}length_{num_agents}agents_{density}density.pth'
        print(f'-----{map_length}length {num_agents}agents {density}density-----')

        tests = []

        env = Environment(fix_density=density, num_agents=num_agents, map_length=map_length)

        for _ in tqdm(range(num_test_cases)):
            tests.append((np.copy(env.map), np.copy(env.agents_pos), np.copy(env.goals_pos)))
            env.reset(num_agents=num_agents, map_length=map_length)
        print()

        with open(name, 'wb') as f:
            pickle.dump(tests, f)

In [93]:
def code_test():
    env = Environment()
    network = Network()
    network.eval()
    obs, last_act, pos = env.observe()
    network.step(torch.as_tensor(obs.astype(np.float32)).to(DEVICE), 
                                                    torch.as_tensor(last_act.astype(np.float32)).to(DEVICE), 
                                                    torch.as_tensor(pos.astype(int)))

In [94]:
def test_model(model_range: Union[int, tuple], test_set=config.test_env_settings):
    '''
    test model in 'saved_models' folder
    '''
    network = Network()
    network.eval()
    network.to(DEVICE)

    pool = mp.Pool(mp.cpu_count()//2)

    if isinstance(model_range, int):
        state_dict = torch.load(os.path.join(config.save_path, f'{model_range}.pth'), map_location=DEVICE)
        network.load_state_dict(state_dict)
        network.eval()
        network.share_memory()

        
        print(f'----------test model {model_range}----------')

        instance_id = 0

        for case in test_set:
            print(f"test set: {case[0]} env {case[1]} agents")
            with open('./test_set/{}_{}agents.pth'.format(case[0], case[1]), 'rb') as f:
                tests = pickle.load(f)

            test = tests[0]
            ret = test_one_case((test, network, instance_id))

            success, steps, num_comm = ret

            # instance_id_base = instance_id
            # tests = [(test, network, instance_id_base + i) for i, test in enumerate(tests)]
            # ret = pool.map(test_one_case, tests)

            # success, steps, num_comm = zip(*ret)

            # print("success rate: {:.2f}%".format(sum(success)/len(success)*100))
            # print("average step: {}".format(sum(steps)/len(steps)))
            # print("communication times: {}".format(sum(num_comm)/len(num_comm)))
            # print()

            instance_id += len(tests)

    elif isinstance(model_range, tuple):

        for model_name in range(model_range[0], model_range[1]+1, config.save_interval):
            state_dict = torch.load(os.path.join(config.save_path, f'{model_name}.pth'), map_location=DEVICE)
            network.load_state_dict(state_dict)
            network.eval()
            network.share_memory()


            print(f'----------test model {model_name}----------')

            instance_id = 0

            for case in test_set:
                print(f"test set: {case[0]} length {case[1]} agents {case[2]} density")
                with open(f'./test_set/{case[0]}length_{case[1]}agents_{case[2]}density.pth', 'rb') as f:
                    tests = pickle.load(f)

                test = tests[0]
                ret = test_one_case((test, network, instance_id))

                success, steps, num_comm = ret

                # instance_id_base = instance_id
                # tests = [(test, network, instance_id_base + i) for i, test in enumerate(tests)]
                # ret = pool.map(test_one_case, tests)

                # success, steps, num_comm = zip(*ret)

                # print("success rate: {:.2f}%".format(sum(success)/len(success)*100))
                # print("average step: {}".format(sum(steps)/len(steps)))
                # print("communication times: {}".format(sum(num_comm)/len(num_comm)))
                # print()

                instance_id += 1

            print('\n')

In [95]:
def test_one_case(args):

    env_set, network, instance_id = args

    env = Environment()
    env.load(np.array(env_set[0]), np.array(env_set[1]), np.array(env_set[2]))
    obs, last_act, pos = env.observe()
    
    done = False
    network.reset()

    num_agents = len(env_set[1])

    step = 0
    num_comm = 0

    while not done and env.steps < config.max_episode_length // 2:
        actions, _, _, _, comm_mask = network.step(torch.as_tensor(obs.astype(np.float32)).to(DEVICE), 
                                                    torch.as_tensor(last_act.astype(np.float32)).to(DEVICE), 
                                                    torch.as_tensor(pos.astype(int)))
        (obs, last_act, pos), _, done, _ = env.step(actions)
        env.save_frame(step, instance_id)
        step += 1
        num_comm += np.sum(comm_mask)

    while not done and env.steps < config.max_episode_length:
        env_copy = copy.deepcopy(env)
        plan = []
        not_arrived = set()
        sim_obs, sim_last_act, sim_pos = env_copy.observe()
        for _ in range(32):
            if env_copy.steps >= config.max_episode_length:
                break
            actions, _, _, _, comm_mask = network.step(torch.as_tensor(sim_obs.astype(np.float32)).to(DEVICE), 
                                                        torch.as_tensor(sim_last_act.astype(np.float32)).to(DEVICE), 
                                                        torch.as_tensor(sim_pos.astype(int)))
            plan.append((actions, comm_mask, copy.deepcopy(sim_pos)))
            (sim_obs, sim_last_act, sim_pos), _, sim_done, _ = env_copy.step(actions)
            for i in range(num_agents):
                if not np.array_equal(env_copy.agents_pos[i], env_copy.goals_pos[i]):
                    not_arrived.add(i)
        planned_steps_dict = {i: [] for i in not_arrived}
        goal_logged = {i: False for i in not_arrived}
        for i in plan:
            actions, _, positions = i
            for agent_idx in not_arrived:
                position = positions[agent_idx]
                # 목표 위치와 현재 위치를 비교하여 도달 여부 판단
                arrived_status = "Arrived" if np.array_equal(position, env.goals_pos[agent_idx]) else "Not arrived"
                direction = reverse_directiondict.get(actions[agent_idx], 'unknown')
                if not goal_logged[agent_idx]:
                    planned_steps_dict[agent_idx].append(
                        f"(Action: {direction}, Position: [{position[0]}, {position[1]}], {arrived_status})"
                    )
                    goal_logged[agent_idx] = True  # 목표 위치 기록을 완료한 플래그 설정
                else:
                    planned_steps_dict[agent_idx].append(
                        f"(Action: {direction}, Position: [{position[0]}, {position[1]}], {arrived_status})"
                    )
        # 결과를 n: [], m: [], o: [] 형식으로 출력
        agents_state = ""
        for agent_idx in planned_steps_dict:
            agent_goal = f" (Goal: [{env.goals_pos[agent_idx][0]}, {env.goals_pos[agent_idx][1]}])"
            agent_log = ", ".join(planned_steps_dict[agent_idx])
            agents_state += f"Agent {agent_idx}{agent_goal}: {agent_log}\n"
        print(agents_state)
        gpt4_response = pathfinder.detection(agents_state)
        response_text = gpt4_response
        print(response_text)
        try:
            start_idx = response_text.index('[')
            end_idx = response_text.rindex(']') + 1
            json_part = response_text[start_idx:end_idx]
            json_data = json.loads(json_part)
            print("Extracted JSON:", json_data)
        except:
            print("JSON 부분을 찾을 수 없으므로 deadlock이 없다고 가정합니다.")
            json_data = []

        deadlock_exists = any(item['deadlock'] == 'yes' for item in json_data)
        
        if not deadlock_exists:
            for actions, comm_mask, _ in plan:
                if env.steps >= config.max_episode_length:
                    break
                (obs, last_act, pos), _, done, _ = env.step(actions)
                env.save_frame(step, instance_id)
                step += 1
                num_comm += np.sum(comm_mask)
        else:
            prime_agents = [item['agent_id'] for item in json_data if item.get('deadlock') == 'yes' and item.get('solution') == 'prime']
            radiation_agents = [item['agent_id'] for item in json_data if item.get('deadlock') == 'yes' and item.get('solution') == 'radiation']
            no_deadlock_agents = [item['agent_id'] for item in json_data if item.get('deadlock') == 'no']
            sorted_prime_agents = get_sorted_agents(prime_agents, env)
            sorted_no_deadlock_agents = get_sorted_agents(no_deadlock_agents, env)

            for _ in range(16):
                if env.steps >= config.max_episode_length:
                    break
                obs_agents = env.observe_agents()
                observation = env.observe()

                manual_actions = [4 for _ in range(num_agents)]
                ml_planned_actions, _, _, _, _ = network.step(torch.as_tensor(obs.astype(np.float32)).to(DEVICE), 
                                                    torch.as_tensor(last_act.astype(np.float32)).to(DEVICE), 
                                                    torch.as_tensor(pos.astype(int)))
                
                fixed_agents = []
                for super_agent in sorted_prime_agents:
                    for relayed_action in push_recursive(observation, obs_agents, super_agent, fixed_agents):
                        manual_actions[relayed_action[0]] = directiondict[relayed_action[1]]
                        fixed_agents.append(relayed_action[0])

                for set_of_agents in radiation_agents:
                    x_values = []
                    y_values = []
                    for agent_idx in set_of_agents:
                        x_values.append(observation[2][agent_idx][0])
                        y_values.append(observation[2][agent_idx][1])
                    avg_x = sum(x_values) / len(x_values)
                    avg_y = sum(y_values) / len(y_values)
                    average_position = (avg_x, avg_y)

                    for radiation_agent in set_of_agents:
                        for relayed_action in push_recursive_radiation(observation, obs_agents, average_position, radiation_agent, fixed_agents):
                            manual_actions[relayed_action[0]] = directiondict[relayed_action[1]]
                            fixed_agents.append(relayed_action[0])
                
                for no_deadlock_agent in sorted_no_deadlock_agents:
                    for relayed_action in push_recursive_not_deadlock(observation, obs_agents, no_deadlock_agent, reverse_directiondict[ml_planned_actions[no_deadlock_agent]], fixed_agents):
                        manual_actions[relayed_action[0]] = directiondict[relayed_action[1]]
                        fixed_agents.append(relayed_action[0])
                        
                (obs, last_act, pos), _, done, _ = env.step(manual_actions)
                env.save_frame(step, instance_id)
                step += 1
                num_comm += np.sum(comm_mask)

    return np.array_equal(env.agents_pos, env.goals_pos), step, num_comm

In [96]:
test_model(128000)

  state_dict = torch.load(os.path.join(config.save_path, f'{model_range}.pth'), map_location=DEVICE)


----------test model 128000----------
test set: warehouse env 64 agents
Agent 24 (Goal: [5, 157]): (Action: east, Position: [4, 116], Not arrived), (Action: east, Position: [4, 117], Not arrived), (Action: east, Position: [4, 118], Not arrived), (Action: east, Position: [4, 119], Not arrived), (Action: east, Position: [4, 120], Not arrived), (Action: east, Position: [4, 121], Not arrived), (Action: east, Position: [4, 122], Not arrived), (Action: east, Position: [4, 123], Not arrived), (Action: east, Position: [4, 124], Not arrived), (Action: east, Position: [4, 125], Not arrived), (Action: east, Position: [4, 126], Not arrived), (Action: east, Position: [4, 127], Not arrived), (Action: east, Position: [4, 128], Not arrived), (Action: east, Position: [4, 129], Not arrived), (Action: east, Position: [4, 130], Not arrived), (Action: east, Position: [4, 131], Not arrived), (Action: east, Position: [4, 132], Not arrived), (Action: east, Position: [4, 133], Not arrived), (Action: east, Posi