In [3]:
import torch
from torch.nn import Linear, LSTM
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch.nn.functional as F
import os

class LSTMAggregator(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(LSTMAggregator, self).__init__()
        self.lstm = LSTM(in_channels, hidden_channels, batch_first=True)

    def forward(self, x, edge_index):
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        
        x_j = x[col]
        x_j = x_j.view(-1, 1, x.size(-1))  # Reshape for LSTM
        x_j, _ = self.lstm(x_j)
        x_j = x_j.squeeze(1)
        
        out = x.new_zeros(x.size(0), x.size(1))
        out.index_add_(0, row, x_j)
        
        out = out / deg.view(-1, 1)
        
        return out

class CustomSAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr='add'):
        super(CustomSAGEConv, self).__init__(aggr=aggr)
        self.lstm_agg = LSTMAggregator(in_channels, out_channels)
        self.lin = Linear(in_channels, out_channels)
        self.update_lin = Linear(in_channels + out_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        x = self.propagate(edge_index, x=x)
        return x

    def message(self, x_j, edge_index, size):
        return x_j

    def update(self, aggr_out, x):
        return F.relu(self.update_lin(torch.cat([x, aggr_out], dim=1)))

    def aggregate(self, inputs, index, ptr=None, dim_size=None):
        return self.lstm_agg(inputs, index)

class EGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = CustomSAGEConv(in_channels, hidden_channels)
        self.conv2 = CustomSAGEConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.lin(x)
        return F.log_softmax(x, dim=-1)


In [2]:
# Main process
output_dir = 'output'
visualization_dir = 'visualizations'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(visualization_dir, exist_ok=True)

# Uncomment the following line only for the initial preprocessing step
# pyg_data_list = preprocess_data(data, output_dir=output_dir, visualization_dir=visualization_dir)

# Load saved graphs
# pyg_data_list = load_saved_graphs(output_dir)

train_size = int(0.7 * len(pyg_data_list))
train_data = pyg_data_list[:train_size]
test_data = pyg_data_list[train_size:]

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

# EGraphSAGE Model
model = EGraphSAGE(in_channels=1, hidden_channels=128, out_channels=2)  # in_channels=1 for PageRank only
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

# # Training EGraphSAGE Model
# print("Training EGraphSAGE Model...")
# train_model(model, train_loader, optimizer, criterion, output_dir=output_dir)

# # Evaluation EGraphSAGE Model
# print("Evaluating EGraphSAGE Model...")
# cm, report, accuracy = evaluate_model(model, test_loader, "EGraphSAGE", output_dir=output_dir)
# print('Confusion Matrix (EGraphSAGE):\n', cm)
# print('Classification Report (EGraphSAGE):\n', report)

# # Save the EGraphSAGE model
# save_model(model, os.path.join(output_dir, 'e_graphsage_model.pth'))

# # Store EGraphSAGE accuracy in the DataFrame
# accuracy_df.loc['EGraphSAGE', 'e_graphsage_model.pth'] = accuracy

# # Save the accuracy DataFrame to a CSV file
# accuracy_df.to_csv(os.path.join(output_dir, 'accuracy_results.csv'))


NameError: name 'os' is not defined