In [None]:
# !git clone https://github.com/sami-ka/gnn_slotting_optimization.git

In [None]:
# ! pip install polars>=1.36.1
# ! pip install pyzmq>=27.1.0 scipy>=1.16.3 torch>=2.9.1 torch-geometric==2.4.0

In [3]:
from pathlib import Path
import sys

root = Path("/")

sys.path.insert(0, "/content/gnn_slotting_optimization" )

In [4]:
from slotting_optimization.generator import DataGenerator
from slotting_optimization.order_book import OrderBook
from slotting_optimization.item_locations import ItemLocations
from slotting_optimization.warehouse import Warehouse

gen = DataGenerator()
samples = gen.generate_samples(20, 20, 300, 1, 10, n_samples=10000, distances_fixed=True, seed=5)

In [5]:
# ===== CHANGE 1: Switch to 3D Edge Attributes =====
from slotting_optimization.gnn_builder import build_graph_3d_sparse  # CHANGED
from slotting_optimization.simulator import Simulator
list_data = []
for (ob, il, w) in samples:
    g_data = build_graph_3d_sparse(  # CHANGED from build_graph_sparse
        order_book=ob,
        item_locations=il,
        warehouse=w,
        simulator=Simulator().simulate
    )
    list_data.append(g_data)

print(f"✓ Generated {len(list_data)} graphs with 3D edge attributes")
print(f"  First graph: {list_data[0].num_nodes} nodes, {list_data[0].edge_index.shape[1]} edges")
print(f"  Edge attributes shape: {list_data[0].edge_attr.shape}")  # Should be [num_edges, 3]

In [6]:
import torch
torch.manual_seed(12345)
train_split_idx = int(0.8 * len(list_data))
train_dataset = list_data[:train_split_idx]
test_dataset = list_data[train_split_idx:]

from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [7]:
# ===== STEP 1: Compute Baseline Performance =====
import numpy as np

# Extract all target values from datasets
train_targets = np.array([data.y.item() for data in train_dataset])
test_targets = np.array([data.y.item() for data in test_dataset])
all_targets = np.concatenate([train_targets, test_targets])

# Compute statistics
train_mean = train_targets.mean()
train_std = train_targets.std()
train_var = train_targets.var()
train_min = train_targets.min()
train_max = train_targets.max()

# Baseline MSE = variance (always predicting mean)
baseline_mse = train_var
baseline_rmse = np.sqrt(baseline_mse)
normalized_rmse = baseline_rmse / train_mean  # RMSE as % of mean

print("=" * 60)
print("BASELINE PERFORMANCE")
print("=" * 60)
print(f"Training set statistics:")
print(f"  Mean distance:     {train_mean:,.2f}")
print(f"  Std deviation:     {train_std:,.2f}")
print(f"  Min distance:      {train_min:,.2f}")
print(f"  Max distance:      {train_max:,.2f}")
print(f"\nBaseline (always predict mean):")
print(f"  MSE:               {baseline_mse:,.2f}")
print(f"  RMSE:              {baseline_rmse:,.2f}")
print(f"  Normalized RMSE:   {normalized_rmse:.2%} of mean")
print(f"\nTest set mean:       {test_targets.mean():,.2f}")
print(f"Test set std:        {test_targets.std():,.2f}")
print("=" * 60)
print(f"\n⚠️  Current model MSE (~1e6) vs baseline MSE ({baseline_mse:.2e})")
print(f"   Goal: Beat baseline by achieving MSE < {baseline_mse:.2e}")
print("=" * 60)

In [8]:
# ===== STEP 2: Normalize Targets =====
# Normalize targets to have mean=0, std=1
# This makes optimization much easier

# Compute normalization parameters from training set ONLY
mean_y = train_mean
std_y = train_std

print(f"\nTarget normalization parameters:")
print(f"  mean_y = {mean_y:,.2f}")
print(f"  std_y  = {std_y:,.2f}")

# Normalize all targets in both train and test datasets
for data in train_dataset:
    data.y = (data.y - mean_y) / std_y

for data in test_dataset:
    data.y = (data.y - mean_y) / std_y

# Verify normalization
train_y_norm = np.array([data.y.item() for data in train_dataset])
test_y_norm = np.array([data.y.item() for data in test_dataset])

print(f"\nAfter normalization:")
print(f"  Train: mean={train_y_norm.mean():.4f}, std={train_y_norm.std():.4f}")
print(f"  Test:  mean={test_y_norm.mean():.4f}, std={test_y_norm.std():.4f}")
print(f"\n✓ Targets normalized! Loss scale should now be ~1.0")
print("  (Remember to denormalize predictions: y_pred * std_y + mean_y)")


In [9]:
# ===== CHANGE 2: Normalize Edge Attributes =====
# Normalize 3D edge attributes to have mean=0, std=1 per dimension

# Compute edge attribute statistics from training set
all_edge_attrs = torch.cat([data.edge_attr for data in train_dataset], dim=0)
edge_mean = all_edge_attrs.mean(dim=0)  # [3]
edge_std = all_edge_attrs.std(dim=0)    # [3]

print("=" * 60)
print("EDGE ATTRIBUTE NORMALIZATION")
print("=" * 60)
print(f"Edge attribute dimensions: [distance, sequence_count, assignment]")
print(f"  Mean: {edge_mean}")
print(f"  Std:  {edge_std}")

# Normalize edge attributes in both train and test datasets
for data in train_dataset:
    data.edge_attr = (data.edge_attr - edge_mean) / (edge_std + 1e-8)

for data in test_dataset:
    data.edge_attr = (data.edge_attr - edge_mean) / (edge_std + 1e-8)

# Verify normalization
all_edge_attrs_norm = torch.cat([data.edge_attr for data in train_dataset], dim=0)
print(f"\nAfter normalization:")
print(f"  Mean: {all_edge_attrs_norm.mean(dim=0)}")
print(f"  Std:  {all_edge_attrs_norm.std(dim=0)}")
print("=" * 60)
print("✓ Edge attributes normalized!")

In [10]:
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()
    break

In [None]:
from torch import nn
from torch_geometric.nn import MessagePassing, global_add_pool

In [None]:
class EdgeThenNodeLayer(MessagePassing):
    def __init__(self, node_dim, edge_dim):
        super().__init__(aggr="add")
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * node_dim + edge_dim, edge_dim),
            nn.ReLU(),
            nn.Linear(edge_dim, edge_dim),
        )
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim + edge_dim, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, node_dim),
        )

    def forward(self, x, edge_index, edge_attr):
        row, col = edge_index
        edge_attr = self.edge_mlp(
            torch.cat([x[row], x[col], edge_attr], dim=1)
        )
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        return x, edge_attr

    def message(self, x_j, edge_attr):
        return self.node_mlp(torch.cat([x_j, edge_attr], dim=1))

In [13]:
class NodeThenEdgeLayer(MessagePassing):
    def __init__(self, node_dim, edge_dim):
        super().__init__(aggr="add")
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim + edge_dim, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, node_dim),
        )
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * node_dim + edge_dim, edge_dim),
            nn.ReLU(),
            nn.Linear(edge_dim, edge_dim),
        )

    def forward(self, x, edge_index, edge_attr):
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        row, col = edge_index
        edge_attr = self.edge_mlp(
            torch.cat([x[row], x[col], edge_attr], dim=1)
        )
        return x, edge_attr

    def message(self, x_j, edge_attr):
        return self.node_mlp(torch.cat([x_j, edge_attr], dim=1))

In [None]:
class GCNBlock(nn.Module):
    def __init__(self, node_dim, edge_dim):
        super().__init__()
        self.edge_then_node = EdgeThenNodeLayer(node_dim, edge_dim)
        self.node_then_edge = NodeThenEdgeLayer(node_dim, edge_dim)

    def forward(self, x, edge_index, edge_attr):
        x, edge_attr = self.edge_then_node(x, edge_index, edge_attr)
        x, edge_attr = self.node_then_edge(x, edge_index, edge_attr)
        return x, edge_attr

In [None]:
# ===== CHANGE 3: Fix GraphRegressionModel =====
# Remove batch-dependent node embeddings, use small random initialization

class GraphRegressionModel(nn.Module):
    def __init__(self, hidden_dim, edge_dim, num_layers):  # REMOVED num_nodes
        super().__init__()

        # REMOVED: self.node_embedding = nn.Embedding(num_nodes, hidden_dim)
        # Instead, we'll initialize node features with small random values in forward()

        self.hidden_dim = hidden_dim  # Store for use in forward()
        self.edge_encoder = nn.Linear(edge_dim, hidden_dim)

        self.layers = nn.ModuleList(
            [GCNBlock(hidden_dim, hidden_dim) for _ in range(num_layers)]
        )

        self.regressor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, data):
        # CHANGED: Initialize node features with small random values
        # Small random initialization helps with gradient flow (better than all zeros)
        x = torch.randn((data.num_nodes, self.hidden_dim), 
                       device=data.edge_index.device) * 0.01

        edge_attr = self.edge_encoder(data.edge_attr)

        for layer in self.layers:
            x, edge_attr = layer(x, data.edge_index, edge_attr)

        graph_emb = global_add_pool(x, data.batch)
        out = self.regressor(graph_emb)
        return out.squeeze(-1)

In [None]:
# Read pytorch model checkpoint from pt file
import torch

# Load the checkpoint (it's a dictionary, not the model itself)
checkpoint = torch.load("final_model.pt", weights_only=False)

# You need to recreate the model architecture first, then load the state dict
# From your training notebook, the model was:
model = GraphRegressionModel(
    hidden_dim=64,
    edge_dim=3,
    num_layers=5
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Get normalization parameters for denormalizing predictions
mean_y = checkpoint["mean_y"]
std_y = checkpoint["std_y"]

# Read test set
test_data = torch.load("test_dataset.pt", weights_only=False)

In [16]:
# ===== CHANGE 4: Update Model Instantiation =====
# Updated for 3D edge attributes, removed num_nodes, and increased layers

model = GraphRegressionModel(
    hidden_dim=64,
    edge_dim=3,  # CHANGED from 1 to 3 for 3D edge attributes
    num_layers=5  # CHANGED from 3 to 5 for better message passing
)
print(model)
print(f"\n✓ Model updated:")
print(f"  hidden_dim = 64")
print(f"  edge_dim = 3 (3D edge attributes)")
print(f"  num_layers = 5 (increased for better information flow)")
print(f"  Node features: small random initialization (0.01 * randn)")


In [17]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

inputs = data.to(device)
targets = data.y.to(device)

In [19]:
# ===== STEPS 4-5: Updated Training with Diagnostics =====
import torch

# CRITICAL FIX: Reduced learning rate from 0.001 to 0.0001 (prevents explosion)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  # CHANGED from 0.001
criterion = torch.nn.MSELoss()

# Add learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=10,
)

import matplotlib.pyplot as plt

# Tracking metrics
train_batch_losses = []
train_epoch_losses = []
val_epoch_losses = []
grad_norms = []
learning_rates = []


def compute_gradient_norm(model):
    """Compute the total gradient norm across all parameters."""
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm


def train():
    model.train()
    epoch_loss = 0.0
    n_samples = 0

    for data in train_loader:
        data = data.to(device)

        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

        batch_size = data.y.size(0)
        train_batch_losses.append(loss.item())
        epoch_loss += loss.item() * batch_size
        n_samples += batch_size

    train_epoch_losses.append(epoch_loss / n_samples)


def test(loader):
    model.eval()
    mse = 0.0

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            pred = model(data)
            mse += ((pred - data.y) ** 2).sum().item()

    return mse / len(loader.dataset)


def log_diagnostics(epoch, train_mse, val_mse, grad_norm, lr):
    """Print diagnostic information."""
    if epoch == 1 or epoch % 10 == 0:
        # Denormalize MSE for interpretability
        train_mse_denorm = train_mse * (std_y ** 2)
        val_mse_denorm = val_mse * (std_y ** 2)
        
        print(f"\nEpoch {epoch:03d}:")
        print(f"  Train MSE (norm):   {train_mse:.6f}")
        print(f"  Val MSE (norm):     {val_mse:.6f}")
        print(f"  Train MSE (orig):   {train_mse_denorm:,.2f}")
        print(f"  Val MSE (orig):     {val_mse_denorm:,.2f}")
        print(f"  Baseline MSE:       {baseline_mse:,.2f}")
        print(f"  Improvement:        {(1 - val_mse_denorm/baseline_mse)*100:.2f}%")
        print(f"  Grad norm:          {grad_norm:.6f}")
        print(f"  Learning rate:      {lr:.6f}")
        
        # Check for issues
        if grad_norm < 1e-6:
            print(f"  ⚠️  WARNING: Gradients very small ({grad_norm:.2e})")
        elif grad_norm > 100:
            print(f"  ⚠️  WARNING: Gradients large ({grad_norm:.2e})")


best_val_mse = float("inf")
best_epoch = 0

print("=" * 60)
print("TRAINING START - ATTEMPT 2")
print("=" * 60)
print(f"Learning rate: {optimizer.param_groups[0]['lr']} (LOWERED to prevent explosion)")
print(f"Gradient clipping: max_norm=1.0")
print(f"Num layers: 5 (increased for better message passing)")
print(f"Node init: small random (0.01 * randn)")
print(f"Target normalization: mean={mean_y:.2f}, std={std_y:.2f}")
print(f"Baseline MSE to beat: {baseline_mse:,.2f}")
print("=" * 60)

for epoch in range(1, 201):  # Start with 50 epochs
    train()
    
    # Compute gradient norm after training step
    # (Do a dummy forward-backward to get current gradients)
    model.train()
    sample_data = next(iter(train_loader)).to(device)
    optimizer.zero_grad()
    sample_out = model(sample_data)
    sample_loss = criterion(sample_out, sample_data.y)
    sample_loss.backward()
    grad_norm = compute_gradient_norm(model)
    grad_norms.append(grad_norm)
    
    # Evaluate
    train_mse = test(train_loader)
    val_mse = test(test_loader)
    val_epoch_losses.append(val_mse)
    
    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    
    # Log diagnostics
    log_diagnostics(epoch, train_mse, val_mse, grad_norm, current_lr)
    
    # Update learning rate scheduler
    scheduler.step(val_mse)
    
    # Save best model
    if val_mse < best_val_mse:
        best_val_mse = val_mse
        best_epoch = epoch
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_mse": val_mse,
                "mean_y": mean_y,
                "std_y": std_y,
            },
            "best_model.pt"
        )

print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Best epoch: {best_epoch}")
print(f"Best val MSE (normalized): {best_val_mse:.6f}")
print(f"Best val MSE (original):   {best_val_mse * std_y**2:,.2f}")
print(f"Baseline MSE:              {baseline_mse:,.2f}")
print(f"Improvement over baseline: {(1 - (best_val_mse * std_y**2)/baseline_mse)*100:.2f}%")
print("=" * 60)

In [20]:
# ===== STEP 6: Enhanced Visualizations =====
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. Training and Validation Loss
ax1 = axes[0, 0]
epochs_range = range(1, len(train_epoch_losses) + 1)
ax1.plot(epochs_range, train_epoch_losses, label='Train Loss', marker='o', alpha=0.7)
ax1.plot(epochs_range, val_epoch_losses, label='Val Loss', marker='s', alpha=0.7)
ax1.axhline(y=1.0, color='r', linestyle='--', label='Normalized Baseline (variance=1.0)', alpha=0.5)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('MSE (Normalized)')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Batch Loss (smoothed)
ax2 = axes[0, 1]
def moving_average(values, window):
    return [
        sum(values[max(0, i - window):i + 1]) /
        (i - max(0, i - window) + 1)
        for i in range(len(values))
    ]
smoothed_batch_losses = moving_average(train_batch_losses, window=50)
ax2.plot(train_batch_losses, alpha=0.2, label='Batch Loss', color='blue')
ax2.plot(smoothed_batch_losses, label='Smoothed Batch Loss', color='orange', linewidth=2)
ax2.set_xlabel('Training Step')
ax2.set_ylabel('MSE (Normalized)')
ax2.set_title('Batch-Level Training Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Gradient Norms
ax3 = axes[1, 0]
ax3.plot(epochs_range, grad_norms, marker='o', color='green')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Gradient Norm')
ax3.set_title('Gradient Norm Over Time')
ax3.axhline(y=1.0, color='r', linestyle='--', alpha=0.3, label='Target norm (~1.0)')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_yscale('log')

# 4. Learning Rate Schedule
ax4 = axes[1, 1]
ax4.plot(epochs_range, learning_rates, marker='o', color='purple')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Learning Rate')
ax4.set_title('Learning Rate Schedule')
ax4.grid(True, alpha=0.3)
ax4.set_yscale('log')

plt.tight_layout()
plt.savefig('training_diagnostics.png', dpi=150, bbox_inches='tight')
plt.show()

# Print summary statistics
print("\n" + "=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)
print(f"\nLoss progression:")
print(f"  Initial train loss: {train_epoch_losses[0]:.6f}")
print(f"  Final train loss:   {train_epoch_losses[-1]:.6f}")
print(f"  Initial val loss:   {val_epoch_losses[0]:.6f}")
print(f"  Final val loss:     {val_epoch_losses[-1]:.6f}")
print(f"\nLoss reduction:")
print(f"  Train: {(1 - train_epoch_losses[-1]/train_epoch_losses[0])*100:.2f}%")
print(f"  Val:   {(1 - val_epoch_losses[-1]/val_epoch_losses[0])*100:.2f}%")
print(f"\nGradient norms:")
print(f"  Mean: {np.mean(grad_norms):.6f}")
print(f"  Min:  {np.min(grad_norms):.6f}")
print(f"  Max:  {np.max(grad_norms):.6f}")
print(f"\nLearning rate:")
print(f"  Initial: {learning_rates[0]:.6f}")
print(f"  Final:   {learning_rates[-1]:.6f}")
print("=" * 60)

In [21]:
smoothed_batch_losses = moving_average(train_batch_losses[50:], window=50)
plt.figure()
plt.plot(train_batch_losses[50:], alpha=0.3, label="Batch loss")
plt.plot(smoothed_batch_losses, label="Smoothed batch loss")
plt.xlabel("Training step")
plt.ylabel("MSE")
plt.legend()
plt.show()

plt.figure()
plt.plot(train_epoch_losses[50:], marker="o")
plt.xlabel("Epoch")
plt.ylabel("Train MSE")
plt.title("Epoch level training loss")
plt.show()

In [22]:
torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_mse": val_mse,
                "mean_y": mean_y,
                "std_y": std_y,
            },
            "final_model.pt"
        )

In [24]:
import os
os.getcwd()

In [25]:
os.listdir(".")

In [28]:
from google.colab import drive
drive.mount("/content/drive")


In [29]:
!cp -r /content/gnn_slotting_optimization /content/drive/MyDrive/
!cp final_model.pt best_model.pt training_diagnostics.png /content/drive/MyDrive/


In [30]:
from google.colab import drive
drive.mount("/content/drive")

