In [1]:
# general
import argparse
import pickle
import torch
import time
import tempfile
import numpy as np
import random
import os
import wandb
#from ray.rllib.agents import ppo, dqn, pg, a3c, impala
from tqdm import tnrange

# our code
from sigma_graph.envs.figure8.action_lookup import MOVE_LOOKUP, TURN_90_LOOKUP
from sigma_graph.envs.figure8.default_setup import OBS_TOKEN
from sigma_graph.envs.figure8.figure8_squad_rllib import Figure8SquadRLLib
from sigma_graph.envs.figure8.gflow_figure8_squad import GlowFigure8Squad
#from graph_scout.envs.base import ScoutMissionStdRLLib
import sigma_graph.envs.figure8.default_setup as default_setup
import model  # THIS NEEDS TO BE HERE IN ORDER TO RUN __init__.py!
import model.utils as utils
import model.gnn_gflow 
from trajectory import Trajectory
import losses
import torch.optim as optim



In [2]:
WANDB = True
SEED = 0
LEARNING_RATE = 1e-3
EPOCHS = 3000
BATCH_SIZE = 100

In [3]:
if WANDB:
    wandb.login()
    wandb.init(
        project="graph-training-simulation",
        config={
            "learning_rate": LEARNING_RATE,
            "epochs": EPOCHS,
            "batch_size": BATCH_SIZE,
            "seed": SEED
        }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mr-marr747[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
config = {
    "custom_model": "graph_transformer_policy",
    "custom_model_config": {
        "nred": 1,
        "nblue": 1,
        "aggregation_fn": "agent_node",
        "hidden_size": 10,
        "is_hybrid": False,
        "conv_type": "gcn",
        "layernorm": False,
        "graph_obs_token": {"embed_opt": False, "embed_dir": True},
    },
    "env_config": {
        "env_path": ".",
        "act_masked": True,
        "init_red": None,
        "init_blue": None,
        "init_health_red": 20,
        "init_health_blue": 20,
        "obs_embed": False,
        "obs_dir": False,
        "obs_team": True,
        "obs_sight": False,
        "log_on": False,
        "log_path": "logs/temp/",
        "fixed_start": -1,
        "penalty_stay": 0,
        "threshold_damage_2_blue": 2,
        "threshold_damage_2_red": 5,
    },
}

In [5]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7fa8624df3b0>

In [6]:
gflowfigure8 = GlowFigure8Squad(sampler_config=config)

---------------
path_data ./GflowsForSimulation/sigma_graph/data/parsed/
/home/rmarr/Documents/GflowsForSimulation_env/GflowsForSimulation


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

gflowfigure8 = GlowFigure8Squad(sampler_config=config)

In [8]:
optimizer = optim.AdamW(gflowfigure8.sampler_fcn.parameters(), lr=LEARNING_RATE)

batch_loss = 0
batch_num = 0
batch_reward = 0

for i in range(EPOCHS):
    trajectory = Trajectory()
    gflowfigure8._reset_agents()
    for _ in range(20):   
        for a_id in range(config['custom_model_config']['nred']):
            step = gflowfigure8.step(a_id)
            trajectory.add_step(
                forward_prob=step['forward_prob'],
                backward_prob=step['backward_prob'],
                # flow=step['flow'],
                # action=step['action'],
                reward=step['step_reward'],
                # node=step['node']
            )
    
    episode_loss = losses.Losses.trajectory_balance(trajectory)
    episode_reward = trajectory.rewards

    batch_num = batch_num + 1
    batch_loss += episode_loss
    batch_reward += episode_reward

    if batch_num % BATCH_SIZE == 0:
        if WANDB:
            wandb.log({"loss": batch_loss/BATCH_SIZE, "reward":  batch_reward/BATCH_SIZE})
            batch_loss = 0
            batch_reward = 0
            for name, param in gflowfigure8.sampler_fcn.named_parameters():
                wandb.log({f"{name}_mean": param.data.mean().item(), f"{name}_std": param.data.std().item()})

    episode_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [12]:
gflowfigure8._reset_agents()
state_dirs = {}
for node in range(27):
    states = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
    states[0][node] = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state = torch.tensor(np.array([states[0],], dtype=np.int8)).to(device)
    probs = gflowfigure8.sampler_fcn.forward(state)
    # (forward_prob, action) = gflowfigure8.probs_to_action(probs)
    total_probs = {}
    total_probs["NOOP"] = (probs[0]+probs[5]+probs[10]).tolist()
    total_probs["N"] = (probs[1]+probs[6]+probs[11]).tolist()
    total_probs["S"] = (probs[2]+probs[7]+probs[12]).tolist()
    total_probs["W"] = (probs[3]+probs[8]+probs[13]).tolist()
    total_probs["E"] = (probs[4]+probs[9]+probs[14]).tolist()
    
    state_dirs[node] = total_probs
print(state_dirs)

{0: {'NOOP': 0.17820949651079954, 'N': 0.19639840144099877, 'S': 0.20381606319986284, 'W': 0.2254008692661287, 'E': 0.1961751695822102}, 1: {'NOOP': 0.18217598909235266, 'N': 0.19812147970705368, 'S': 0.19669236652488165, 'W': 0.21663757582996856, 'E': 0.2063725888457435}, 2: {'NOOP': 0.18004291409355228, 'N': 0.19487807413016764, 'S': 0.221700904809441, 'W': 0.19689705364023707, 'E': 0.20648105332660185}, 3: {'NOOP': 0.18169170924137149, 'N': 0.18619650300562068, 'S': 0.22019112569198818, 'W': 0.2079897318416776, 'E': 0.20393093021934208}, 4: {'NOOP': 0.18428485523093124, 'N': 0.1944376562795975, 'S': 0.20081464503836322, 'W': 0.22708909770667018, 'E': 0.19337374574443794}, 5: {'NOOP': 0.19413872073804495, 'N': 0.18301615324898782, 'S': 0.20584071761857042, 'W': 0.23640976764821833, 'E': 0.18059464074617837}, 6: {'NOOP': 0.1858327918118302, 'N': 0.16176891499626472, 'S': 0.22773962216945004, 'W': 0.25647981870644243, 'E': 0.16817885231601246}, 7: {'NOOP': 0.18215174715508717, 'N': 0.1

In [None]:
trajectory = Trajectory()
gflowfigure8._reset_agents()
for _ in range(20):   
    for a_id in range(config['custom_model_config']['nred']):
        step = gflowfigure8.step(a_id)
        trajectory.add_step(
            forward_prob=step['forward_prob'],
            backward_prob=step['backward_prob'],
            # flow=step['flow'],
            # action=step['action'],
            reward=step['step_reward'],
            # node=step['node']
        )
    print(step['node'])