In [1]:
import torch
from torch.nn.functional import one_hot
from torch.optim import Adam
from tqdm import tqdm
from preconditioner import PreconditionerEnv
from policy import ForwardPolicy, BackwardPolicy
from gflownet.gflownet import GFlowNet
from gflownet.utils import sparse_one_hot
from gflownet.utils import trajectory_balance_loss, market_matrix_to_sparse_tensor
import psutil

In [2]:
def log_memory_usage(stage: str):
    process = psutil.Process()
    mem_info = process.memory_info()
    print(f"[{stage}] CPU Memory Usage: {mem_info.rss / (1024 ** 2):.2f} MB")
    if torch.cuda.is_available():
        print(f"[{stage}] GPU Memory Usage: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")


In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
matrix_path = '../LF10/LF10.mtx'  # Update this with your file path
batch_size = 3
num_epochs = 1000
lr = 0.0001

In [5]:
log_memory_usage("Before Loading Initial Matrix")

# Load the initial matrix from a file
initial_matrix = market_matrix_to_sparse_tensor(matrix_path)
matrix_size = initial_matrix.size(0)

log_memory_usage("After Loading Initial Matrix")


[Before Loading Initial Matrix] CPU Memory Usage: 204.35 MB
[After Loading Initial Matrix] CPU Memory Usage: 205.00 MB


In [6]:
# Initialize the environment and policies
env = PreconditionerEnv(matrix_size=matrix_size, initial_matrix=initial_matrix)
env.data.edge_attr.shape

torch.Size([82])

In [7]:

node_features = -1  # Assuming each node has a single feature, can be adjusted
hidden_dim = 8
forward_policy = ForwardPolicy(node_features=node_features, hidden_dim=hidden_dim, num_actions=env.num_actions)
#forward_policy = ForwardPolicy(in_channels=node_features, hidden_channels=hidden_dim, out_channels=env.num_actions)
backward_policy = BackwardPolicy(matrix_size=matrix_size, num_actions=env.num_actions)

In [8]:
env.data.edge_attr.shape

torch.Size([82])

In [9]:
def check_gradients(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                print(f"{name}: {param.grad.norm()}")
            else:
                print(f"{name}: No gradient")


In [10]:
import pandas as pd
# Initialize the GFlowNet model
model = GFlowNet(forward_policy, backward_policy, env)
opt = Adam(model.parameters(), lr=lr)

log_memory_usage("After Model Initialization")

report_data = pd.DataFrame(columns=['epoch', 'num_actions', 'loss', 'reward'])

for epoch in (p := tqdm(range(num_epochs))):
   #log_memory_usage(f"Start of Epoch {epoch}")

    model.train()
    #opt.zero_grad()

    # Initialize the starting states
    initial_indices = torch.zeros(batch_size).long()
    #s0 = [sparse_one_hot(initial_indices[i:i+1], env.state_dim).float() for i in range(batch_size)]
    s0 = [initial_matrix.clone() for _ in range(batch_size)]
    #s0 = one_hot(torch.zeros(batch_size).long(), env.state_dim).float()
    # Sample final states and log information
    s, log = model.sample_states(s0, return_log=True)
    
    # Calculate the trajectory balance loss
    loss = trajectory_balance_loss(log.total_flow,
                                    log.rewards,
                                    log.fwd_probs,
                                    log.back_probs)
    
    #print(f"log.total_flow {log.total_flow}")
    #print(f"log.rewards {log.rewards}")
    #print(f"log.fwd_probs {log.fwd_probs}")
    #print(f"log.back_probs {log.back_probs}")
    #print(f"log._actions shape {len(log._actions)}")
    #print(f"Loss Calculation: {loss}")
    # Backpropagation and optimization step
    loss.backward()
    #check_gradients(model)
    opt.step()
    #named_params = model.named_parameters()
    opt.zero_grad()

    #Capture data
    total_length = len(log._actions)
    report_data = report_data.append({'epoch': epoch, 'num_actions': total_length, 'loss': loss.item(), 'reward': log.rewards}, ignore_index=True)
    
    if epoch % 100 == 0:
       tqdm.write(f"Epoch {epoch} Loss: {loss.item():.3f}, Num_Actions {total_length}")
        

[After Model Initialization] CPU Memory Usage: 241.39 MB


  0%|          | 1/5000 [00:02<3:44:27,  2.69s/it]

Epoch 0 Loss: 4983.640, Num_Actions 82


  2%|▏         | 101/5000 [02:17<1:38:12,  1.20s/it]

Epoch 100 Loss: 47.465, Num_Actions 12


  4%|▍         | 201/5000 [04:50<2:15:47,  1.70s/it]

Epoch 200 Loss: 3349.520, Num_Actions 77


  6%|▌         | 301/5000 [07:28<2:12:31,  1.69s/it]

Epoch 300 Loss: 2713.896, Num_Actions 76


  8%|▊         | 401/5000 [10:29<2:27:48,  1.93s/it]

Epoch 400 Loss: 1473.257, Num_Actions 75


 10%|█         | 501/5000 [13:38<2:07:58,  1.71s/it]

Epoch 500 Loss: 40.103, Num_Actions 35


 12%|█▏        | 601/5000 [17:07<3:10:11,  2.59s/it]

Epoch 600 Loss: 3398.164, Num_Actions 81


 14%|█▍        | 701/5000 [20:35<2:52:40,  2.41s/it]

Epoch 700 Loss: 2667.533, Num_Actions 76


 16%|█▌        | 801/5000 [23:52<2:21:39,  2.02s/it]

Epoch 800 Loss: 1546.122, Num_Actions 68


 18%|█▊        | 901/5000 [27:16<2:18:29,  2.03s/it]

Epoch 900 Loss: 476.574, Num_Actions 55


 20%|██        | 1001/5000 [30:31<1:51:43,  1.68s/it]

Epoch 1000 Loss: 1250.143, Num_Actions 63


 22%|██▏       | 1101/5000 [33:40<1:49:40,  1.69s/it]

Epoch 1100 Loss: 4766.391, Num_Actions 81


 24%|██▍       | 1201/5000 [36:40<1:46:03,  1.68s/it]

Epoch 1200 Loss: 108.456, Num_Actions 39


 26%|██▌       | 1301/5000 [39:36<1:59:42,  1.94s/it]

Epoch 1300 Loss: 411.840, Num_Actions 55


 28%|██▊       | 1401/5000 [42:39<2:01:39,  2.03s/it]

Epoch 1400 Loss: 2597.125, Num_Actions 77


 30%|███       | 1501/5000 [45:56<2:47:41,  2.88s/it]

Epoch 1500 Loss: 1293.198, Num_Actions 69


 32%|███▏      | 1601/5000 [48:56<1:30:50,  1.60s/it]

Epoch 1600 Loss: 1446.398, Num_Actions 74


 34%|███▍      | 1701/5000 [51:49<1:49:29,  1.99s/it]

Epoch 1700 Loss: 1052.836, Num_Actions 72


 36%|███▌      | 1801/5000 [54:36<1:07:40,  1.27s/it]

Epoch 1800 Loss: 18.556, Num_Actions 20


 38%|███▊      | 1901/5000 [57:37<1:41:46,  1.97s/it]

Epoch 1900 Loss: 1447.294, Num_Actions 74


 40%|████      | 2001/5000 [1:00:41<1:44:02,  2.08s/it]

Epoch 2000 Loss: 1792.613, Num_Actions 77


 42%|████▏     | 2101/5000 [1:03:54<1:32:27,  1.91s/it]

Epoch 2100 Loss: 551.144, Num_Actions 61


 44%|████▍     | 2201/5000 [1:07:08<1:51:35,  2.39s/it]

Epoch 2200 Loss: 1505.237, Num_Actions 80


 46%|████▌     | 2301/5000 [1:10:24<1:32:32,  2.06s/it]

Epoch 2300 Loss: 3760.808, Num_Actions 82


 48%|████▊     | 2401/5000 [1:13:44<1:27:07,  2.01s/it]

Epoch 2400 Loss: 1421.248, Num_Actions 68


 50%|█████     | 2501/5000 [1:16:50<1:38:02,  2.35s/it]

Epoch 2500 Loss: 3060.491, Num_Actions 79


 52%|█████▏    | 2601/5000 [1:20:04<1:19:49,  2.00s/it]

Epoch 2600 Loss: 1617.362, Num_Actions 69


 54%|█████▍    | 2701/5000 [1:23:21<1:17:08,  2.01s/it]

Epoch 2700 Loss: 2197.577, Num_Actions 80


 56%|█████▌    | 2801/5000 [1:26:43<1:07:39,  1.85s/it]

Epoch 2800 Loss: 1487.862, Num_Actions 67


 58%|█████▊    | 2901/5000 [1:29:48<1:10:38,  2.02s/it]

Epoch 2900 Loss: 4757.352, Num_Actions 79


 60%|██████    | 3001/5000 [1:32:58<57:27,  1.72s/it]  

Epoch 3000 Loss: 613.285, Num_Actions 59


 62%|██████▏   | 3101/5000 [1:36:20<58:54,  1.86s/it]  

Epoch 3100 Loss: 137.557, Num_Actions 38


 64%|██████▍   | 3201/5000 [1:39:44<1:02:17,  2.08s/it]

Epoch 3200 Loss: 1924.310, Num_Actions 74


 66%|██████▌   | 3301/5000 [1:43:08<58:36,  2.07s/it]  

Epoch 3300 Loss: 3900.259, Num_Actions 76


 68%|██████▊   | 3401/5000 [1:46:25<53:48,  2.02s/it]  

Epoch 3400 Loss: 534.440, Num_Actions 57


 70%|███████   | 3501/5000 [1:49:34<49:12,  1.97s/it]  

Epoch 3500 Loss: 797.110, Num_Actions 75


 72%|███████▏  | 3601/5000 [1:52:52<51:45,  2.22s/it]

Epoch 3600 Loss: 4038.603, Num_Actions 83


 74%|███████▍  | 3701/5000 [1:56:04<46:31,  2.15s/it]

Epoch 3700 Loss: 1167.579, Num_Actions 71


 76%|███████▌  | 3801/5000 [1:59:19<42:56,  2.15s/it]

Epoch 3800 Loss: 202.431, Num_Actions 57


 78%|███████▊  | 3901/5000 [2:02:32<34:56,  1.91s/it]

Epoch 3900 Loss: 2876.469, Num_Actions 83


 80%|████████  | 4001/5000 [2:05:47<30:31,  1.83s/it]

Epoch 4000 Loss: 196.131, Num_Actions 51


 82%|████████▏ | 4101/5000 [2:08:50<31:03,  2.07s/it]

Epoch 4100 Loss: 2159.346, Num_Actions 71


 84%|████████▍ | 4201/5000 [2:11:43<21:37,  1.62s/it]

Epoch 4200 Loss: 112.461, Num_Actions 54


 86%|████████▌ | 4301/5000 [2:14:42<18:21,  1.58s/it]

Epoch 4300 Loss: 15.069, Num_Actions 37


 88%|████████▊ | 4401/5000 [2:17:48<16:33,  1.66s/it]

Epoch 4400 Loss: 550.870, Num_Actions 58


 90%|█████████ | 4501/5000 [2:20:58<15:24,  1.85s/it]

Epoch 4500 Loss: 1249.890, Num_Actions 71


 92%|█████████▏| 4601/5000 [2:24:06<11:51,  1.78s/it]

Epoch 4600 Loss: 1389.782, Num_Actions 62


 94%|█████████▍| 4701/5000 [2:27:27<10:38,  2.14s/it]

Epoch 4700 Loss: 3517.770, Num_Actions 83


 96%|█████████▌| 4801/5000 [2:30:47<05:54,  1.78s/it]

Epoch 4800 Loss: 210.254, Num_Actions 46


 98%|█████████▊| 4901/5000 [2:34:11<03:28,  2.11s/it]

Epoch 4900 Loss: 1424.540, Num_Actions 72


100%|██████████| 5000/5000 [2:37:28<00:00,  1.89s/it]


In [11]:
report_data.to_csv('training_log.csv', index=False)

In [19]:
import plotly.graph_objects as go
# Extract the data
epochs = report_data['epoch'].values
num_actions = report_data['num_actions'].values
losses = report_data['loss'].values

# Extract the data
epochs = report_data['epoch'].values
num_actions = report_data['num_actions'].values
losses = report_data['loss'].values

# Create the 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=epochs,
    y=num_actions,
    z=losses,
    mode='markers',
    marker=dict(
        size=5,
        color=losses,
        colorscale='Viridis',
        opacity=0.8
    ),
    text=[f'Epoch: {e}<br>Num Actions: {n}<br>Loss: {l}' for e, n, l in zip(epochs, num_actions, losses)],
    hoverinfo='text'
)])

# Update the layout
fig.update_layout(
    scene=dict(
        xaxis=dict(
            title='Epoch',
            range=[0, max(epochs) * 1.1]  # Extend the range slightly beyond the max epoch
        ),
        yaxis=dict(
            title='Number of Actions'
        ),
        zaxis=dict(
            title='Loss'
        )
    ),
    width=1000,
    height=800
)

# Show the plot
fig.show()

In [20]:
# Extract the data
epochs = report_data['epoch'].values
losses = report_data['loss'].values

# Create the 2D scatter plot
fig = go.Figure(data=go.Scatter(
    x=epochs,
    y=losses,
    mode='lines+markers',
    marker=dict(
        size=5,
        color='blue'
    ),
    text=[f'Epoch: {e}<br>Loss: {l}' for e, l in zip(epochs, losses)],
    hoverinfo='text'
))

# Update the layout
fig.update_layout(
    xaxis=dict(
        title='Epoch'
    ),
    yaxis=dict(
        title='Loss'
    ),
    width=1000,
    height=600,
    title='Epoch vs Loss'
)

# Show the plot
fig.show()

In [21]:
import pandas as pd
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression
import numpy as np

# Extract the data
epochs = report_data['epoch'].values.reshape(-1, 1)
losses = report_data['loss'].values

# Perform linear regression
reg = LinearRegression().fit(epochs, losses)
slope = reg.coef_[0]
intercept = reg.intercept_

# Calculate the regression line
regression_line = reg.predict(epochs)

# Create the 2D scatter plot
fig = go.Figure()

# Add the original data
fig.add_trace(go.Scatter(
    x=report_data['epoch'],
    y=report_data['loss'],
    mode='markers',
    marker=dict(
        size=5,
        color='blue'
    ),
    name='Loss',
    text=[f'Epoch: {e}<br>Loss: {l}' for e, l in zip(report_data['epoch'], report_data['loss'])],
    hoverinfo='text'
))

# Add the regression line
fig.add_trace(go.Scatter(
    x=report_data['epoch'],
    y=regression_line,
    mode='lines',
    line=dict(
        color='red'
    ),
    name='Regression Line'
))

# Update the layout
fig.update_layout(
    xaxis=dict(
        title='Epoch'
    ),
    yaxis=dict(
        title='Loss'
    ),
    width=1000,
    height=600,
    title=f'Epoch vs Loss (Slope: {slope:.4f})'
)

# Show the plot
fig.show()

# Print the slope to determine the trend
print(f"The slope of the regression line is {slope:.4f}")
if slope < 0:
    print("The values are trending down.")
elif slope > 0:
    print("The values are trending up.")
else:
    print("The values are constant.")

The slope of the regression line is -0.0379
The values are trending down.


In [13]:
log._actions


tensor([[ 95,  56, 301],
        [132,  19, 226],
        [187, 111, 211],
        [283, 285, 208],
        [247,  77, 190],
        [ 21, 230, 285],
        [ 36, 322, 135],
        [263, 301, 229],
        [188, 150, 115],
        [192, 153, 268],
        [133, 284, 284],
        [212, 207,  94],
        [324,  59, 169],
        [ -1, 192, 192],
        [ -1, 131, 174],
        [ -1, 229, 173],
        [ -1,  74, 131],
        [ -1,  95, 154],
        [ -1,  40,  21],
        [ -1, 268, 170],
        [ -1,  98, 111],
        [ -1,   1, 302],
        [ -1, 264, 250],
        [ -1, 133, 263],
        [ -1, 187, 287],
        [ -1, 250,  60],
        [ -1, 305, 152],
        [ -1, 208,  98],
        [ -1,  60, 207],
        [ -1,  76, 249],
        [ -1, 226,  77],
        [ -1, 170,  93],
        [ -1, 246, 188],
        [ -1,  78,   2],
        [ -1, 209, 305],
        [ -1, 324,  73],
        [ -1,  -1, 322],
        [ -1,  -1,  39],
        [ -1,  -1, 246],
        [ -1,  -1,  95],


In [14]:
# Function to check for duplicates across columns
def find_column_duplicates(tensor, check_value=None):
    num_columns = tensor.size(1)
    duplicates = {}
    check_value_duplicates = {}
    
    for col in range(num_columns):
        seen = set()
        col_duplicates = set()
        for row in range(tensor.size(0)):
            value = tensor[row, col].item()
            if value in seen:
                col_duplicates.add(value)
            seen.add(value)
        
        if col_duplicates:
            duplicates[col] = col_duplicates
        
        if check_value is not None and check_value in seen:
            check_value_duplicates[col] = check_value in col_duplicates
    
    return duplicates, check_value_duplicates

In [15]:
duplicates, is_negative_one_duplicate = find_column_duplicates(log._actions, check_value=-1)
print("Duplicate values by column:", duplicates)
print("Is -1 a duplicate in each column:", is_negative_one_duplicate)
    


Duplicate values by column: {0: {-1}, 1: {-1}}
Is -1 a duplicate in each column: {0: True, 1: True}


In [16]:
duplicates

{0: {-1}, 1: {-1}}

In [17]:
print(duplicates)

{0: {-1}, 1: {-1}}


In [18]:
# Sample and plot final states
s0 = one_hot(torch.zeros(10**4).long(), env.state_dim).float()
s = model.sample_states(s0, return_log=False)
# Implement your plot function or use another way to visualize the results
# plot(s, env, matrix_size)

AttributeError: 'NoneType' object has no attribute '_actions'