In [8]:
import torch
import torch.nn as nn
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch_geometric.loader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, GATv2Conv
import numpy as np


In [9]:
class GAT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GAT, self).__init__()
        self.conv1 = GATv2Conv(input_size, hidden_size, edge_dim=1)
        self.conv2 = GATv2Conv(hidden_size, hidden_size, edge_dim=1)
        self.lin = nn.Linear(hidden_size, output_size)

    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index, edge_attr=edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr=edge_weight)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

In [10]:
input_size = 7
hidden_size = 64
output_size = 7
epochs = 100
learning_rate = 1e-2
graph_name = "cb_nabil_wt"
batch_size = 32


In [11]:
graphs = torch.load(f"../graphs/{graph_name}.pt", weights_only=False)
train, val = train_test_split(graphs, test_size=0.2, random_state=12)
train_graphs = DataLoader(train, batch_size=batch_size, shuffle=True)
val_graphs = DataLoader(val, batch_size=batch_size, shuffle=False)


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GAT(input_size, hidden_size, output_size)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
ls_fn = nn.MSELoss()

model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch in train_graphs:
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch.x, batch.edge_index, batch.edge_weight, batch.batch)
        output = output.view(-1)
        loss = ls_fn(output, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss:.4f}")

model.eval()
predictions = []
ground_truth = []
for batch in val_graphs:
    batch = batch.to(device)
    with torch.no_grad():
        output = model(batch.x, batch.edge_index, batch.edge_weight, batch.batch)
        output = output.view(-1)
        predictions.append(output.cpu().numpy())
        ground_truth.append(batch.y.cpu().numpy())

predictions = np.concatenate(predictions, axis=0)
ground_truth = np.concatenate(ground_truth, axis=0)

mse = mean_squared_error(ground_truth, predictions)
mae = mean_absolute_error(ground_truth, predictions)

print(f"Mean Squared Error (MSE): {mse:.4f}")
print(f"Mean Absolute Error (MAE): {mae:.4f}")



Epoch 1/100, Loss: 9.6065
Epoch 2/100, Loss: 6.4011
Epoch 3/100, Loss: 6.2938
Epoch 4/100, Loss: 6.0352
Epoch 5/100, Loss: 5.6747
Epoch 6/100, Loss: 5.5900
Epoch 7/100, Loss: 5.7147
Epoch 8/100, Loss: 5.5522
Epoch 9/100, Loss: 5.6075
Epoch 10/100, Loss: 5.4581
Epoch 11/100, Loss: 5.4223
Epoch 12/100, Loss: 5.7962
Epoch 13/100, Loss: 5.7236
Epoch 14/100, Loss: 5.4017
Epoch 15/100, Loss: 5.4582
Epoch 16/100, Loss: 5.6315
Epoch 17/100, Loss: 5.3471
Epoch 18/100, Loss: 5.8085
Epoch 19/100, Loss: 5.3579
Epoch 20/100, Loss: 5.3299
Epoch 21/100, Loss: 5.2867
Epoch 22/100, Loss: 5.2651
Epoch 23/100, Loss: 5.5022
Epoch 24/100, Loss: 5.5573
Epoch 25/100, Loss: 5.2115
Epoch 26/100, Loss: 5.1041
Epoch 27/100, Loss: 5.2728
Epoch 28/100, Loss: 5.2225
Epoch 29/100, Loss: 5.2670
Epoch 30/100, Loss: 5.3925
Epoch 31/100, Loss: 5.1118
Epoch 32/100, Loss: 5.1144
Epoch 33/100, Loss: 5.2966
Epoch 34/100, Loss: 5.1429
Epoch 35/100, Loss: 5.1481
Epoch 36/100, Loss: 5.0245
Epoch 37/100, Loss: 5.1698
Epoch 38/1

In [14]:
model_path = "../models/GAT_nabil.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Model saved to ../models/GAT_nabil.pth


In [7]:
# # Training loop with validation and early stopping
# best_val_loss = float('inf')
# patience = 10
# counter = 0

# for epoch in range(epochs):
#     # Training phase
#     model.train()
#     total_loss = 0
#     for batch in train_graphs:
#         batch = batch.to(device)
#         optimizer.zero_grad()
#         output = model(batch.x, batch.edge_index, batch.batch)
#         output = output.view(-1)
#         loss = ls_fn(output, batch.y)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()

#     # Print training loss after the 400th epoch
#     if epoch >= 400:
#         print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {total_loss:.4f}")

#     # Validation phase
#     model.eval()
#     val_loss = 0
#     predictions = []
#     ground_truth = []
#     for batch in val_graphs:
#         batch = batch.to(device)
#         with torch.no_grad():
#             output = model(batch.x, batch.edge_index, batch.batch)
#             output = output.view(-1)
#             val_loss += ls_fn(output, batch.y).item()
#             predictions.append(output.cpu().numpy())
#             ground_truth.append(batch.y.cpu().numpy())

#     val_loss /= len(val_graphs)
#     predictions = np.concatenate(predictions, axis=0)
#     ground_truth = np.concatenate(ground_truth, axis=0)

#     val_mse = mean_squared_error(ground_truth, predictions)
#     val_mae = mean_absolute_error(ground_truth, predictions)

#     print(f"Validation Loss: {val_loss:.4f}, Validation MSE: {val_mse:.4f}, Validation MAE: {val_mae:.4f}")

#     # Learning rate scheduling
#     scheduler.step(val_loss)

#     # Early stopping logic
#     if val_loss < best_val_loss:
#         best_val_loss = val_loss
#         counter = 0
#         torch.save(model.state_dict(), "gcn.pth")  # Save the best model
#     else:
#         counter += 1
#         if counter >= patience:
#             print("Early stopping triggered.")
#             break

# # Load the best model for testing
# model.load_state_dict(torch.load("gcn.pth"))
# model.to(device)
# # Test phase
# model.eval()
# test_predictions = []
# test_ground_truth = []
# test_loss = 0
# for batch in test_graphs:
#     batch = batch.to(device)
#     with torch.no_grad():
#         output = model(batch.x, batch.edge_index, batch.batch)
#         output = output.view(-1)
#         test_loss += ls_fn(output, batch.y).item()
#         test_predictions.append(output.cpu().numpy())
#         test_ground_truth.append(batch.y.cpu().numpy())

# test_loss /= len(test_graphs)
# test_predictions = np.concatenate(test_predictions, axis=0)
# test_ground_truth = np.concatenate(test_ground_truth, axis=0)

# test_mse = mean_squared_error(test_ground_truth, test_predictions)
# test_mae = mean_absolute_error(test_ground_truth, test_predictions)

# print(f"Test Loss: {test_loss:.4f}, Test MSE: {test_mse:.4f}, Test MAE: {test_mae:.4f}")