In [1]:
# general
import argparse
import pickle
import torch
import time
import tempfile
import numpy as np
import random
import os
import wandb
#from ray.rllib.agents import ppo, dqn, pg, a3c, impala
from tqdm import tnrange

# our code
from sigma_graph.envs.figure8.action_lookup import MOVE_LOOKUP, TURN_90_LOOKUP
from sigma_graph.envs.figure8.default_setup import OBS_TOKEN
from sigma_graph.envs.figure8.figure8_squad_rllib import Figure8SquadRLLib
from sigma_graph.envs.figure8.gflow_figure8_squad import GlowFigure8Squad
#from graph_scout.envs.base import ScoutMissionStdRLLib
import sigma_graph.envs.figure8.default_setup as default_setup
import model  # THIS NEEDS TO BE HERE IN ORDER TO RUN __init__.py!
import model.utils as utils
import model.gnn_gflow 
from trajectory import Trajectory
import losses
import torch.optim as optim



In [2]:
WANDB = True
SEED = 0
LEARNING_RATE = 1e-3
EPOCHS = 300000
BATCH_SIZE = 500

In [3]:
if WANDB:
    wandb.login()
    wandb.init(
        project="graph-training-simulation",
        config={
            "learning_rate": LEARNING_RATE,
            "epochs": EPOCHS,
            "batch_size": BATCH_SIZE,
            "seed": SEED
        }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mr-marr747[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
config = {
    "custom_model": "graph_transformer_policy",
    "custom_model_config": {
        "nred": 1,
        "nblue": 1,
        "aggregation_fn": "agent_node",
        "hidden_size": 10,
        "is_hybrid": False,
        "conv_type": "gcn",
        "layernorm": False,
        "graph_obs_token": {"embed_opt": False, "embed_dir": True},
    },
    "env_config": {
        "env_path": ".",
        "act_masked": True,
        "init_red": None,
        "init_blue": None,
        "init_health_red": 20,
        "init_health_blue": 20,
        "obs_embed": False,
        "obs_dir": False,
        "obs_team": True,
        "obs_sight": False,
        "log_on": False,
        "log_path": "logs/temp/",
        "fixed_start": -1,
        "penalty_stay": 0,
        "threshold_damage_2_blue": 2,
        "threshold_damage_2_red": 5,
    },
}

In [5]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x1084082b0>

In [6]:
gflowfigure8 = GlowFigure8Squad(sampler_config=config)
# 25, '11_0110'

---------------
path_data ./GflowsForSimulation/sigma_graph/data/parsed/
/Users/ryanmarr/Documents/CognArch/GflowsForSimulation


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

gflowfigure8 = GlowFigure8Squad(sampler_config=config)

In [11]:
optimizer = optim.AdamW(gflowfigure8.sampler_fcn.parameters(), lr=LEARNING_RATE)

batch_loss = 0
batch_num = 0
batch_reward = 0

minibatch_loss = 0
minibatch_reward = 0

for i in range(EPOCHS):
    trajectory = Trajectory()
    gflowfigure8._reset_agents()
    for _ in range(20):   
        for a_id in range(config['custom_model_config']['nred']):
            step = gflowfigure8.step(a_id)
            trajectory.add_step(
                forward_prob=step['forward_prob'],
                backward_prob=step['backward_prob'],
                # flow=step['flow'],
                # action=step['action'],
                reward=step['step_reward'],
                # node=step['node']
            )

    logZ = gflowfigure8.sampler_fcn.logZ
    episode_loss = losses.Losses.trajectory_balance(trajectory, logZ)
    
    episode_reward = trajectory.rewards

    batch_num = batch_num + 1
    batch_loss += episode_loss
    batch_reward += episode_reward

    if batch_num % BATCH_SIZE == 0:
        if WANDB:
            wandb.log({"loss": batch_loss/BATCH_SIZE, "reward":  batch_reward/BATCH_SIZE})
            batch_loss = 0
            batch_reward = 0
            for name, param in gflowfigure8.sampler_fcn.named_parameters():
                wandb.log({f"{name}_mean": param.data.mean().item(), f"{name}_std": param.data.std().item()})

    episode_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

KeyboardInterrupt: 

In [8]:
gflowfigure8._reset_agents()
state_dirs = {}
for node in range(27):
    states = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
    states[0][node] = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state = torch.tensor(np.array([states[0],], dtype=np.int8)).to(device)
    probs = gflowfigure8.sampler_fcn.forward(state)
    # (forward_prob, action) = gflowfigure8.probs_to_action(probs)
    total_probs = {}
    total_probs["NOOP"] = (probs[0]+probs[5]+probs[10]).tolist()
    total_probs["N"] = (probs[1]+probs[6]+probs[11]).tolist()
    total_probs["S"] = (probs[2]+probs[7]+probs[12]).tolist()
    total_probs["W"] = (probs[3]+probs[8]+probs[13]).tolist()
    total_probs["E"] = (probs[4]+probs[9]+probs[14]).tolist()
    
    state_dirs[node] = total_probs
print(state_dirs)

{0: {'NOOP': -8.268583186501372, 'N': -7.372467926049602, 'S': -8.094081168187596, 'W': -8.714105078330345, 'E': -8.34197791149142}, 1: {'NOOP': -7.989169240276983, 'N': -8.429453397781996, 'S': -8.558959676534016, 'W': -7.989027128987622, 'E': -7.734271737759707}, 2: {'NOOP': -8.219692872446817, 'N': -8.228478331694099, 'S': -8.232860806543997, 'W': -8.388799794829199, 'E': -7.614703019079005}, 3: {'NOOP': -8.249323730104553, 'N': -8.489449043159564, 'S': -8.13092456033508, 'W': -7.971445468956761, 'E': -7.826401634263799}, 4: {'NOOP': -8.148759543743694, 'N': -8.25205655393381, 'S': -7.94118358497925, 'W': -8.194708210854905, 'E': -8.094218214944338}, 5: {'NOOP': -8.177092812324275, 'N': -8.326954968024875, 'S': -8.236814066198221, 'W': -7.612063352287366, 'E': -8.33808238117872}, 6: {'NOOP': -8.367796025895304, 'N': -8.780694667949426, 'S': -8.170630384010483, 'W': -6.953959476369189, 'E': -8.762081617012878}, 7: {'NOOP': -8.198084675800025, 'N': -8.367902154390231, 'S': -7.68754842

In [9]:
trajectory = Trajectory()
gflowfigure8._reset_agents()
for _ in range(20):   
    for a_id in range(config['custom_model_config']['nred']):
        step = gflowfigure8.step(a_id)
        trajectory.add_step(
            forward_prob=step['forward_prob'],
            backward_prob=step['backward_prob'],
            # flow=step['flow'],
            # action=step['action'],
            reward=step['step_reward'],
            # node=step['node']
        )
        print(step['node'])
        print(step['action'])


25
[[3, 2]]
19
[[1, 0]]
19
[[0, 2]]
19
[[1, 0]]
19
[[2, 1]]
13
[[2, 1]]
5
[[3, 1]]
5
[[4, 1]]
5
[[3, 2]]
5
[[4, 2]]
5
[[0, 0]]
5
[[0, 0]]
5
[[2, 2]]
4
[[4, 1]]
3
[[4, 0]]
1
[[2, 2]]
1
[[0, 0]]
1
[[1, 1]]
2
[[1, 1]]
12
[[3, 2]]


In [10]:
import argparse
import glob
import os
import re

import networkx as nx
import numpy as np
from PIL import Image
import matplotlib.colors as colors
import matplotlib.pyplot as plt
# from matplotlib.animation import FuncAnimation, PillowWriter
from sigma_graph.data.file_manager import check_dir, find_file_in_dir, load_graph_files

def agent_log_parser(line) -> dict:
    agent_info = {}
    # parse team info
    team_red = re.search(r"red:(\d+)", line)
    team_blue = re.search(r"blue:(\d+)", line)
    if team_red is not None:
        agent_info["team"] = "red"
        agent_info["id"] = team_red[1]
    elif team_blue is not None:
        agent_info["team"] = "blue"
        agent_info["id"] = team_blue[1]
    else:
        assert f"[log] Invalid agent team format: {line}"
    # parse agent info
    agent_pos = re.search(r"HP:\s?(\d+) node:(\d+) dir:(\d) pos:\((\d+), (\d+)\)", line)
    if agent_info is not None:
        agent_info["HP"] = int(agent_pos[1])
        agent_info["node"] = int(agent_pos[2])
        agent_info["dir"] = int(agent_pos[3])
        agent_info["pos"] = (int(agent_pos[4]), int(agent_pos[5]))
    else:
        assert f"[log] Invalid agent info format: {line}"
    return agent_info


def list_nums_log_parser(line):
    pass


def log_file_parser(line):
    print(f'line {line}')
    segments = line.split(" | ")
    print(f'segments {segments}')
    step_num = int(re.search(r"Step #\s?(\d+)", segments[0])[1])

    agents = []
    for str_agents in segments[1:-2]:
        agents.append(agent_log_parser(str_agents))
    actions = segments[-2]
    rewards = segments[-1]

    return step_num, agents, actions, rewards[:-1]


def check_log_files(env_dir, log_dir, log_file):
    # generate a subfolder in the log folder for -> animations (and optional pictures for each step)
    log_file_dir = find_file_in_dir(log_dir, log_file)
    fig_file_dir = os.path.join(log_dir, log_file[:-4])
    if not check_dir(fig_file_dir):
        os.mkdir(fig_file_dir)
    return log_file_dir, fig_file_dir


def generate_picture(env_dir, log_dir, log_file, HP_red, TR_red, HP_blue, TR_blue,
                     color_decay=True, if_froze=False, max_step=40, map_lookup="S"):
    # check file existence
    log_file_dir, fig_folder = check_log_files(env_dir, log_dir, log_file)
    map_info, _ = load_graph_files(env_path=env_dir, map_lookup=map_lookup)
    # load log info
    file = open(log_file_dir, 'r')
    lines = file.readlines()

    # predetermined colors
    col_map_red = ['#200000', '#200000', '#400000', '#800000', '#BF0000', '#FF0000']
    col_map_blue = ['#000020', '#000020', '#000040', '#000080', '#0000BF', '#0000FF']
    if color_decay:
        HP_offset = 0.1
        red_bds = np.append([0], np.linspace(HP_red - TR_red + HP_offset, HP_red + HP_offset, num=len(col_map_red)))
        red_norm = colors.BoundaryNorm(boundaries=red_bds, ncolors=len(col_map_red))
        blue_bds = np.append([0], np.linspace(HP_blue - TR_blue, HP_blue + HP_offset, num=len(col_map_blue)))
        blue_norm = colors.BoundaryNorm(boundaries=blue_bds, ncolors=len(col_map_blue))
    else:
        red_norm = lambda x: -1
        blue_norm = lambda x: -1

    total_reward = 0
    pause_step = 0

    for i, line in enumerate(lines):
        fig = plt.figure()
        # set figure background opacity (alpha) to 0
        fig.patch.set_alpha(0.)
        fig.tight_layout()
        plt.axis('off')

        if i < max_step:
            idx_step, agents, action, reward = log_file_parser(line)
            text_head = f"#{idx_step:2d}/{max_step} {action} {reward} "
        elif i == max_step:
            # get episode rewards from log
            text_head += line[:-1]
        legend_text = [text_head]
        # set color map for agents and waypoints
        col_map = ["gold"] * len(map_info.n_info)
        for agent in agents:
            legend_text += ["{}_{} HP:{} node:{} dir:{} pos:{}".format(agent["team"], agent["id"], agent["HP"],
                                                                       agent["node"], agent["dir"], agent["pos"])]
            if agent["team"] == 'red':
                col_map[agent['node'] - 1] = "red" #col_map_red[red_norm(agent['HP'])]
            elif agent["team"] == 'blue':
                blue_health = agent['HP']
                col_map[agent['node'] - 1] = "blue" #col_map_blue[blue_norm(blue_health)]
        # set pause frame number for gif looping
        if if_froze and (not pause_step) and (blue_health <= HP_blue - TR_blue):
            pause_step = i
        # render fig and save to png
        nx.draw_networkx(map_info.g_acs, map_info.n_info, node_color=col_map, edge_color="grey", arrows=True)
        plt.legend(legend_text, bbox_to_anchor=(0.07, 0.95, 0.83, 0.1), loc='lower left', prop={'size': 8},
                   mode="expand", borderaxespad=0.)
        plt.savefig(os.path.join(fig_folder, f"{i:03d}.png"), dpi=100, transparent=True)
        plt.close()
    return fig_folder, pause_step


def generate_picture_route(env_dir, log_dir, log_file, route_info):
    log_file_dir, fig_folder = check_log_files(env_dir, log_dir, log_file)
    return fig_folder


def frame_add_background(img_dir, gif_file, bg_file, fps, stop_frame=0, wait_frame=5):
    img_files = img_dir + "/*.png"
    imgs = []
    frames = 0
    for f in sorted(glob.glob(img_files)):
        foreground = Image.open(f)
        background = Image.open(bg_file)
        background.paste(foreground, (0, 0), foreground)
        imgs.append(background)
        if stop_frame:
            if frames == stop_frame:
                break
            frames += 1
    # set up additional end frames before looping
    if not stop_frame:
        for i in range(wait_frame):
            imgs.append(imgs[-1])
    imgs[0].save(fp=gif_file, format='GIF', append_images=imgs[1:],
                 save_all=True, duration=(1000 // fps), loop=(stop_frame > 0))


def local_run(env_dir, log_dir, prefix, bg_pic, fps, HP_red, TR_red, HP_blue, TR_blue,
              color_decay=True, froze=False, route_only=False, route_info=None):
    directory = os.fsencode(log_dir)
    for file in os.listdir(directory):
        log_file = os.fsdecode(file)
        print(f'log_file {log_file}')
        if log_file.endswith(".txt") and log_file.startswith(prefix):
            if route_only:
                fig_folder = generate_picture_route(env_dir, log_dir, log_file, route_info)
                pause_frame = 0
            else:
                fig_folder, pause_frame = generate_picture(env_dir, log_dir, log_file,
                                                           HP_red, TR_red, HP_blue, TR_blue, color_decay, froze)
            frame_add_background(fig_folder, os.path.join(log_dir, f"{log_file[:-4]}.gif"), bg_pic, fps,
                                 pause_frame)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_dir', type=str, default='../../', help='path to project root')
    parser.add_argument('--log_dir', type=str, default='../../logs/visuals/demo/', help='path to log file folder')
    parser.add_argument('--prefix', type=str, default='log_', help='log file prefix')
    parser.add_argument('--background', type=str, default='../../logs/visuals/background.png')
    parser.add_argument('--fps', type=int, default=2)  # frame per second in animations

    parser.add_argument('--HP_froze_on', action='store_true', default=False, help='stop animation if agent is dead')
    parser.add_argument('--HP_red', type=int, default=100)
    parser.add_argument('--TR_red', type=int, default=5)
    parser.add_argument('--HP_blue', type=int, default=100)
    parser.add_argument('--TR_blue', type=int, default=10)
    parser.add_argument('--HP_color_off', action='store_false', default=True, help='gradient colors for HP')

    parser.add_argument('--route_only', type=bool, default=False)  # exclude step info
    parser.add_argument('--route_info', type=str, default='name')  # choose from ['name', 'pos', 'idx']
    args = parser.parse_args(args=['--log_dir', '../logs/temp', '--env_dir', '../', '--background', '../logs/background.png'])

    local_run(args.env_dir, args.log_dir, args.prefix, args.background, args.fps,
              args.HP_red, args.TR_red, args.HP_blue, args.TR_blue, args.HP_color_off,
              args.HP_froze_on, args.route_only, args.route_info)

FileNotFoundError: [Errno 2] No such file or directory: b'../logs/temp'