In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as pp
import numpy as np
from ipywidgets import interact, IntSlider, fixed

from sigma_graph.envs.figure8.action_lookup import MOVE_LOOKUP, TURN_90_LOOKUP
from sigma_graph.envs.figure8.default_setup import OBS_TOKEN
from sigma_graph.envs.figure8.figure8_squad_rllib import Figure8SquadRLLib
from sigma_graph.envs.figure8.gflow_figure8_squad import GlowFigure8Squad
#from graph_scout.envs.base import ScoutMissionStdRLLib
import sigma_graph.envs.figure8.default_setup as default_setup
from sigma_graph.data.file_manager import check_dir, find_file_in_dir, load_graph_files
import model  # THIS NEEDS TO BE HERE IN ORDER TO RUN __init__.py!
import model.utils as utils
import model.gnn_gflow 
from trajectory import Trajectory
import losses

import torch.optim as optim
import wandb
import json
import random
import networkx as nx
import matplotlib.pyplot as plt

NUM_EPOCHS = 100000
# default = 34
BATCH_SIZE = 100
LEARNING_RATE = 3e-4
WANDB = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

local_action_move = {
    0: "NOOP",
    1: "N",
    2: "S",
    3: "W",
    4: "E",
}

def state_to_vec(state):
    result = [0]*27
    result[state-1] = 1
    return torch.tensor(result).float()

def compute_reward(state):
    if state == 17:
        return 1
    return 0

def convert_discrete_action_to_multidiscrete(action):
        return [action % len(local_action_move), action // len(local_action_move)]



In [2]:
# Investigate loss rewar mirror
# Try real reward
# Make code cleaner 
# visualize flows 

config = {
    "custom_model_config": {
        "custom_model": "fcn", #fcn #attn_fcn
        "reward": "complex", #random_region random single complex
        "reward_interval": "step", #trajectory 
        "trajectory_per_reward": 1,
        "embedding": "number", #number #coordinate
        "is_dynamic_embedding": False,
        "trajectory_length": 34,
        "nred": 1,
        "nblue": 1,
        "start_node": 22,
        "aggregation_fn": "agent_node",
        "hidden_size": 15,
        "is_hybrid": False,
        "conv_type": "gcn",
        "layernorm": False,
        "graph_obs_token": {"embed_opt": False, "embed_dir": True},
    },
    "env_config": {
        "env_path": ".",
        "act_masked": True,
        "init_red": None,
        "init_blue": None,
        "init_health_red": 20,
        "init_health_blue": 20,
        "obs_embed": False,
        "obs_dir": False,
        "obs_team": True,
        "obs_sight": False,
        "log_on": True,
        "log_path": "logs/temp/",
        "fixed_start": -1,
        "penalty_stay": 0,
        "threshold_damage_2_blue": 2,
        "threshold_damage_2_red": 5,
    },
}

current_time = datetime.now()
run_name = f"{config['custom_model_config']['custom_model']}-{config['custom_model_config']['reward']}-{config['custom_model_config']['embedding']}-{current_time.strftime('%Y-%m-%d %H:%M:%S')}"
print(run_name)

if WANDB:
    wandb.init(
        project="graph-training-simulation",
        config={
                "model_config": config,
                "exp_config": {
                    "learning_rate": LEARNING_RATE,
                    "epocs": NUM_EPOCHS,
                    "batch_size": BATCH_SIZE
            }
        },
        name=run_name
    )

gflowfigure8 = GlowFigure8Squad(sampler_config=config)
optimizer = optim.AdamW(gflowfigure8.sampler_fcn_coordinate_time.parameters(), lr=LEARNING_RATE)

fcn-complex-number-2024-05-06 01:45:51


In [3]:
# FCN coordinate time

minibatch_loss = 0
minibatch_reward = 0
minibatch_z = 0
minibatch_pf = 0
minibatch_pb = 0

pbar = tqdm(total=NUM_EPOCHS)
episode = 0

while episode <= NUM_EPOCHS:
  
  TEMP_AGENT_INDEX = 0
  
  total_P_F = 0
  total_P_B = 0
  total_reward = 0
    
  gflowfigure8.reset()

  trajectory_length = config['custom_model_config']['trajectory_length']

  trajectory_path = []
  action_path = []
  gflowfigure8.reset_state()
  for t in range(trajectory_length):
    step = gflowfigure8.step_fcn_coordinate_time(TEMP_AGENT_INDEX)  
    total_P_F += step['forward_prob']
    total_P_B += step['backward_prob']
    total_reward += step['step_reward']
    trajectory_path.append(step['red_node'])
    action_path.append(step['action'])

  logZ = gflowfigure8.sampler_fcn_coordinate_time.logZ
    
  #clipped_reward = torch.log(torch.tensor(total_reward)).clip(-20)
  last_node = gflowfigure8.team_red[0].get_info()["node"]
  clipped_reward = torch.log(torch.tensor(compute_reward(last_node))).clip(-20)

  loss = (logZ + total_P_F - clipped_reward - total_P_B).pow(2)

  minibatch_loss += loss
  minibatch_reward += clipped_reward
  minibatch_z += logZ
  minibatch_pf += total_P_F
  minibatch_pb += total_P_B

  if (episode + 1) % BATCH_SIZE == 0:
    if WANDB:
      wandb.log({
          "loss": minibatch_loss/BATCH_SIZE, 
          "reward":  minibatch_reward/BATCH_SIZE,
          "pf": minibatch_pf/BATCH_SIZE,
          "pb": minibatch_pb/BATCH_SIZE,
          "z": minibatch_z/BATCH_SIZE
        })
      # for name, param in sampler.named_parameters():
      #     wandb.log({f"{name}_mean": param.data.mean().item(), f"{name}_std": param.data.std().item()})
    
    minibatch_loss.backward(retain_graph=True)
    optimizer.step()
    optimizer.zero_grad()
    minibatch_loss = 0
    minibatch_reward = 0
    minibatch_z = 0 
    minibatch_pf = 0 
    minibatch_pb = 0 

  pbar.update(1)
  episode = episode + 1

  0%|          | 0/100000 [00:00<?, ?it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)


  0%|          | 1/100000 [00:00<11:15:55,  2.47it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 3/100000 [00:00<4:37:38,  6.00it/s] 

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 7/100000 [00:00<2:56:52,  9.42it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 9/100000 [00:01<2:41:30, 10.32it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 11/100000 [00:01<2:32:29, 10.93it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 13/100000 [00:01<2:27:42, 11.28it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 15/100000 [00:01<2:25:45, 11.43it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 19/100000 [00:01<2:22:32, 11.69it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,    0.0000,    0.0000,
         

  0%|          | 21/100000 [00:02<2:20:51, 11.83it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 23/100000 [00:02<2:20:28, 11.86it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 25/100000 [00:02<2:21:28, 11.78it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 29/100000 [00:02<2:17:51, 12.09it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 31/100000 [00:02<2:17:38, 12.10it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 33/100000 [00:03<2:17:43, 12.10it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 35/100000 [00:03<2:18:31, 12.03it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 39/100000 [00:03<2:18:54, 11.99it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 41/100000 [00:03<2:17:35, 12.11it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,    0.0000,    0.0000,
         

  0%|          | 43/100000 [00:03<2:18:11, 12.06it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 45/100000 [00:04<2:19:59, 11.90it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 47/100000 [00:04<2:19:57, 11.90it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 51/100000 [00:04<2:19:52, 11.91it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 53/100000 [00:04<2:21:04, 11.81it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 55/100000 [00:04<2:20:27, 11.86it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 57/100000 [00:05<2:20:51, 11.83it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 59/100000 [00:05<2:21:47, 11.75it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 61/100000 [00:05<2:22:16, 11.71it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 65/100000 [00:05<2:22:27, 11.69it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,    0.0000,    0.0000,
         

  0%|          | 67/100000 [00:05<2:21:18, 11.79it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

  0%|          | 69/100000 [00:06<2:21:04, 11.81it/s]

self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         314.7000, 1988.1000,  322.1000, 1990.3000,  314.7000, 1988.1000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
         314.8000, 1982.0000,  314.8000, 1982.0000,  314.8000, 1982.0000,
           0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000], device='cuda:0', dtype=torch.float64)
self.gflow_state tensor([ 322.1000, 1990.3000,  314.7000, 1988.1000,  314.7000, 1988.1000,
         

KeyboardInterrupt: 

  0%|          | 70/100000 [00:19<2:21:04, 11.81it/s]

In [None]:
torch.save(gflowfigure8, f'models/{run_name}.pt')