In [142]:
# Testing GridEnvLocal - Local Exploration from Given State

# This notebook demonstrates the new GridEnvLocal environment that allows local exploration starting from a given state with special action rules.


In [143]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import numpy as np
import torch
import matplotlib.pyplot as plt
from types import SimpleNamespace

# Import the new environment
from disc_gflownet.envs.grid_env_local import GridEnvLocal
from reward_func.evo_devo import somitogenesis_reward_func
from graph.graph import draw_network_motif


In [144]:
# Create a test environment with mixed initial values (positive, negative, zero)
# Using a smaller 3-node network for easier testing: 3^2 + 3 = 12 dimensions

test_initial_state = [
    # Weight matrix (3x3 = 9 values): w11, w12, w13, w21, w22, w23, w31, w32, w33
    -901,  0,   0,    # First row: positive, zero, negative
     0,   0,   0,    # Second row: zero, positive, zero  
     0,   0,   0,    # Third row: negative, zero, positive
    # Diagonal values (3 values): d1, d2, d3
     0,  0,   0     # Diagonal: positive, negative, zero
]

print(f"Test initial state: {test_initial_state}")
print(f"State length: {len(test_initial_state)} (should be 12 for 3-node network)")

# Create config for test environment
config = SimpleNamespace(
    n_workers=1,
    cache_max_size=1000,
    min_reward=0.001,
    custom_reward_fn=lambda x: np.sum(np.abs(x)) * 0.01,  # Simple reward function
    n_steps=20,
    n_dims=12,  # 3^2 + 3 = 12 for 3-node network
    initial_state=test_initial_state,
    actions_per_dim={
        'weight': {
            'positive': [100, -10],    # For positive initial values
            'negative': [-100, 10]     # For negative initial values
        },
        'diagonal': {
            'positive': [50, -5],      # For positive initial values
            'negative': [-50, 5]       # For negative initial values
        }
    },
    grid_bound={
        'weight': {'min': -1000, 'max': 1000},
        'diagonal': {'min': -1000, 'max': 1000}
    },
    enable_time=False
)

# Create environment
env = GridEnvLocal(config)
print(f"\nEnvironment created successfully!")
print(f"Number of nodes: {env.n_nodes}")
print(f"Action space size: {env.action_dim}")
print(f"Encoding dimension: {env.encoding_dim}")

# Print action configuration
env.print_actions()


Test initial state: [-901, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
State length: 12 (should be 12 for 3-node network)

Environment created successfully!
Number of nodes: 3
Action space size: 24
Encoding dimension: 24012
--------------------------------------------------
GridEnvLocal: Actions depend on initial state and first action taken
Weight actions:
  For positive initial values: [100, -10]
  For negative initial values: [-100, 10]
Diagonal actions:
  For positive initial values: [50, -5]
  For negative initial values: [-50, 5]
Total action dimension: 24
Action indexing: 2 slots per dimension
--------------------------------------------------


In [145]:
# Reset environment and examine initial state
env.reset()
print("Testing")
print("=" * 50)

print(f"Initial state: {env._state}")
print(f"Action directions (all should be None): {env.action_directions}")

# Test action availability for each dimension
print(f"\nAction availability by dimension:")
print("Dim | Type     | Initial Val | Available Actions")
print("-" * 45)

for dim in range(env.n_dims):
    val = test_initial_state[dim]
    n_weight_params = env.n_nodes * env.n_nodes
    dim_type = "weight" if dim < n_weight_params else "diagonal"
    available_actions = env._get_available_actions(dim)
    
    print(f"{dim:3d} | {dim_type:8s} | {val:11d} | {available_actions}")

# Get initial action mask
initial_mask = env.get_forward_mask(env._state)
print(f"\nInitial action mask:")
print(f"Total actions available: {np.sum(initial_mask)} out of {env.action_dim}")

# Show which specific actions are available
available_indices = np.where(initial_mask)[0]
print(f"Available action indices: {available_indices[:]}")  


Testing
Initial state: [-901, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Action directions (all should be None): [None, None, None, None, None, None, None, None, None, None, None, None]

Action availability by dimension:
Dim | Type     | Initial Val | Available Actions
---------------------------------------------
  0 | weight   |        -901 | [-100, 10]
  1 | weight   |           0 | []
  2 | weight   |           0 | []
  3 | weight   |           0 | []
  4 | weight   |           0 | []
  5 | weight   |           0 | []
  6 | weight   |           0 | []
  7 | weight   |           0 | []
  8 | weight   |           0 | []
  9 | diagonal |           0 | []
 10 | diagonal |           0 | []
 11 | diagonal |           0 | []

Initial action mask:
Total actions available: 1 out of 24
Available action indices: [1]


In [141]:
# Reset and perform step-by-step exploration
env.reset()
print("Testing Step-by-Step Exploration")
print("=" * 40)

def print_step_info():
    mask = env.get_forward_mask(env._state)
    available_indices = np.where(mask)[0]
    print(f"\nTotal steps taken: {env._step}")
    print(f"Current state: {env._state}")
    print(f"All available action indices: {available_indices}")
    print(f"Available actions: {len(available_indices)} out of {env.action_dim}")
    print(f"Action directions: {env.action_directions}")
    return available_indices


available_indices = print_step_info()

acrtion_idx = available_indices[0]
dim = acrtion_idx // env.slots_per_dim
slot = acrtion_idx % env.slots_per_dim
current_val = env._state[dim]
available_actions = env._get_available_actions(dim)
action_val = available_actions[slot] if slot < len(available_actions) else None
print(f"\n---Taking action idx {acrtion_idx} (dim={dim}, slot={slot}, value={action_val})")
obs, reward, done = env.step(acrtion_idx)
print(f"---Action taken, reward: {reward:.4f}, done: {done}")
available_indices = print_step_info()

for _ in range(3):
    acrtion_idx = available_indices[0]
    dim = acrtion_idx // env.slots_per_dim
    slot = acrtion_idx % env.slots_per_dim
    current_val = env._state[dim]
    available_actions = env._get_available_actions(dim)
    action_val = available_actions[slot] if slot < len(available_actions) else None
    print(f"\n---Taking action idx {acrtion_idx} (dim={dim}, slot={slot}, value={action_val})")
    obs, reward, done = env.step(acrtion_idx)
    print(f"---Action taken, reward: {reward:.4f}, done: {done}")
    available_indices = print_step_info()


Testing Step-by-Step Exploration

Total steps taken: 0
Current state: [-901, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
All available action indices: [1]
Available actions: 1 out of 24
Action directions: [None, None, None, None, None, None, None, None, None, None, None, None]

---Taking action idx 1 (dim=0, slot=1, value=10)
---Action taken, reward: 8.9110, done: False

Total steps taken: 1
Current state: [-891, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
All available action indices: [1]
Available actions: 1 out of 24
Action directions: ['slot_1', None, None, None, None, None, None, None, None, None, None, None]

---Taking action idx 1 (dim=0, slot=1, value=10)
---Action taken, reward: 8.8110, done: False

Total steps taken: 2
Current state: [-881, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
All available action indices: [1]
Available actions: 1 out of 24
Action directions: ['slot_1', None, None, None, None, None, None, None, None, None, None, None]

---Taking action idx 1 (dim=0, slot=1, value=10)
---Action taken, 