In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as pp
import numpy as np
from ipywidgets import interact, IntSlider, fixed

from sigma_graph.data.file_manager import load_graph_files, save_log_2_file, log_done_reward
from sigma_graph.envs.figure8.action_lookup import MOVE_LOOKUP, TURN_90_LOOKUP
from sigma_graph.envs.figure8.default_setup import OBS_TOKEN
from sigma_graph.envs.figure8.figure8_squad_rllib import Figure8SquadRLLib
from sigma_graph.envs.figure8.gflow_figure8_squad import GlowFigure8Squad
#from graph_scout.envs.base import ScoutMissionStdRLLib
import sigma_graph.envs.figure8.default_setup as default_setup
from sigma_graph.data.file_manager import check_dir, find_file_in_dir, load_graph_files
import model  # THIS NEEDS TO BE HERE IN ORDER TO RUN __init__.py!
import model.utils as utils
import model.gnn_gflow 
from trajectory import Trajectory
import losses

import torch.optim as optim
import wandb
import json
import random
import networkx as nx
import matplotlib.pyplot as plt

NUM_EPOCHS = 100000
# default = 34
BATCH_SIZE = 100
LEARNING_RATE = 3e-4
WANDB = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

local_action_move = {
    0: "NOOP",
    1: "N",
    2: "S",
    3: "W",
    4: "E",
}

def state_to_vec(state):
    result = [0]*27
    result[state-1] = 1
    return torch.tensor(result).float()

def compute_reward(state):
    if state == 17:
        return 1
    return 0

def convert_discrete_action_to_multidiscrete(action):
        return [action % len(local_action_move), action // len(local_action_move)]



In [2]:
# Investigate loss rewar mirror
# Try real reward
# Make code cleaner 
# visualize flows 

config = {
    "custom_model_config": {
        "custom_model": "fcn", #fcn #attn_fcn
        "reward": "complex", #random_region random single complex
        "reward_interval": "step", #trajectory 
        "trajectory_per_reward": 1,
        "embedding": "number", #number #coordinate
        "is_dynamic_embedding": False,
        "trajectory_length": 34,
        "nred": 1,
        "nblue": 1,
        "start_node": 22,
        "aggregation_fn": "agent_node",
        "hidden_size": 15,
        "is_hybrid": False,
        "conv_type": "gcn",
        "layernorm": False,
        "graph_obs_token": {"embed_opt": False, "embed_dir": True},
    },
    "env_config": {
        "env_path": ".",
        "act_masked": True,
        "init_red": None,
        "init_blue": None,
        "init_health_red": 20,
        "init_health_blue": 20,
        "obs_embed": False,
        "obs_dir": False,
        "obs_team": True,
        "obs_sight": False,
        "log_on": True,
        "log_path": "logs/temp/",
        "fixed_start": -1,
        "penalty_stay": 0,
        "threshold_damage_2_blue": 2,
        "threshold_damage_2_red": 5,
    },
}

current_time = datetime.now()
run_name = f"{config['custom_model_config']['custom_model']}-{config['custom_model_config']['reward']}-{config['custom_model_config']['embedding']}-{current_time.strftime('%Y-%m-%d %H:%M:%S')}"
print(run_name)

if WANDB:
    wandb.init(
        project="graph-training-simulation",
        config={
                "model_config": config,
                "exp_config": {
                    "learning_rate": LEARNING_RATE,
                    "epocs": NUM_EPOCHS,
                    "batch_size": BATCH_SIZE
            }
        },
        name=run_name
    )

gflowfigure8 = GlowFigure8Squad(sampler_config=config)
optimizer = optim.AdamW(gflowfigure8.sampler_fig8_gat_coordinate_time.parameters(), lr=LEARNING_RATE)

fcn-complex-number-2024-05-16 15:53:08


In [3]:
# FCN coordinate time

minibatch_loss = 0
minibatch_reward = 0
minibatch_z = 0
minibatch_pf = 0
minibatch_pb = 0

pbar = tqdm(total=NUM_EPOCHS)
episode = 0

while episode <= NUM_EPOCHS:
  
  TEMP_AGENT_INDEX = 0
  
  total_P_F = 0
  total_P_B = 0
  total_reward = 0
    
  trajectory_length = config['custom_model_config']['trajectory_length']

  trajectory_path = []
  action_path = []
  gflowfigure8.reset()
  gflowfigure8.reset_state_gat_coordinate_time()
  for t in range(trajectory_length):
    step = gflowfigure8.step_gat_coordinate_time(TEMP_AGENT_INDEX)  
    total_P_F += step['forward_prob']
    total_P_B += step['backward_prob']
    total_reward += step['step_reward']
    trajectory_path.append(step['red_node'])
    action_path.append(step['action'])

  logZ = gflowfigure8.sampler_fcn_coordinate_time.logZ
    
  clipped_reward = torch.log(torch.tensor(total_reward)).clip(-20)
  #last_node = gflowfigure8.team_red[0].get_info()["node"]
  #clipped_reward = torch.log(torch.tensor(compute_reward(last_node))).clip(-20)

  loss = (logZ + total_P_F - clipped_reward - total_P_B).pow(2)
  
  minibatch_loss += loss
  minibatch_reward += clipped_reward
  minibatch_z += logZ
  minibatch_pf += total_P_F
  minibatch_pb += total_P_B

  if (episode + 1) % BATCH_SIZE == 0:
    if WANDB:
      wandb.log({
          "loss": minibatch_loss/BATCH_SIZE, 
          "reward":  minibatch_reward/BATCH_SIZE,
          "pf": minibatch_pf/BATCH_SIZE,
          "pb": minibatch_pb/BATCH_SIZE,
          "z": minibatch_z/BATCH_SIZE
        })
      # for name, param in sampler.named_parameters():
      #     wandb.log({f"{name}_mean": param.data.mean().item(), f"{name}_std": param.data.std().item()})
    
    minibatch_loss.backward(retain_graph=True)
    optimizer.step()
    optimizer.zero_grad()
    minibatch_loss = 0
    minibatch_reward = 0
    minibatch_z = 0 
    minibatch_pf = 0 
    minibatch_pb = 0 

  pbar.update(1)
  episode = episode + 1

  0%|          | 0/100000 [00:00<?, ?it/s]

0
prev_blue None
tensor([[ 322.1000, 1990.3000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  319.6000, 1979.8000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000

  0%|          | 1/100000 [00:00<16:32:41,  1.68it/s]

tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
discrete_action 1
25
prev_blue 26
tensor([[ 322.1000, 1990.3000,  327.1000, 1988.8000,  322.1000, 1990.3000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  319.6000, 1979.8000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  319.6000, 1979.8000,
          319.7000, 1975.0000,  319.7000, 1975.0000,  319.7000, 1975.0000,
          319.6000, 1979.8000,  319.7000, 1975.0000,  324.8000, 1974.2000,
          319.7000, 1975.0000,  319.7000, 1975.0000,  319.7000, 1975.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.

  0%|          | 2/100000 [00:00<10:15:35,  2.71it/s]

tensor([[ 322.1000, 1990.3000,  327.1000, 1988.8000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  332.0000, 1988.0000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  332.0000, 1988.0000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  336.8000, 1987.6000,  332.0000, 1988.0000,
          327.1000, 1988.8000,  332.0000, 1988.0000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  332.0000, 1988.0000,  336.8000, 1987.6000,
          336.8000, 1987.6000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  327.1000, 1988.8000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  332.0000, 1988.0000,  332

  0%|          | 3/100000 [00:01<8:22:01,  3.32it/s] 

tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
discrete_action 5
19
prev_blue 1
tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  319.6000, 1979.8000,
          319.7000, 1975.0000,  319.6000, 1979.8000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  322.1000, 1990.3000,  322.1000, 1990.3000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0

  0%|          | 4/100000 [00:01<7:21:38,  3.77it/s]

tensor([-1.4157e-19, -3.2372e-19, -2.8255e-19, -6.3127e-19, -4.1479e-19,
        -1.8999e-19, -1.8562e-19, -5.3951e-20, -2.7150e-19,  4.6713e-19,
         4.0762e-20,  2.7900e-19,  4.1470e-19, -1.4223e-19, -9.0716e-20,
        -2.5758e-19, -8.2655e-20, -8.2204e-20,  1.1782e-20, -3.7623e-20,
         5.0837e-19, -1.8213e-20, -5.4915e-20, -7.2708e-20,  1.9508e-19,
         3.1273e-19,  5.9487e-20, -2.2462e-19,  2.2858e-19,  1.4132e-19,
         1.3160e-19,  4.7856e-19,  8.4144e-20, -7.6149e-19, -3.7058e-19,
         5.8752e-19,  3.3265e-19,  4.9192e-20,  6.1380e-20,  2.8629e-19,
         8.0180e-20,  1.7442e-20, -1.5355e-19,  1.3546e-19,  7.9472e-20,
         2.6554e-20, -5.1816e-20,  1.5906e-19,  3.9573e-19, -2.5845e-19,
        -2.3029e-19,  3.7449e-20,  1.9447e-19,  1.6773e-19,  3.0522e-19,
        -1.9642e-19, -2.5961e-19,  4.6003e-19, -2.4226e-19, -3.6523e-20,
        -1.4309e-19, -2.2282e-19,  3.9258e-20, -5.1718e-19, -8.1405e-20,
         1.1402e-19, -3.2110e-19,  9.2076e-20, -1.8

  0%|          | 5/100000 [00:01<6:50:02,  4.06it/s]

tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.7000, 1988.1000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          319.6000, 1979.8000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.7000, 1988.1000,  314.7000, 1988.1000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  314

  0%|          | 6/100000 [00:01<6:39:13,  4.17it/s]

13
prev_blue 13
tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          327.1000, 1988.8000,  322.1000, 1990.3000,  319.6000, 1979.8000,
          319.6000, 1979.8000,  314.8000, 1982.0000,  319.6000, 1979.8000,
          322.1000, 1990.3000,  327.1000, 1988.8000,  332.0000, 1988.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          327.1000, 1988.8000,  322.1000,

  0%|          | 7/100000 [00:01<6:25:22,  4.32it/s]

tensor([[[0.],
         [1.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.],
         [1.]],

        [[0.],
         [1.],
         [0.],
         [0.]],

        [[0.],
         [1.],
         [0.],
         [0.]]], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
discrete_action 7
10
prev_blue 6
tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  319.6000, 1979.8000,
          324.3000, 1979.2000,  

  0%|          | 8/100000 [00:02<6:20:06,  4.38it/s]

8
prev_blue 8
tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  322.1000, 1990.3000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  314.7000, 1988.1000,  322.1000, 1990.3000,
          314.7000, 1988.1000,  314.7000, 1

  0%|          | 9/100000 [00:02<6:12:34,  4.47it/s]

tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  319.6000, 1979.8000,    0

  0%|          | 10/100000 [00:02<6:06:27,  4.55it/s]

tensor([ -69.7951,   38.3313,  -46.5285,  117.7455,  -12.8158,  -27.4807,
          24.6544,   54.5816,    1.0731,   76.1421, -100.1815,   92.8772,
         -27.8703,   63.0707,  -76.6610,  -54.8234,   62.8433,  -61.9647,
          35.4404,   -7.1171,   42.2823,  -36.1327,   49.1367,  -19.3771,
          27.4013,   -0.6074,   -6.0619,   67.7835, -102.7899,  -79.8001,
         -45.8892,   11.4874,  -46.6974,   40.5250,  119.5813,  -52.4703,
          29.6111,  -21.7713,  -41.0565,  -29.2273,   15.3598,   17.5833,
           2.3987,   15.8504,   43.4923,  -57.9175,  -10.7271,  -64.6690,
          80.1943,   -0.9560,   95.4584,   36.8315,  -11.3881,  -63.0563,
          91.0873,  -32.6572,  -14.4932,  -55.6293,  113.3433,   45.6122,
          32.0620,   18.4802,   35.7780,   30.2611,  -43.5460,   55.4326,
         104.5865,   25.8983,   47.3885], device='cuda:0', dtype=torch.float64,
       grad_fn=<SelectBackward0>)
tensor([ 22.3652, -13.0786,   4.1726,  -6.5086,  -7.2236, -45.3061,   8.

  0%|          | 11/100000 [00:02<6:05:58,  4.55it/s]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
discrete_action 9
33
prev_blue 16
tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  319.6000, 1979.8000,
          324.3000, 1979.2000,  324.8000, 1974.2000,  319.7000, 1975.0000,
          324.8000, 1974.2000,  324.8000, 1974.2000,  324.8000, 1974.2000,
          324.3000, 1979.2000,  324.8000, 1974.2000,  324.8000, 1974.2000,
          319.7000, 1975.0000,  324.8000, 1974.2000,  319.7000, 1975.00

  0%|          | 12/100000 [00:02<6:07:05,  4.54it/s]

tensor([[ 322.1000, 1990.3000,  327.1000, 1988.8000,  322.1000, 1990.3000,
          327.1000, 1988.8000,  327.1000, 1988.8000,  324.3000, 1979.2000,
          324.3000, 1979.2000,  324.3000, 1979.2000,  324.3000, 1979.2000,
          327.1000, 1988.8000,  327.1000, 1988.8000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  327.1000, 1988.8000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.8000, 1982.0000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          322.1000, 1990.3000,  319.6000, 1979.8000,  324.3000, 1979.2000,
          319.6000, 1979.8000,  322.1000, 1990.3000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  327.1000, 1988.8000,  322.1000, 1990.3000,
          327.1000, 1988.8000,  327.1000, 1988.8000,  324

  0%|          | 13/100000 [00:03<6:04:58,  4.57it/s]

28
prev_blue 8
tensor([[3.2210e+02, 1.9903e+03, 3.2210e+02, 1.9903e+03, 3.1960e+02, 1.9798e+03,
         3.1480e+02, 1.9820e+03, 3.1470e+02, 1.9881e+03, 3.1480e+02, 1.9820e+03,
         3.1470e+02, 1.9881e+03, 3.1480e+02, 1.9820e+03, 3.1470e+02, 1.9881e+03,
         3.2210e+02, 1.9903e+03, 3.1960e+02, 1.9798e+03, 3.1960e+02, 1.9798e+03,
         3.1960e+02, 1.9798e+03, 3.1960e+02, 1.9798e+03, 3.1970e+02, 1.9750e+03,
         3.1970e+02, 1.9750e+03, 3.1970e+02, 1.9750e+03, 3.2480e+02, 1.9742e+03,
         3.1970e+02, 1.9750e+03, 3.1970e+02, 1.9750e+03, 3.1970e+02, 1.9750e+03,
         3.1970e+02, 1.9750e+03, 3.1970e+02, 1.9750e+03, 3.1970e+02, 1.9750e+03,
         3.2480e+02, 1.9742e+03, 3.2450e+02, 1.9694e+03, 3.2450e+02, 1.9694e+03,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.2210e+02, 1.9903e+03, 3.2210e+02, 1.9

  0%|          | 14/100000 [00:03<6:03:35,  4.58it/s]

tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  322.1000, 1990.3000,
          319.6000, 1979.8000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.7000, 1988.1000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          327.1000, 1988.8000,  324.3000, 1979.2000,  324.3000, 1979.2000,
          324.3000, 1979.2000,  319.6000, 1979.8000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  322.1000, 1990.3000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  319.6000, 1979.8000,
          319.6000, 1979.8000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  314.7000, 1988.1000,  322.1000, 1990.3000,
          319.6000, 1979.8000,  314.8000, 1982.0000,  314

  0%|          | 15/100000 [00:03<5:59:23,  4.64it/s]

tensor([[[0.],
         [1.],
         [0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.],
         [1.],
         [1.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.]]], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
dis

  0%|          | 16/100000 [00:03<5:56:42,  4.67it/s]

tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  319.6000, 1979.8000,
          319.7000, 1975.0000,  324.8000, 1974.2000,  324.3000, 1979.2000,
          327.1000, 1988.8000,  324.3000, 1979.2000,  324.8000, 1974.2000,
          324.8000, 1974.2000,  324.5000, 1969.4000,  324.8000, 1974.2000,
          324.5000, 1969.4000,  330.4000, 1968.5000,  324.5000, 1969.4000,
          330.4000, 1968.5000,  330.4000, 1968.5000,  330.4000, 1968.5000,
          330.4000, 1968.5000,  330.4000, 1968.5000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  319

  0%|          | 17/100000 [00:04<6:01:15,  4.61it/s]

tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  319.6000, 1979.8000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  319.6000, 1979.8000,  322.1000, 1990.3000,
          327.1000, 1988.8000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314

  0%|          | 18/100000 [00:04<6:08:02,  4.53it/s]

tensor([[[0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.]],

        [[1.],
         [0.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.]]], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.00

  0%|          | 19/100000 [00:04<6:09:03,  4.52it/s]

tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  319.6000, 1979.8000,
          319.7000, 1975.0000,  319.6000, 1979.8000,  319.7000, 1975.0000,
          319.6000, 1979.8000,  319.6000, 1979.8000,  324.3000, 1979.2000,
          324.8000, 1974.2000,  324.8000, 1974.2000,  319.7000, 1975.0000,
          319.7000, 1975.0000,  319.6000, 1979.8000,  319.6000, 1979.8000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  322.1000, 1990.3000,  319.6000, 1979.8000,
          319.7000, 1975.0000,  319.6000, 1979.8000,  319

  0%|          | 20/100000 [00:04<6:05:01,  4.56it/s]

tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  327.1000, 1988.8000,  322.1000, 1990.3000,
          319.6000, 1979.8000,  324.3000, 1979.2000,  324.8000, 1974.2000,
          324.3000, 1979.2000,  324.8000, 1974.2000,  324.5000, 1969.4000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  327.1000, 1988.8000,  322

  0%|          | 21/100000 [00:04<6:04:02,  4.58it/s]

tensor([  130.2756, -2399.6911,   298.1235,  1891.9796,  1295.3317, -2368.2116,
         -466.6295,  2829.8049, -1908.7599,  2718.0970,  -380.5298,  5770.1873,
        -4872.4768,  -439.6975,   737.8804, -1938.7696,  -466.7934, -5901.5140,
         -556.4475,  -977.9677, -5257.0957, -2178.1428,  -658.1715, -4265.5254,
          642.3653,  -171.3864,  2230.2650, -2635.9788,  1610.3492, -2077.5482,
         5401.9085,  1967.5138,    34.0498,  -239.9268,  2598.9704,   375.3582,
        -1652.3358, -6140.4219,   294.6462, -2019.1585, -4116.3406,   123.3791,
         -451.3917, -3585.1927, -2761.2807,  -591.0236, -1068.1998,  6691.8645,
         3850.4187,  2365.4164,  1189.6736, -5500.1564, -2222.5027, -1750.5593,
         2770.7829, -3121.6361,   620.9032,   464.6018,  2612.7188,   677.2404,
         2206.0811,  4576.6042, -2384.7463,  1727.9622, -4691.2154,  -168.4110,
         -507.0342, -2832.9500,  4214.4624], device='cuda:0',
       dtype=torch.float64, grad_fn=<SelectBackward0>)
ten

  0%|          | 22/100000 [00:05<6:01:18,  4.61it/s]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
discrete_action 12
10
prev_blue 6
tensor([[ 322.1000, 1990.3000,  319.6000, 1979.8000,  324.3000, 1979.2000,
          324.8000, 1974.2000,  324.8000, 1974.2000,  324.8000, 1974.2000,
          324.8000, 1974.2000,  324.8000, 1974.2000,  324.5000, 1969.4000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.00

  0%|          | 23/100000 [00:05<6:01:22,  4.61it/s]

tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
discrete_action 8
8
prev_blue 8
tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  319.6000, 1979.8000,  319.6000, 1979.8000,
          319.7000, 1975.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.00

  0%|          | 24/100000 [00:05<6:00:29,  4.62it/s]

tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          322.1000, 1990.3000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          322.1000, 1990.3000,  322.1000, 1990.3000,  314

  0%|          | 25/100000 [00:05<6:03:29,  4.58it/s]

0
prev_blue None
tensor([[ 322.1000, 1990.3000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  319.6000, 1979.8000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000

  0%|          | 26/100000 [00:06<6:05:21,  4.56it/s]

tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  314.7000, 1988.1000,
          322.1000, 1990.3000,  319.6000, 1979.8000,  324.3000, 1979.2000,
          327.1000, 1988.8000,  322.1000, 1990.3000,  327.1000, 1988.8000,
          332.0000, 1988.0000,  332.0000, 1988.0000,  327.1000, 1988.8000,
          332.0000, 1988.0000,  327.1000, 1988.8000,  324.3000, 1979.2000,
          324.8000, 1974.2000,  319.7000, 1975.0000,  319.7000, 1975.0000,
          324.8000, 1974.2000,  324.5000, 1969.4000,  324.8000, 1974.2000,
          324.5000, 1969.4000,  324.8000, 1974.2000,  324.5000, 1969.4000,
          324.5000, 1969.4000,  330.4000, 1968.5000,  330.4000, 1968.5000,
          330.4000, 1968.5000,  337.2000, 1968.2000,  342.5000, 1967.5000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  322.1000, 1990.3000,  314.7000, 1988.1000,
          322.1000, 1990.3000,  319.6000, 1979.8000,  324

  0%|          | 27/100000 [00:06<6:06:47,  4.54it/s]

tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  322.1000, 1990.3000,  319.6000, 1979.8000,
          322.1000, 1990.3000,  327.1000, 1988.8000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          327.1000, 1988.8000,  327.1000, 1988.8000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  332.0000, 1988.0000,  336.8000, 1987.6000,
          342.7000, 1987.9000,  342.7000, 1987.9000,  347.9000, 1987.8000,
          347.9000, 1987.8000,  347.9000, 1987.8000,  347.9000, 1987.8000,
          347.9000, 1987.8000,  342.7000, 1987.9000,  336.8000, 1987.6000,
          342.7000, 1987.9000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  322.1000, 1990.3000,  319

  0%|          | 28/100000 [00:06<6:08:21,  4.52it/s]

tensor([[ 322.1000, 1990.3000,  327.1000, 1988.8000,  332.0000, 1988.0000,
          336.8000, 1987.6000,  332.0000, 1988.0000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  336.8000, 1987.6000,  332.0000, 1988.0000,
          332.0000, 1988.0000,  327.1000, 1988.8000,  327.1000, 1988.8000,
          324.3000, 1979.2000,  319.6000, 1979.8000,  322.1000, 1990.3000,
          327.1000, 1988.8000,  324.3000, 1979.2000,  324.8000, 1974.2000,
          324.8000, 1974.2000,  324.8000, 1974.2000,  324.8000, 1974.2000,
          324.8000, 1974.2000,  324.5000, 1969.4000,  324.5000, 1969.4000,
          324.5000, 1969.4000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  327.1000, 1988.8000,  332.0000, 1988.0000,
          336.8000, 1987.6000,  332.0000, 1988.0000,  332

  0%|          | 29/100000 [00:06<6:12:00,  4.48it/s]

tensor([-5.5578e-70,  3.1260e-69,  2.2799e-69,  5.9779e-69, -2.3567e-69,
         2.1919e-69, -1.1742e-69,  2.7213e-69, -8.6379e-70,  1.2160e-71,
         1.2347e-69, -2.6078e-69,  4.5872e-70, -5.8182e-69,  6.4847e-69,
         9.3348e-70, -6.9244e-69,  5.2792e-70, -3.5759e-69, -2.9792e-69,
        -8.4975e-70,  2.7281e-70, -2.3156e-69, -3.5657e-69,  3.7637e-69,
         4.2609e-70, -2.6718e-69, -4.5277e-70,  5.3199e-69,  7.5222e-69,
         4.7583e-69,  4.2456e-70,  3.2515e-69, -1.8949e-69, -3.0376e-69,
         2.5513e-69,  9.2134e-69, -1.8222e-69,  4.5115e-69, -5.7915e-69,
        -6.8418e-70,  1.5454e-70, -6.8640e-69,  9.1964e-70, -1.0244e-70,
         4.9377e-69,  3.2299e-69,  2.0298e-69, -2.2194e-69, -7.0603e-70,
        -6.2680e-69, -1.9394e-69, -7.8931e-69, -4.2555e-69, -1.8686e-69,
        -5.7029e-69, -6.0850e-69,  4.9629e-69, -2.9378e-69, -1.0300e-69,
        -8.8739e-70,  1.3462e-69, -4.7161e-70, -3.9446e-69, -6.8339e-69,
        -1.4706e-69,  5.5409e-70,  4.3799e-71,  1.5

  0%|          | 30/100000 [00:06<6:07:49,  4.53it/s]

tensor([[ 322.1000, 1990.3000,  319.6000, 1979.8000,  319.7000, 1975.0000,
          319.6000, 1979.8000,  319.6000, 1979.8000,  319.6000, 1979.8000,
          324.3000, 1979.2000,  324.3000, 1979.2000,  327.1000, 1988.8000,
          322.1000, 1990.3000,  319.6000, 1979.8000,  319.6000, 1979.8000,
          319.6000, 1979.8000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.7000, 1988.1000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  319.6000, 1979.8000,  319.7000, 1975.0000,
          319.6000, 1979.8000,  319.6000, 1979.8000,  319

  0%|          | 31/100000 [00:07<6:07:54,  4.53it/s]

tensor([[ 322.1000, 1990.3000,  322.1000, 1990.3000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  322.1000, 1990.3000,
          322.1000, 1990.3000,  322.1000, 1990.3000,  327.1000, 1988.8000,
          332.0000, 1988.0000,  332.0000, 1988.0000,  336.8000, 1987.6000,
          342.7000, 1987.9000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  322.1000, 1990.3000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  314.7000, 1988.1000,  314

  0%|          | 32/100000 [00:07<6:04:55,  4.57it/s]

tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          314.7000, 1988.1000,  314.8000, 1982.0000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  322.1000, 1990.3000,  327.1000, 1988.8000,
          332.0000, 1988.0000,  327.1000, 1988.8000,  327.1000, 1988.8000,
          327.1000, 1988.8000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  322.1000, 1990.3000,  322

  0%|          | 33/100000 [00:07<6:05:47,  4.55it/s]

tensor([-0.0105,  0.0225,  0.0137,  0.0303, -0.0384,  0.0036,  0.0135,  0.0612,
         0.0276,  0.0089, -0.0237, -0.0478, -0.0012, -0.0545,  0.0357],
       device='cuda:0', dtype=torch.float64, grad_fn=<ViewBackward0>)
discrete_action 8
15
prev_blue 25
tensor([[ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
          314.7000, 1988.1000,  322.1000, 1990.3000,  322.1000, 1990.3000,
          319.6000, 1979.8000,  324.3000, 1979.2000,  324.3000, 1979.2000,
          324.3000, 1979.2000,  324.8000, 1974.2000,  319.7000, 1975.0000,
          319.6000, 1979.8000,  319.6000, 1979.8000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.

KeyboardInterrupt: 

In [None]:
torch.save(gflowfigure8, f'models/{run_name}.pt')

In [None]:
gflowfigure8.reset()
red_path = []
blue_path = []
step_rewards = []
total_reward = 0
for r in range(34):   
    step = gflowfigure8.step_fcn_coordinate_time(0)
    red_path.append(step['red_node'])
    blue_path.append(step['blue_node'])
    step_rewards.append(step['step_reward'])
    total_reward += step['step_reward']

print(f'step_rewards {total_reward}')
episode_reward = gflowfigure8._episode_rewards_aggressive()[0]
print(f'episode_reward {episode_reward}')
total_reward += episode_reward
print(f'total_reward {total_reward}')

step_rewards 4
episode_reward 0
total_reward 4


In [None]:
map_info, _ = load_graph_files(map_lookup="S")
col_map = ["gold"] * len(map_info.n_info)

def display_graph(index):
    fig, ax = plt.subplots(figsize=(8, 6))
    cur_col_map = col_map[:]
    cur_col_map[red_path[index]-1] = "red"
    cur_col_map[blue_path[index]-1] = "blue"
    nx.draw_networkx(map_info.g_acs, pos=map_info.n_info, node_color=cur_col_map, edge_color="blue", arrows=True, ax=ax)
    ax.set_title(f"step rewards {step_rewards[index]}")
    plt.axis('off')
    plt.show()

# Create an interactive widget to display different graphs
slider = IntSlider(min=0, max=33-1, step=1, value=0, description='Graph Index')
interact(display_graph, index=slider)

interactive(children=(IntSlider(value=0, description='Graph Index', max=32), Output()), _dom_classes=('widget-…

<function __main__.display_graph(index)>

In [None]:
gflowfigure8.reset()
red_path = []
blue_path = []
step_rewards = []
total_reward = 0
for r in range(34):   
    step = gflowfigure8.step_fcn_coordinate_time(0)
    red_path.append(step['red_node'])
    blue_path.append(step['blue_node'])
    step_rewards.append(step['step_reward'])
    total_reward += step['step_reward']

print(f'step_rewards {total_reward}')
episode_reward = gflowfigure8._episode_rewards_aggressive()[0]
print(f'episode_reward {episode_reward}')
total_reward += episode_reward
print(f'total_reward {total_reward}')

step_rewards 2
episode_reward 8
total_reward 10


In [None]:
map_info, _ = load_graph_files(map_lookup="S")
col_map = ["gold"] * len(map_info.n_info)

def display_graph(index):
    fig, ax = plt.subplots(figsize=(8, 6))
    cur_col_map = col_map[:]
    cur_col_map[red_path[index]-1] = "red"
    cur_col_map[blue_path[index]-1] = "blue"
    print(map_info.g_acs)
    print(map_info.n_info)
    nx.draw_networkx(map_info.g_acs, pos=map_info.n_info, node_color=cur_col_map, edge_color="blue", arrows=True, ax=ax)
    ax.set_title(f"step rewards {step_rewards[index+1]}")
    plt.axis('off')
    plt.show()

# Create an interactive widget to display different graphs
slider = IntSlider(min=0, max=33, step=1, value=0, description='Graph Index')
interact(display_graph, index=slider)

interactive(children=(IntSlider(value=0, description='Graph Index', max=33), Output()), _dom_classes=('widget-…

<function __main__.display_graph(index)>

100001it [1:45:20, 22.03it/s]                           