In [1]:
# Import necessary libraries
import torch
from graph_classification import *
from node_classification import *
from torch_geometric.datasets import TUDataset, Planetoid
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

In [2]:
# Graph Classification Example
def run_graph_classification():
    dataset = TUDataset(root='./data', name='MUTAG')
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GCN_graph(input_dim=dataset.num_node_features, hidden_dim=64, num_classes=dataset.num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, 101):
        train(model, train_loader, optimizer, criterion, device)
        if epoch % 10 == 0:
            train_acc, _, _, _ = evaluate_with_metrics(model, train_loader, device)
            test_acc, y_true, y_pred, y_probs = evaluate_with_metrics(model, test_loader, device)
            print(f"Epoch {epoch}: Train accuracy: {train_acc:.5f}, Test accuracy: {test_acc:.5f}")

    plot_roc_curve(y_true, y_probs, dataset.num_classes)
    plot_confusion_matrix(y_true, y_pred, dataset.num_classes)



# Node Classification Example
def run_node_classification():
    dataset = Planetoid(root='./data', name='Cora', transform=T.NormalizeFeatures())
    data = dataset[0].to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    model = GCN_node(input_dim=dataset.num_node_features, hidden_dim=64, num_classes=dataset.num_classes).to(data.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, 101):
        loss = train_node(model, data, optimizer, criterion)
        if epoch % 10 == 0:
            test_acc = evaluate_node(model, data)
            print(f"Epoch {epoch}: Loss: {loss:.4f}, Test accuracy: {test_acc:.4f}")

    embeddings = extract_embeddings(model, data)
    visualize_tsne(embeddings, data.y.cpu().numpy())

In [2]:
# Run Graph Classification
run_graph_classification()


In [2]:
# Run Node Classification
run_node_classification()