In [15]:
import torch
from torch.nn.functional import one_hot
from torch.optim import Adam
from tqdm import tqdm
from preconditioner import PreconditionerEnv
from policy import ForwardPolicy, BackwardPolicy
from gflownet.gflownet import GFlowNet
from gflownet.utils import sparse_one_hot
from gflownet.utils import trajectory_balance_loss, market_matrix_to_sparse_tensor
import psutil

In [16]:
def log_memory_usage(stage: str):
    process = psutil.Process()
    mem_info = process.memory_info()
    print(f"[{stage}] CPU Memory Usage: {mem_info.rss / (1024 ** 2):.2f} MB")
    if torch.cuda.is_available():
        print(f"[{stage}] GPU Memory Usage: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")


In [17]:
import warnings
warnings.filterwarnings('ignore')

In [18]:
matrix_path = '../pivtol/pivtol.mtx'  # Update this with your file path
batch_size = 3
num_epochs = 3000
lr = 0.00005

# Run GMRES Without Preconditioner As Baseline

In [23]:
import numpy as np
from scipy.io import mmread
from scipy.sparse.linalg import gmres, spilu, LinearOperator
from scipy.sparse import csr_matrix
import time

# Function to load matrix A from .mtx file
def load_mtx_file(file_path):
    matrix = mmread(file_path)
    return csr_matrix(matrix)

# Function to load vector b from .mtx file and ensure it is correctly shaped
def load_vector_mtx(file_path):
    vector = mmread(file_path)  # Load the vector (could be sparse or dense)
    
    # Check if the loaded data is a sparse matrix, if so convert it to a dense array
    if hasattr(vector, "toarray"):
        vector = vector.toarray()
    
    # Flatten the array if it's a row or column vector
    vector = vector.flatten()
    
    return vector

# Function to solve the system using GMRES with an optional preconditioner
def solve_with_gmres(A, b, M=None):
    # Ensure b is a 1D array with the same number of rows as A
    b = b.flatten()
    if b.shape[0] != A.shape[0]:
        raise ValueError(f"Shape mismatch: A is {A.shape}, but b is {b.shape}")
    
    # Initial guess (zero vector)
    x0 = np.zeros(b.shape)

    # Lists to store iteration number and residual norm
    residuals = []
    
    # Callback function to capture residual norm at each iteration
    def callback(rk):
        residuals.append(rk)
    
    # Measure computational time
    start_time = time.time()
    
    # Use GMRES to solve the system Ax = b with preconditioner M
    x, exitCode = gmres(A, b, x0=x0, M=M, maxiter=1000000, callback=callback)
    
    elapsed_time = time.time() - start_time
    
    if exitCode == 0:
        print("GMRES converged successfully.")
    else:
        print(f"GMRES did not converge. Exit code: {exitCode}")
    
    # Number of iterations is the length of the residuals list
    num_iterations = len(residuals)
    
    return x, residuals, num_iterations, elapsed_time


In [25]:
# Example usage
mtx_file_path_A = matrix_path  # Replace with your actual matrix file path
mtx_file_path_b = '../pivtol/pivtol_b.mtx'  # Replace with your actual vector file path

# Load the vector data as a numpy array
b_data = mmread(mtx_file_path_b)

# Check if the data is a matrix (should be, given its shape is 102x3)
if b_data.ndim == 2 and b_data.shape[1] > 1:
    # Extract the first column
    b_single = b_data[:, 0]
else:
    raise ValueError("The loaded data is not in the expected format or does not have multiple columns.")

# Verify the shape of the extracted vector
b_single.shape, b_single


# Load A and b from the .mtx files
A = load_mtx_file(mtx_file_path_A)
b = b_single

# Solve the system
x_np, res_np, iter_np, time_np = solve_with_gmres(A, b)

# Output the solution
print(f"Residual Norm No Preconditioner: {res_np}")
print(f"No iterations No Preconditioner: {iter_np}")
print(f"Elapsed Time No Preconditioner: {time_np}")

ValueError: Shape mismatch: A is (102, 102), but b is (306,)

In [32]:
log_memory_usage("Before Loading Initial Matrix")

# Load the initial matrix from a file
original_matrix = market_matrix_to_sparse_tensor(matrix_path)

log_memory_usage("After Loading Initial Matrix")


[Before Loading Initial Matrix] CPU Memory Usage: 78.45 MB
[After Loading Initial Matrix] CPU Memory Usage: 80.70 MB


# Structured Sampling Preconditioner

In [17]:
#Takes in a PyTorch sparse tensor and samples blocks of a certain size, removing a certain ratio.
def structured_sampling(matrix, block_size, keep_ratio=0.5):
    if not matrix.is_sparse:
        raise ValueError("Input matrix must be a PyTorch sparse tensor")

    indices = matrix._indices()
    values = matrix._values()
    n, m = matrix.shape
    blocks = []
    block_positions = []

    for i in range(0, n, block_size):
        for j in range(0, m, block_size):
            mask = (indices[0] >= i) & (indices[0] < i + block_size) & (indices[1] >= j) & (indices[1] < j + block_size)
            block_indices = indices[:, mask]
            block_values = values[mask]
            if block_indices.size(1) > 0:  # If the block has non-zero elements
                # Sort block values and keep the top elements based on keep_ratio
                num_non_zeros = block_indices.size(1)
                num_to_keep = max(1, int(num_non_zeros * keep_ratio))
                _, top_indices = torch.topk(block_values.abs(), num_to_keep)
                
                reduced_block_indices = block_indices[:, top_indices]
                reduced_block_values = block_values[top_indices]

                reduced_block_indices[0] -= i
                reduced_block_indices[1] -= j
                block_size_tensor = torch.Size([block_size, block_size])
                blocks.append(torch.sparse_coo_tensor(reduced_block_indices, reduced_block_values, block_size_tensor))
                block_positions.append((i, j))

    if len(blocks) == 0:
        return torch.sparse_coo_tensor(matrix.size())  # Return an empty sparse matrix if no blocks found

    block_diag_indices = []
    block_diag_values = []

    for (block, (i_offset, j_offset)) in zip(blocks, block_positions):
        b_indices = block._indices()
        b_values = block._values()
        b_indices[0] += i_offset
        b_indices[1] += j_offset
        block_diag_indices.append(b_indices)
        block_diag_values.append(b_values)

    block_diag_indices = torch.cat(block_diag_indices, dim=1)
    block_diag_values = torch.cat(block_diag_values)
    sparse_subset = torch.sparse_coo_tensor(block_diag_indices, block_diag_values, (n, m))

    return sparse_subset

In [18]:
initial_matrix = structured_sampling(original_matrix, 4, 0.75)
matrix_size = initial_matrix.size(0)

In [22]:
# Assuming initial_matrix is a PyTorch tensor
# Convert the PyTorch tensor to a NumPy array
initial_matrix_dense = initial_matrix.to_dense()
initial_matrix_np = initial_matrix_dense.detach().cpu().numpy()

# Define the matvec function for the LinearOperator
def matvec(x):
    return initial_matrix_np.dot(x)

# Create the LinearOperator using the matvec function
M = LinearOperator(shape=A.shape, matvec=matvec)

In [23]:
# Solve the system without preconditioner
solution_no_prec, residuals_no_prec, num_iterations_no_prec, elapsed_time_no_prec = solve_with_gmres(A, b)

# Solve the system with ILU preconditioner
solution_prec, residuals_prec, num_iterations_prec, elapsed_time_prec = solve_with_gmres(A, b, M=M)

# Output the results for comparison
print("\nGMRES without preconditioner:")
print(f"Number of iterations: {num_iterations_no_prec}")
print(f"Final residual norm: {residuals_no_prec[-1]}")
print(f"Elapsed time: {elapsed_time_no_prec:.4f} seconds")

print("\nGMRES with Sparse preconditioner:")
print(f"Number of iterations: {num_iterations_prec}")
print(f"Final residual norm: {residuals_prec[-1]}")
print(f"Elapsed time: {elapsed_time_prec:.4f} seconds")

GMRES converged successfully.
GMRES converged successfully.

GMRES without preconditioner:
Number of iterations: 17
Final residual norm: 1.4499262075380584e-06
Elapsed time: 0.0016 seconds

GMRES with ILU preconditioner:
Number of iterations: 26
Final residual norm: 1.2296403292993637e-09
Elapsed time: 0.0021 seconds


In [None]:
print(initial_matrix)

In [None]:
# Initialize the environment and policies
env = PreconditionerEnv(matrix_size=matrix_size, initial_matrix=initial_matrix, original_matrix=original_matrix)
env.data.edge_attr.shape

In [None]:

node_features = -1
input_dim = 1
hidden_dim = 8
forward_policy = ForwardPolicy(node_features=node_features, hidden_dim=hidden_dim, num_actions=env.num_actions)
#forward_policy = ForwardPolicy(in_channels=node_features, hidden_channels=hidden_dim, out_channels=env.num_actions)
backward_policy = BackwardPolicy(input_dim=input_dim, hidden_dim=hidden_dim, num_actions=env.num_actions)

In [None]:
env.data.edge_attr.shape

In [None]:
env.num_actions

In [None]:
def check_gradients(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                print(f"{name}: {param.grad.norm()}")
            else:
                print(f"{name}: No gradient")


In [None]:
import pandas as pd
# Initialize the GFlowNet model
model = GFlowNet(forward_policy, backward_policy, env)
opt = Adam(model.parameters(), lr=lr)

log_memory_usage("After Model Initialization")

report_data = pd.DataFrame(columns=['epoch', 'num_actions', 'loss', 'reward'])

detailed_report_data = pd.DataFrame(columns=['epoch', 'sample_number', 'num_actions', 'loss', 'reward'])

for epoch in (p := tqdm(range(num_epochs))):
   #log_memory_usage(f"Start of Epoch {epoch}")

    model.train()
    #opt.zero_grad()

    # Initialize the starting states
    initial_indices = torch.zeros(batch_size).long()
    #s0 = [sparse_one_hot(initial_indices[i:i+1], env.state_dim).float() for i in range(batch_size)]
    s0 = [initial_matrix.clone() for _ in range(batch_size)]
    #s0 = one_hot(torch.zeros(batch_size).long(), env.state_dim).float()
    # Sample final states and log information
    s, log = model.sample_states(s0, return_log=True)
    
    # Calculate the trajectory balance loss
    loss = trajectory_balance_loss(log.total_flow,
                                    log.rewards,
                                    log.fwd_probs,
                                    log.back_probs)
    
    #print(f"log.total_flow {log.total_flow}")
    #print(f"log.rewards {log.rewards}")
    #print(f"log.fwd_probs {log.fwd_probs}")
    #print(f"log.back_probs {log.back_probs}")
    #print(f"log._actions shape {len(log._actions)}")
    #print(f"Loss Calculation: {loss}")
    # Backpropagation and optimization step
    loss.backward()
    #check_gradients(model)
    opt.step()
    #named_params = model.named_parameters()
    opt.zero_grad()

    #Capture data
    total_length = len(log._actions)
    report_data = report_data.append({'epoch': epoch, 'num_actions': total_length, 'loss': loss.item(), 'reward': log.rewards}, ignore_index=True)

        # Capture data for each sample in the batch
    for sample_id in range(batch_size):
        sum_actions = log._actions.t()[sample_id]
        mask_actions = sum_actions != -1
        num_actions = mask_actions.sum()
        reward = log.rewards[sample_id].item() if isinstance(log.rewards, torch.Tensor) else log.rewards[sample_id]
        detailed_report_data = detailed_report_data.append({
            'epoch': epoch,
            'sample_number': sample_id + 1,  # Sample number within the batch/epoch
            'num_actions': num_actions.item(),
            'loss': loss.item(),
            'reward': reward
        }, ignore_index=True)
    
    if epoch % 100 == 0:
       tqdm.write(f"Epoch {epoch} Loss: {loss.item():.3f}, Num_Actions {total_length}")
        

In [None]:
report_data.to_csv('training_log.csv', index=False)

In [None]:
detailed_report_data.to_csv('detailed_training_log.csv', index=False)

In [None]:
import plotly.graph_objects as go
# Extract the data
epochs = report_data['epoch'].values
num_actions = report_data['num_actions'].values
losses = report_data['loss'].values

# Extract the data
epochs = report_data['epoch'].values
num_actions = report_data['num_actions'].values
losses = report_data['loss'].values

# Create the 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=epochs,
    y=num_actions,
    z=losses,
    mode='markers',
    marker=dict(
        size=5,
        color=losses,
        colorscale='Viridis',
        opacity=0.8
    ),
    text=[f'Epoch: {e}<br>Num Actions: {n}<br>Loss: {l}' for e, n, l in zip(epochs, num_actions, losses)],
    hoverinfo='text'
)])

# Update the layout
fig.update_layout(
    scene=dict(
        xaxis=dict(
            title='Epoch',
            range=[0, max(epochs) * 1.1]  # Extend the range slightly beyond the max epoch
        ),
        yaxis=dict(
            title='Number of Actions'
        ),
        zaxis=dict(
            title='Loss'
        )
    ),
    width=1000,
    height=800
)

# Show the plot
fig.show()

In [None]:
# Extract the data
epochs = report_data['epoch'].values
losses = report_data['loss'].values

# Create the 2D scatter plot
fig = go.Figure(data=go.Scatter(
    x=epochs,
    y=losses,
    mode='lines+markers',
    marker=dict(
        size=5,
        color='blue'
    ),
    text=[f'Epoch: {e}<br>Loss: {l}' for e, l in zip(epochs, losses)],
    hoverinfo='text'
))

# Update the layout
fig.update_layout(
    xaxis=dict(
        title='Epoch'
    ),
    yaxis=dict(
        title='Loss'
    ),
    width=1000,
    height=600,
    title='Epoch vs Loss'
)

# Show the plot
fig.show()

In [None]:
import pandas as pd
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression
import numpy as np

# Extract the data
epochs = report_data['epoch'].values.reshape(-1, 1)
losses = report_data['loss'].values

# Perform linear regression
reg = LinearRegression().fit(epochs, losses)
slope = reg.coef_[0]
intercept = reg.intercept_

# Calculate the regression line
regression_line = reg.predict(epochs)

# Create the 2D scatter plot
fig = go.Figure()

# Add the original data
fig.add_trace(go.Scatter(
    x=report_data['epoch'],
    y=report_data['loss'],
    mode='markers',
    marker=dict(
        size=5,
        color='blue'
    ),
    name='Loss',
    text=[f'Epoch: {e}<br>Loss: {l}' for e, l in zip(report_data['epoch'], report_data['loss'])],
    hoverinfo='text'
))

# Add the regression line
fig.add_trace(go.Scatter(
    x=report_data['epoch'],
    y=regression_line,
    mode='lines',
    line=dict(
        color='red'
    ),
    name='Regression Line'
))

# Update the layout
fig.update_layout(
    xaxis=dict(
        title='Epoch'
    ),
    yaxis=dict(
        title='Loss'
    ),
    width=1000,
    height=600,
    title=f'Epoch vs Loss (Slope: {slope:.4f})'
)

# Show the plot
fig.show()

# Print the slope to determine the trend
print(f"The slope of the regression line is {slope:.4f}")
if slope < 0:
    print("The values are trending down.")
elif slope > 0:
    print("The values are trending up.")
else:
    print("The values are constant.")

In [None]:
# Function to check for duplicates across columns
def find_column_duplicates(tensor, check_value=None):
    num_columns = tensor.size(1)
    duplicates = {}
    check_value_duplicates = {}
    
    for col in range(num_columns):
        seen = set()
        col_duplicates = set()
        for row in range(tensor.size(0)):
            value = tensor[row, col].item()
            if value in seen:
                col_duplicates.add(value)
            seen.add(value)
        
        if col_duplicates:
            duplicates[col] = col_duplicates
        
        if check_value is not None and check_value in seen:
            check_value_duplicates[col] = check_value in col_duplicates
    
    return duplicates, check_value_duplicates

In [None]:
duplicates, is_negative_one_duplicate = find_column_duplicates(log._actions, check_value=-1)
print("Duplicate values by column:", duplicates)
print("Is -1 a duplicate in each column:", is_negative_one_duplicate)
    


In [None]:
duplicates

In [None]:
print(duplicates)

In [None]:
print(log._actions.shape)
print(log._traj.shape)

In [None]:
# Sample and plot final states
s0 = one_hot(torch.zeros(10**4).long(), env.state_dim).float()
s = model.sample_states(s0, return_log=False)
# Implement your plot function or use another way to visualize the results
# plot(s, env, matrix_size)