In [1]:
import torch
from torch.nn.functional import one_hot
from torch.optim import Adam
from tqdm import tqdm
from preconditioner import PreconditionerEnv
from policy import ForwardPolicy, BackwardPolicy
from gflownet.gflownet import GFlowNet
from gflownet.utils import sparse_one_hot
from gflownet.utils import trajectory_balance_loss, market_matrix_to_sparse_tensor
import psutil

In [2]:
def log_memory_usage(stage: str):
    process = psutil.Process()
    mem_info = process.memory_info()
    print(f"[{stage}] CPU Memory Usage: {mem_info.rss / (1024 ** 2):.2f} MB")
    if torch.cuda.is_available():
        print(f"[{stage}] GPU Memory Usage: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")


In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
matrix_path = '../LF10/LF10.mtx'  # Update this with your file path
batch_size = 3
num_epochs = 3000
lr = 0.00005

In [5]:
log_memory_usage("Before Loading Initial Matrix")

# Load the initial matrix from a file
initial_matrix = market_matrix_to_sparse_tensor(matrix_path)
matrix_size = initial_matrix.size(0)

log_memory_usage("After Loading Initial Matrix")


[Before Loading Initial Matrix] CPU Memory Usage: 204.13 MB
[After Loading Initial Matrix] CPU Memory Usage: 204.77 MB


In [6]:
# Initialize the environment and policies
env = PreconditionerEnv(matrix_size=matrix_size, initial_matrix=initial_matrix)
env.data.edge_attr.shape

torch.Size([82])

In [7]:

node_features = -1
input_dim = 1
hidden_dim = 8
forward_policy = ForwardPolicy(node_features=node_features, hidden_dim=hidden_dim, num_actions=env.num_actions)
#forward_policy = ForwardPolicy(in_channels=node_features, hidden_channels=hidden_dim, out_channels=env.num_actions)
backward_policy = BackwardPolicy(input_dim=input_dim, hidden_dim=hidden_dim, num_actions=env.num_actions)

In [8]:
env.data.edge_attr.shape

torch.Size([82])

In [9]:
def check_gradients(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                print(f"{name}: {param.grad.norm()}")
            else:
                print(f"{name}: No gradient")


In [10]:
import pandas as pd
# Initialize the GFlowNet model
model = GFlowNet(forward_policy, backward_policy, env)
opt = Adam(model.parameters(), lr=lr)

log_memory_usage("After Model Initialization")

report_data = pd.DataFrame(columns=['epoch', 'num_actions', 'loss', 'reward'])

detailed_report_data = pd.DataFrame(columns=['epoch', 'sample_number', 'num_actions', 'loss', 'reward'])

for epoch in (p := tqdm(range(num_epochs))):
   #log_memory_usage(f"Start of Epoch {epoch}")

    model.train()
    #opt.zero_grad()

    # Initialize the starting states
    initial_indices = torch.zeros(batch_size).long()
    #s0 = [sparse_one_hot(initial_indices[i:i+1], env.state_dim).float() for i in range(batch_size)]
    s0 = [initial_matrix.clone() for _ in range(batch_size)]
    #s0 = one_hot(torch.zeros(batch_size).long(), env.state_dim).float()
    # Sample final states and log information
    s, log = model.sample_states(s0, return_log=True)
    
    # Calculate the trajectory balance loss
    loss = trajectory_balance_loss(log.total_flow,
                                    log.rewards,
                                    log.fwd_probs,
                                    log.back_probs)
    
    #print(f"log.total_flow {log.total_flow}")
    #print(f"log.rewards {log.rewards}")
    #print(f"log.fwd_probs {log.fwd_probs}")
    #print(f"log.back_probs {log.back_probs}")
    #print(f"log._actions shape {len(log._actions)}")
    #print(f"Loss Calculation: {loss}")
    # Backpropagation and optimization step
    loss.backward()
    #check_gradients(model)
    opt.step()
    #named_params = model.named_parameters()
    opt.zero_grad()

    #Capture data
    total_length = len(log._actions)
    report_data = report_data.append({'epoch': epoch, 'num_actions': total_length, 'loss': loss.item(), 'reward': log.rewards}, ignore_index=True)

        # Capture data for each sample in the batch
    for sample_id in range(batch_size):
        sum_actions = log._actions.t()[sample_id]
        mask_actions = sum_actions != -1
        num_actions = mask_actions.sum()
        reward = log.rewards[sample_id].item() if isinstance(log.rewards, torch.Tensor) else log.rewards[sample_id]
        detailed_report_data = detailed_report_data.append({
            'epoch': epoch,
            'sample_number': sample_id + 1,  # Sample number within the batch/epoch
            'num_actions': num_actions.item(),
            'loss': loss.item(),
            'reward': reward
        }, ignore_index=True)
    
    if epoch % 100 == 0:
       tqdm.write(f"Epoch {epoch} Loss: {loss.item():.3f}, Num_Actions {total_length}")
        

[After Model Initialization] CPU Memory Usage: 242.67 MB


  0%|          | 1/2000 [00:01<1:00:00,  1.80s/it]

Epoch 0 Loss: 2689.235, Num_Actions 75


  5%|▌         | 101/2000 [01:57<51:23,  1.62s/it]

Epoch 100 Loss: 10725.562, Num_Actions 68


 10%|█         | 201/2000 [04:18<42:34,  1.42s/it]

Epoch 200 Loss: 16409.486, Num_Actions 48


 15%|█▌        | 301/2000 [07:18<1:01:06,  2.16s/it]

Epoch 300 Loss: 15935.632, Num_Actions 63


 20%|██        | 401/2000 [11:24<1:23:32,  3.14s/it]

Epoch 400 Loss: 10908.463, Num_Actions 43


 25%|██▌       | 501/2000 [17:31<1:29:11,  3.57s/it]

Epoch 500 Loss: 60.509, Num_Actions 14


 30%|███       | 601/2000 [24:58<1:45:47,  4.54s/it]

Epoch 600 Loss: 10318.387, Num_Actions 79


 35%|███▌      | 701/2000 [31:08<1:39:59,  4.62s/it]

Epoch 700 Loss: 27729.088, Num_Actions 83


 40%|████      | 801/2000 [37:10<58:08,  2.91s/it]  

Epoch 800 Loss: 5332.296, Num_Actions 50


 45%|████▌     | 901/2000 [42:58<1:02:11,  3.39s/it]

Epoch 900 Loss: 3407.259, Num_Actions 79


 50%|█████     | 1001/2000 [48:09<1:05:15,  3.92s/it]

Epoch 1000 Loss: 16375.021, Num_Actions 83


 55%|█████▌    | 1101/2000 [54:00<54:06,  3.61s/it]  

Epoch 1100 Loss: 2376.098, Num_Actions 61


 60%|██████    | 1201/2000 [1:00:16<41:14,  3.10s/it]

Epoch 1200 Loss: 8836.380, Num_Actions 39


 65%|██████▌   | 1301/2000 [1:06:33<43:31,  3.74s/it]  

Epoch 1300 Loss: 2108.276, Num_Actions 65


 70%|███████   | 1401/2000 [1:11:50<25:25,  2.55s/it]

Epoch 1400 Loss: 2393.307, Num_Actions 29


 75%|███████▌  | 1501/2000 [1:17:19<29:48,  3.59s/it]

Epoch 1500 Loss: 91.256, Num_Actions 83


 80%|████████  | 1601/2000 [1:22:55<24:43,  3.72s/it]

Epoch 1600 Loss: 1557.445, Num_Actions 50


 85%|████████▌ | 1701/2000 [1:29:14<16:39,  3.34s/it]

Epoch 1700 Loss: 3499.307, Num_Actions 72


 90%|█████████ | 1801/2000 [1:35:52<10:41,  3.22s/it]

Epoch 1800 Loss: 10897.868, Num_Actions 47


 95%|█████████▌| 1901/2000 [1:41:09<05:43,  3.47s/it]

Epoch 1900 Loss: 22616.217, Num_Actions 81


100%|██████████| 2000/2000 [1:47:09<00:00,  3.21s/it]


In [11]:
report_data.to_csv('training_log.csv', index=False)

In [12]:
detailed_report_data.to_csv('detailed_training_log.csv', index=False)

In [13]:
import plotly.graph_objects as go
# Extract the data
epochs = report_data['epoch'].values
num_actions = report_data['num_actions'].values
losses = report_data['loss'].values

# Extract the data
epochs = report_data['epoch'].values
num_actions = report_data['num_actions'].values
losses = report_data['loss'].values

# Create the 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=epochs,
    y=num_actions,
    z=losses,
    mode='markers',
    marker=dict(
        size=5,
        color=losses,
        colorscale='Viridis',
        opacity=0.8
    ),
    text=[f'Epoch: {e}<br>Num Actions: {n}<br>Loss: {l}' for e, n, l in zip(epochs, num_actions, losses)],
    hoverinfo='text'
)])

# Update the layout
fig.update_layout(
    scene=dict(
        xaxis=dict(
            title='Epoch',
            range=[0, max(epochs) * 1.1]  # Extend the range slightly beyond the max epoch
        ),
        yaxis=dict(
            title='Number of Actions'
        ),
        zaxis=dict(
            title='Loss'
        )
    ),
    width=1000,
    height=800
)

# Show the plot
fig.show()

In [14]:
# Extract the data
epochs = report_data['epoch'].values
losses = report_data['loss'].values

# Create the 2D scatter plot
fig = go.Figure(data=go.Scatter(
    x=epochs,
    y=losses,
    mode='lines+markers',
    marker=dict(
        size=5,
        color='blue'
    ),
    text=[f'Epoch: {e}<br>Loss: {l}' for e, l in zip(epochs, losses)],
    hoverinfo='text'
))

# Update the layout
fig.update_layout(
    xaxis=dict(
        title='Epoch'
    ),
    yaxis=dict(
        title='Loss'
    ),
    width=1000,
    height=600,
    title='Epoch vs Loss'
)

# Show the plot
fig.show()

In [15]:
import pandas as pd
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression
import numpy as np

# Extract the data
epochs = report_data['epoch'].values.reshape(-1, 1)
losses = report_data['loss'].values

# Perform linear regression
reg = LinearRegression().fit(epochs, losses)
slope = reg.coef_[0]
intercept = reg.intercept_

# Calculate the regression line
regression_line = reg.predict(epochs)

# Create the 2D scatter plot
fig = go.Figure()

# Add the original data
fig.add_trace(go.Scatter(
    x=report_data['epoch'],
    y=report_data['loss'],
    mode='markers',
    marker=dict(
        size=5,
        color='blue'
    ),
    name='Loss',
    text=[f'Epoch: {e}<br>Loss: {l}' for e, l in zip(report_data['epoch'], report_data['loss'])],
    hoverinfo='text'
))

# Add the regression line
fig.add_trace(go.Scatter(
    x=report_data['epoch'],
    y=regression_line,
    mode='lines',
    line=dict(
        color='red'
    ),
    name='Regression Line'
))

# Update the layout
fig.update_layout(
    xaxis=dict(
        title='Epoch'
    ),
    yaxis=dict(
        title='Loss'
    ),
    width=1000,
    height=600,
    title=f'Epoch vs Loss (Slope: {slope:.4f})'
)

# Show the plot
fig.show()

# Print the slope to determine the trend
print(f"The slope of the regression line is {slope:.4f}")
if slope < 0:
    print("The values are trending down.")
elif slope > 0:
    print("The values are trending up.")
else:
    print("The values are constant.")

The slope of the regression line is 0.0690
The values are trending up.


In [16]:
# Function to check for duplicates across columns
def find_column_duplicates(tensor, check_value=None):
    num_columns = tensor.size(1)
    duplicates = {}
    check_value_duplicates = {}
    
    for col in range(num_columns):
        seen = set()
        col_duplicates = set()
        for row in range(tensor.size(0)):
            value = tensor[row, col].item()
            if value in seen:
                col_duplicates.add(value)
            seen.add(value)
        
        if col_duplicates:
            duplicates[col] = col_duplicates
        
        if check_value is not None and check_value in seen:
            check_value_duplicates[col] = check_value in col_duplicates
    
    return duplicates, check_value_duplicates

In [17]:
duplicates, is_negative_one_duplicate = find_column_duplicates(log._actions, check_value=-1)
print("Duplicate values by column:", duplicates)
print("Is -1 a duplicate in each column:", is_negative_one_duplicate)
    


Duplicate values by column: {0: {-1}}
Is -1 a duplicate in each column: {0: True, 2: False}


In [18]:
duplicates

{0: {-1}}

In [19]:
print(duplicates)

{0: {-1}}


In [20]:
# Sample and plot final states
s0 = one_hot(torch.zeros(10**4).long(), env.state_dim).float()
s = model.sample_states(s0, return_log=False)
# Implement your plot function or use another way to visualize the results
# plot(s, env, matrix_size)

AttributeError: 'NoneType' object has no attribute '_actions'