In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import dgl

from utils.aug import load

import sys
sys.path.append('..')

from arch import IOGCN, IOGAT, IOMLP, GCN, GAT

In [2]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# Data

In [None]:
graph, feat, labels, num_class, train_idx, val_idx, test_idx = load("cora")
in_dim = feat.shape[1]

In [4]:
lr = 1e-3
wd = 0
lambd = 1e-3
n_epochs = 500

n_layers = 1
hid_dim = 16
out_dim = num_class
nonlin = nn.ReLU

eval_freq = 20
es_patience = 200
subgraph = "neigh_sampling" # random, neigh_sampling
method = "transpose" # linear, fill
N = graph.number_of_nodes()

In [5]:
build_params_gcn = {'norm': 'both', 'bias': False}
build_params_gat = {'num_heads': 8, 'feat_drop': 0.2, 'attn_drop': 0.2}

In [6]:
def neigh_sampling(A, K=6):
    N = A.shape[0]
    init_node_X = np.random.randint(N)
    init_node_Y = np.random.randint(N)

    list_nodes_x = [init_node_X]
    list_nodes_y = [init_node_Y]

    for k in range(K):
        new_nodes_x = list_nodes_x.copy()
        new_nodes_y = list_nodes_y.copy()
        for x in list_nodes_x:
            new_nodes_x.extend(np.where(A[x,:] == 1)[0])
        for y in list_nodes_y:
            new_nodes_y.extend(np.where(A[y,:] == 1)[0])
        
        # Remove duplicates
        list_nodes_x = list(set(new_nodes_x))
        list_nodes_y = list(set(new_nodes_y))
    return list_nodes_x, list_nodes_y

In [None]:
if subgraph == "random":
    Nx = Ny = 512
    idxs_X = torch.randperm(N)[:Nx]
    idxs_Y = torch.randperm(N)[:Ny]
elif subgraph == "neigh_sampling":
    Nx = Ny = 0
    while Nx < 10 or Ny < 10:
        idxs_X, idxs_Y = neigh_sampling(graph.cpu().adj().to_dense(), K=2)
        Nx = len(idxs_X)
        Ny = len(idxs_Y)
        print(f"{Nx=}, {Ny=} - ", end="")
    common_nodes = [x for x in idxs_X if x in idxs_Y]
    idxs_common_x = [idxs_X.index(x) for x in common_nodes]
    idxs_common_y = [idxs_Y.index(x) for x in common_nodes]
    idxs_common = (idxs_common_x,idxs_common_y)

gx = graph.subgraph(idxs_X).add_self_loop()
gy = graph.subgraph(idxs_Y).add_self_loop()
#graph = graph.add_self_loop()
#Nin = gx.number_of_nodes()
x = feat[idxs_X,:]

gx = gx.to(device)
gy = gy.to(device)
graph = graph.to(device)
x = x.to(device)
labels_y = labels[idxs_Y].to(device)

In [8]:
idxs_y_split = torch.randperm(Ny)
N_train = 0.3
N_val = 0.2
train_idx = idxs_y_split[:int(N_train*Ny)]
val_idx = []
test_idx = idxs_y_split[int(N_train*Ny):]

In [9]:
if method == "fill" or method == "linear": # Defaulting here in case of linear transformation
    x_gcn = torch.ones((N, x.shape[1]), device=device)
    #x_gcn[idxs] = x
elif method == "graph":
    adj = graph.adjacency_matrix().to_dense().to(device)
    x_gcn = adj[:,idxs_X] @ x

In [10]:
def test(model, x, y, gx, gy, graph, lr, train_idx, val_idx, test_idx, model_name="io", es_patience=-1, verbose=True):

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    loss_fn = nn.CrossEntropyLoss()

    losses, acc_train, acc_val, acc_test = [np.zeros(n_epochs) for _ in range(4)]

    best_val_acc = 0
    best_test_acc = 0
    es_count = 0

    for i in range(n_epochs):
        if model_name == "io":
            yhat = model(gx, gy, x)
        else:
            yhat = model(graph, x)
        loss = loss_fn(yhat[train_idx], y[train_idx])

        opt.zero_grad()
        loss.backward()
        opt.step()

        preds = torch.argmax(yhat, 1)
        results = (preds == y).type(torch.float32)
        acc_train[i] = results[train_idx].mean().item()
        acc_val[i] = results[val_idx].mean().item()
        acc_test[i] = results[test_idx].mean().item()

        if acc_val[i] > best_val_acc:
            es_count = 0
            best_val_acc = acc_val[i]
            best_test_acc = acc_test[i]
        else:
            es_count += 1

        if es_patience > 0 and es_count > es_patience:
            break

        losses[i] = loss.item()

        if (i == 0 or (i+1) % eval_freq == 0) and verbose:
            print(f"Epoch {i+1}/{n_epochs} - Loss: {losses[i]} - Train Acc: {acc_train[i]} - Test Acc: {acc_test[i]}", flush=True)

    return losses, acc_train, acc_val, acc_test, best_test_acc

In [None]:
iogcn = IOGCN(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, [], method, build_params_gcn, nonlin=nonlin).to(device)
_, acc_train_iogcn, acc_val_iogcn, acc_test_iogcn, _ = test(iogcn, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx)

In [None]:
iomlp = IOMLP(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, [], method, build_params_gcn, nonlin=nonlin).to(device)
_, acc_train_iomlp, acc_val_iomlp, acc_test_iomlp, _ = test(iomlp, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx)

In [None]:
iogat = IOGAT(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, [], method, build_params_gat, nonlin=nonlin).to(device)
_, acc_train_iogat, acc_val_iogat, acc_test_iogat, _ = test(iogat, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx)

In [None]:
f = plt.figure(figsize=(12,8))

plt.plot(np.arange(n_epochs), acc_test_iogcn)
plt.plot(np.arange(n_epochs), acc_test_iomlp)
plt.plot(np.arange(n_epochs), acc_test_iogat)
#plt.plot(np.arange(n_epochs), acc_test_only_gcn)

plt.legend(["IOGCN", "IOMLP", "IOGAT"], fontsize=14)

In [15]:
idxs_orig_reorder = np.array(idxs_Y)[idxs_y_split.numpy()]
train_idx_orig = torch.from_numpy(idxs_orig_reorder[:int(N_train*Ny)])
val_idx_orig = torch.from_numpy(idxs_orig_reorder[int(N_train*Ny):int((N_train+N_val)*Ny)])
test_idx_orig = torch.from_numpy(idxs_orig_reorder[int((N_train+N_val)*Ny):])

In [None]:
edges_to_remove = []
start_edges = graph.edges()[0].tolist()
end_edges = graph.edges()[1].tolist()

for i in range(graph.number_of_edges()):
    if not (start_edges[i] in idxs_Y and end_edges[i] in idxs_Y):
        edges_to_remove.append(i)

graph_pruned = graph.clone()
graph_pruned.remove_edges(edges_to_remove)
assert graph_pruned.number_of_edges() == gy.number_of_edges() - gy.number_of_nodes() # - because self loop

graph_pruned = graph_pruned.add_self_loop().to(device)

x_pruned = torch.zeros(feat.shape)
x_pruned[idxs_X,:] = feat[idxs_X,:]

gcn = GCN(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gcn).to(device)
loss_gcn_limited, acc_train_gcn_limited, acc_val_gcn_limited, acc_test_gcn_limited, best_acc_test_gcn_limited = test(gcn, x_pruned.to(device), labels.to(device), gx, gy, graph_pruned, 1e-2, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gcn", es_patience=es_patience, verbose=False)

best_acc_test_gcn_limited

# Exhaustive tests

In [None]:
N_SIMS = 10
methods = ["linear", "transpose", "common"]
accs_test_iogcn, accs_test_iomlp, accs_test_iogat, losses_iogcn, losses_iomlp, losses_iogat = [np.zeros((N_SIMS, len(methods), n_epochs)) for _ in range(6)]
accs_test_gcn = np.zeros((N_SIMS, n_epochs))
losses_gcn = np.zeros((N_SIMS, n_epochs))
accs_test_gat = np.zeros((N_SIMS, n_epochs))
losses_gat = np.zeros((N_SIMS, n_epochs))

for sim in range(N_SIMS):
    print(f"Simulation {sim+1} ", end="")
    if subgraph == "random":
        Nx = Ny = 512
        idxs_X = torch.randperm(N)[:Nx].to(device)
        idxs_Y = torch.randperm(N)[:Ny].to(device)
    elif subgraph == "neigh_sampling":
        Nx = Ny = 0
        while Nx < 10 or Ny < 10:
            idxs_X, idxs_Y = neigh_sampling(graph.cpu().adj().to_dense(), K=2)
            Nx = len(idxs_X)
            Ny = len(idxs_Y)
            print(f"{Nx=}, {Ny=} - ", end="")
        common_nodes = [x for x in idxs_X if x in idxs_Y]
        idxs_common_x = [idxs_X.index(x) for x in common_nodes]
        idxs_common_y = [idxs_Y.index(x) for x in common_nodes]
        idxs_common = (idxs_common_x,idxs_common_y)

    gx = graph.subgraph(idxs_X).add_self_loop()
    gy = graph.subgraph(idxs_Y).add_self_loop()
    x = feat[idxs_X,:]

    gx = gx.to(device)
    gy = gy.to(device)
    graph = graph.to(device)
    x = x.to(device)
    labels_y = labels[idxs_Y].to(device)

    Sx = gx.clone().adj().to_dense().to("cuda")
    Sy = gy.clone().adj().to_dense().to("cuda")

    idxs_y_split = torch.randperm(Ny)
    N_train = 0.3
    N_val = 0.2
    train_idx = idxs_y_split[:int(N_train*Ny)]
    val_idx = idxs_y_split[int(N_train*Ny):int((N_train+N_val)*Ny)]
    test_idx = idxs_y_split[int(N_train*Ny):]

    idxs_orig_reorder = np.array(idxs_Y)[idxs_y_split.numpy()]
    train_idx_orig = torch.from_numpy(idxs_orig_reorder[:int(N_train*Ny)])
    val_idx_orig = torch.from_numpy(idxs_orig_reorder[int(N_train*Ny):int((N_train+N_val)*Ny)])
    test_idx_orig = torch.from_numpy(idxs_orig_reorder[int((N_train+N_val)*Ny):])

    for j, m in enumerate(methods):

        iogcn = IOGCN(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, idxs_common, m, build_params_gcn, nonlin=nonlin).to(device)
        iogat = IOGAT(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, idxs_common, m, build_params_gat, nonlin=nonlin).to(device)
        iomlp = IOMLP(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, idxs_common, m, build_params_gcn, nonlin=nonlin).to(device)

        loss_iogcn, acc_train_iogcn, acc_val_iogcn, acc_test_iogcn, _ = test(iogcn, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx, verbose=False)
        loss_iogat, acc_train_iogat, acc_val_iogat, acc_test_iogat, _ = test(iogat, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx, verbose=False)
        loss_iomlp, acc_train_iomlp, acc_val_iomlp, acc_test_iomlp, _ = test(iomlp, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx, verbose=False)

        print(f"{acc_test_iogcn[-1]=:.6f}, {acc_test_iogat[-1]=:.6f}, {loss_iogcn[-1]=:.6f}, {loss_iogat[-1]=:.6f}")

        accs_test_iogcn[sim,j,:] = acc_test_iogcn
        accs_test_iogat[sim,j,:] = acc_test_iogat
        accs_test_iomlp[sim,j,:] = acc_test_iomlp
        losses_iogcn[sim,j,:] = loss_iogcn
        losses_iogat[sim,j,:] = loss_iogat
        losses_iomlp[sim,j,:] = loss_iomlp

    lr_gcn = 1e-2
    gcn = GCN(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gcn).to(device)
    loss_gcn, acc_train_gcn, acc_val_gcn, acc_test_gcn, _ = test(gcn, feat.to(device), labels.to(device), gx, gy, graph.clone().add_self_loop(), lr_gcn, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gcn", verbose=False)

    gat = GAT(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gat).to(device)
    loss_gat, acc_train_gat, acc_val_gat, acc_test_gat, _ = test(gat, feat.to(device), labels.to(device), gx, gy, graph.clone().add_self_loop(), lr_gcn, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gat", verbose=False)

    accs_test_gcn[sim,:] = acc_test_gcn
    losses_gcn[sim,:] = loss_gcn
    accs_test_gat[sim,:] = acc_test_gat
    losses_gat[sim,:] = loss_gat

In [None]:
f = plt.figure(figsize=(12,8))

acc_test_iogcn, acc_test_iogat, acc_test_iomlp, acc_test_gcn, acc_test_gat, loss_iogcn, loss_iogat, loss_iomlp, loss_gcn, loss_gat = \
    [100*np.mean(elem, 0) for elem in [accs_test_iogcn, accs_test_iogat, accs_test_iomlp, accs_test_gcn, accs_test_gat, losses_iogcn, losses_iogat, losses_iomlp, losses_gcn, losses_gat]]

plt.plot(np.arange(n_epochs), acc_test_iogcn[0,:], label="IOGCN-W")
plt.plot(np.arange(n_epochs), acc_test_iogat[0,:], label="IOGAT-W")
plt.plot(np.arange(n_epochs), acc_test_iomlp[0,:], label="IOMLP-W")

plt.plot(np.arange(n_epochs), acc_test_iogcn[1,:], label="IOGCN-T")
plt.plot(np.arange(n_epochs), acc_test_iogat[1,:], label="IOGAT-T")
plt.plot(np.arange(n_epochs), acc_test_iomlp[1,:], label="IOMLP-T")

plt.plot(np.arange(n_epochs), acc_test_gcn, label="GCN")
plt.plot(np.arange(n_epochs), acc_test_gat, label="GAT")


f.legend(fontsize=14)

In [None]:
plt.figure(figsize=(8,6))

plt.plot(np.arange(n_epochs), acc_test_iogcn[0,:], label="IOGCN-W")
plt.plot(np.arange(n_epochs), acc_test_iogat[0,:], label="IOGAT-W")
plt.plot(np.arange(n_epochs), acc_test_iomlp[0,:], label="IOMLP-W")

plt.plot(np.arange(n_epochs), acc_test_iogcn[1,:], label="IOGCN-T")
plt.plot(np.arange(n_epochs), acc_test_iogat[1,:], label="IOGAT-T")
plt.plot(np.arange(n_epochs), acc_test_iomlp[1,:], label="IOMLP-T")

plt.plot(np.arange(n_epochs), acc_test_iogcn[2,:], label="IOGCN-C")
plt.plot(np.arange(n_epochs), acc_test_iogat[2,:], label="IOGAT-C")
plt.plot(np.arange(n_epochs), acc_test_iomlp[2,:], label="IOMLP-C")

plt.plot(np.arange(n_epochs), acc_test_gcn, label="GCN")
plt.plot(np.arange(n_epochs), acc_test_gat, color='c', label="GAT")

plt.legend(fontsize=14)

plt.title("Evolution of the accuracy measured on the test node set", fontsize=16)

plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Accuracy (%)", fontsize=14)

# Changing the number of nodes

In [None]:
N_SIMS = 25
methods = ["linear", "transpose", "common"]
Ks = [2,3,4,5,6,7,8]
subgraph = "neigh_sampling"

accs_test_iogcn, accs_test_ogcn, accs_test_iomlp, accs_test_iogat,\
    losses_iogcn, losses_ogcn, losses_iomlp, losses_iogat =\
    [np.zeros((len(Ks), N_SIMS, len(methods),n_epochs)) for _ in range(8)]

best_accs_test_iogcn, best_accs_test_ogcn, best_accs_test_iomlp, best_accs_test_iogat = \
    [np.zeros((len(Ks), N_SIMS, len(methods))) for _ in range(4)]

accs_test_gcn = np.zeros((len(Ks), N_SIMS,n_epochs))
best_accs_test_gcn = np.zeros((len(Ks), N_SIMS))
losses_gcn = np.zeros((len(Ks), N_SIMS,n_epochs))
accs_test_gcn_limited_x = np.zeros((len(Ks), N_SIMS,n_epochs))
best_accs_test_gcn_limited_x = np.zeros((len(Ks), N_SIMS))
losses_gcn_limited_x = np.zeros((len(Ks), N_SIMS,n_epochs))
accs_test_gcn_limited_y = np.zeros((len(Ks), N_SIMS,n_epochs))
best_accs_test_gcn_limited_y = np.zeros((len(Ks), N_SIMS))
losses_gcn_limited_y = np.zeros((len(Ks), N_SIMS,n_epochs))
accs_test_gcn_limited_xy = np.zeros((len(Ks), N_SIMS,n_epochs))
best_accs_test_gcn_limited_xy = np.zeros((len(Ks), N_SIMS))
losses_gcn_limited_xy = np.zeros((len(Ks), N_SIMS,n_epochs))
accs_test_gat = np.zeros((len(Ks), N_SIMS,n_epochs))
best_accs_test_gat = np.zeros((len(Ks), N_SIMS))
losses_gat = np.zeros((len(Ks), N_SIMS,n_epochs))

for k, n_neigh in enumerate(Ks):
    print(f"****** Starting {n_neigh} neighbors ****** (simulation out of {N_SIMS}): ", end="")

    for sim in range(N_SIMS):

        print(f"{sim+1} ", end="", flush=True)

        if subgraph == "random":
            Nx = Ny = 512
            idxs_X = torch.randperm(N)[:Nx].to(device)
            idxs_Y = torch.randperm(N)[:Ny].to(device)
        elif subgraph == "neigh_sampling":
            Nx = Ny = 0
            while Nx < 10 or Ny < 10: # Ensure at least 10 nodes in each graph
                idxs_X, idxs_Y = neigh_sampling(graph.cpu().adj().to_dense(), K=n_neigh)
                Nx = len(idxs_X)
                Ny = len(idxs_Y)
                #print(f"{Nx=}, {Ny=} - ", end="")
            common_nodes = [x for x in idxs_X if x in idxs_Y]
            idxs_common_x = [idxs_X.index(x) for x in common_nodes]
            idxs_common_y = [idxs_Y.index(x) for x in common_nodes]
            idxs_common = (idxs_common_x,idxs_common_y)

        gx = graph.subgraph(idxs_X).add_self_loop()
        gy = graph.subgraph(idxs_Y).add_self_loop()
        x = feat[idxs_X,:]

        gx = gx.to(device)
        gy = gy.to(device)
        graph = graph.to(device)
        x = x.to(device)
        labels_y = labels[idxs_Y].to(device)

        idxs_y_split = torch.randperm(Ny)
        N_train = 0.3
        N_val = 0.2
        train_idx = idxs_y_split[:int(N_train*Ny)]
        val_idx = idxs_y_split[int(N_train*Ny):int((N_train+N_val)*Ny)]
        test_idx = idxs_y_split[int((N_train+N_val)*Ny):]

        idxs_orig_reorder = np.array(idxs_Y)[idxs_y_split.numpy()]
        train_idx_orig = torch.from_numpy(idxs_orig_reorder[:int(N_train*Ny)])
        val_idx_orig = torch.from_numpy(idxs_orig_reorder[int(N_train*Ny):int((N_train+N_val)*Ny)])
        test_idx_orig = torch.from_numpy(idxs_orig_reorder[int((N_train+N_val)*Ny):])

        for j, m in enumerate(methods):

            iogcn = IOGCN(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, idxs_common, m, build_params_gcn, nonlin=nonlin).to(device)
            iogat = IOGAT(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, idxs_common, m, build_params_gat, nonlin=nonlin).to(device)
            iomlp = IOMLP(in_dim, hid_dim, out_dim, Nx, Ny, n_layers, idxs_common, m, build_params_gcn, nonlin=nonlin).to(device)

            loss_iogcn, acc_train_iogcn, acc_val_iogcn, acc_test_iogcn, best_acc_test_iogcn = test(iogcn, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx, es_patience=es_patience, verbose=False)
            loss_iogat, acc_train_iogat, acc_val_iogat, acc_test_iogat, best_acc_test_iogat = test(iogat, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx, es_patience=es_patience, verbose=False)
            loss_iomlp, acc_train_iomlp, acc_val_iomlp, acc_test_iomlp, best_acc_test_iomlp = test(iomlp, x, labels_y, gx, gy, graph, lr, train_idx, val_idx, test_idx, es_patience=es_patience, verbose=False)

            #print(f"{acc_test_iogcn[-1]=:.6f}, {acc_test_iogat[-1]=:.6f}, {acc_test_ogcn[-1]=:.6f}, {loss_iogcn[-1]=:.6f}, {loss_ogcn[-1]=:.6f}")

            accs_test_iogcn[k,sim,j,:] = acc_test_iogcn
            accs_test_iogat[k,sim,j,:] = acc_test_iogat
            accs_test_iomlp[k,sim,j,:] = acc_test_iomlp
            best_accs_test_iogcn[k,sim,j] = best_acc_test_iogcn
            best_accs_test_iogat[k,sim,j] = best_acc_test_iogat
            best_accs_test_iomlp[k,sim,j] = best_acc_test_iomlp
            losses_iogcn[k,sim,j,:] = loss_iogcn
            losses_iogat[k,sim,j,:] = loss_iogat
            losses_iomlp[k,sim,j,:] = loss_iomlp

        lr_gcn = 1e-2
        gcn = GCN(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gcn).to(device)
        loss_gcn, acc_train_gcn, acc_val_gcn, acc_test_gcn, best_acc_test_gcn = test(gcn, feat.to(device), labels.to(device), gx, gy, graph.clone().add_self_loop(), lr_gcn, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gcn", es_patience=es_patience, verbose=False)

        edges_to_remove_x = []
        edges_to_remove_y = []
        edges_to_remove_xy = []
        start_edges = graph.edges()[0].tolist()
        end_edges = graph.edges()[1].tolist()

        idxs_XY = list(set(idxs_X + idxs_Y))

        for i in range(graph.number_of_edges()):
            if not (start_edges[i] in idxs_X and end_edges[i] in idxs_X):
                edges_to_remove_x.append(i)
            if not (start_edges[i] in idxs_Y and end_edges[i] in idxs_Y):
                edges_to_remove_y.append(i)
            if not (start_edges[i] in idxs_XY and end_edges[i] in idxs_XY):
                edges_to_remove_xy.append(i)

        gx_pruned = graph.clone()
        gx_pruned.remove_edges(edges_to_remove_x)
        assert gx_pruned.number_of_edges() == gx.number_of_edges() - gx.number_of_nodes() # - because self loop

        gx_pruned = gx_pruned.add_self_loop().to(device)

        gy_pruned = graph.clone()
        gy_pruned.remove_edges(edges_to_remove_y)
        assert gy_pruned.number_of_edges() == gy.number_of_edges() - gy.number_of_nodes() # - because self loop

        gy_pruned = gy_pruned.add_self_loop().to(device)

        gxy_pruned = graph.clone()
        gxy_pruned.remove_edges(edges_to_remove_xy)

        gxy_pruned = gxy_pruned.add_self_loop().to(device)

        x_pruned = torch.zeros(feat.shape)
        x_pruned[idxs_X,:] = feat[idxs_X,:]

        gcn = GCN(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gcn).to(device)
        loss_gcn_limited_x, acc_train_gcn_limited_x, acc_val_gcn_limited_x, acc_test_gcn_limited_x, best_acc_test_gcn_limited_x = test(gcn, x_pruned.to(device), labels.to(device), gx, gy, gx_pruned, lr_gcn, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gcn", es_patience=es_patience, verbose=False)

        gcn = GCN(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gcn).to(device)
        loss_gcn_limited_y, acc_train_gcn_limited_y, acc_val_gcn_limited_y, acc_test_gcn_limited_y, best_acc_test_gcn_limited_y = test(gcn, x_pruned.to(device), labels.to(device), gx, gy, gy_pruned, lr_gcn, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gcn", es_patience=es_patience, verbose=False)

        gcn = GCN(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gcn).to(device)
        loss_gcn_limited_xy, acc_train_gcn_limited_xy, acc_val_gcn_limited_xy, acc_test_gcn_limited_xy, best_acc_test_gcn_limited_xy = test(gcn, x_pruned.to(device), labels.to(device), gx, gy, gxy_pruned, lr_gcn, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gcn", es_patience=es_patience, verbose=False)

        gat = GAT(in_dim, hid_dim, out_dim, 2*n_layers, nonlin, build_params_gat).to(device)
        loss_gat, acc_train_gat, acc_val_gat, acc_test_gat, best_acc_test_gat = test(gat, feat.to(device), labels.to(device), gx, gy, graph.clone().add_self_loop(), lr_gcn, train_idx_orig, val_idx_orig, test_idx_orig, model_name="gat", es_patience=es_patience, verbose=False)

        accs_test_gcn[k,sim,:] = acc_test_gcn
        best_accs_test_gcn[k,sim] = best_acc_test_gcn
        losses_gcn[k,sim,:] = loss_gcn
        accs_test_gcn_limited_x[k,sim,:] = acc_test_gcn_limited_x
        best_accs_test_gcn_limited_x[k,sim] = best_acc_test_gcn_limited_x
        losses_gcn_limited_x[k,sim,:] = loss_gcn_limited_x
        accs_test_gcn_limited_y[k,sim,:] = acc_test_gcn_limited_y
        best_accs_test_gcn_limited_y[k,sim] = best_acc_test_gcn_limited_y
        losses_gcn_limited_y[k,sim,:] = loss_gcn_limited_y
        accs_test_gcn_limited_xy[k,sim,:] = acc_test_gcn_limited_xy
        best_accs_test_gcn_limited_xy[k,sim] = best_acc_test_gcn_limited_xy
        losses_gcn_limited_xy[k,sim,:] = loss_gcn_limited_xy
        accs_test_gat[k,sim,:] = acc_test_gat
        best_accs_test_gat[k,sim] = best_acc_test_gat
        losses_gat[k,sim,:] = loss_gat
    
    print("DONE")

In [26]:
import scienceplots

plt.style.use(['science','ieee'])

In [None]:
plt.figure(figsize=(12,8))

best_acc_test_iogcn, best_acc_test_iogat, best_acc_test_ogcn, best_acc_test_iomlp, best_acc_test_gcn, best_acc_test_gcn_limited_x, best_acc_test_gcn_limited_y, best_acc_test_gcn_limited_xy, best_acc_test_gat = \
    [100*np.mean(elem, 1) for elem in [best_accs_test_iogcn, best_accs_test_iogat, best_accs_test_ogcn, best_accs_test_iomlp, best_accs_test_gcn, best_accs_test_gcn_limited_x, best_accs_test_gcn_limited_y, best_accs_test_gcn_limited_xy, best_accs_test_gat]]

plt.plot(Ks[:-1], best_acc_test_iogcn[:-1,0], 'o-', color='b', linewidth=2, label="IOGCN-W")
#plt.plot(Ks, best_acc_test_iogat[:,0], 's-', color='r', linewidth=2, label="IOGAT-W")
plt.plot(Ks[:-1], best_acc_test_iomlp[:-1,0], 'v-', color='g', linewidth=2, label="IOMLP-W")

plt.plot(Ks[:-1], best_acc_test_iogcn[:-1,1], 'o--', color='b', linewidth=2, label="IOGCN-T")
#plt.plot(Ks, best_acc_test_iogat[:,1], 's--', color='r', linewidth=2, label="IOGAT-T")
#plt.plot(Ks, best_acc_test_iomlp[:,1], 'v--', color='g', linewidth=2, label="IOMLP-T")

plt.plot(Ks[:-1], best_acc_test_iogcn[:-1,2], 'o-.', color='b', linewidth=2, label="IOGCN-C")
#plt.plot(Ks, best_acc_test_iogat[:,2], 's-.', color='r', linewidth=2, label="IOGAT-C")
plt.plot(Ks[:-1], best_acc_test_iomlp[:-1,2], 'v-.', color='g', linewidth=2, label="IOMLP-C")

plt.plot(Ks[:-1], best_acc_test_gcn[:-1], 'o-', linewidth=2, color='m', label="GCN")
plt.plot(Ks[:-1], best_acc_test_gcn_limited_x[:-1], 'o-', linewidth=2, color='y', label="GCN-Limited-$\mathcal{G}_X$")
plt.plot(Ks[:-1], best_acc_test_gcn_limited_y[:-1], 'o-', linewidth=2, color='orange', label="GCN-Limited-$\mathcal{G}_Y$")
plt.plot(Ks[:-1], best_acc_test_gcn_limited_xy[:-1], 'o-', linewidth=2, color='gray', label="GCN-Limited-$\mathcal{G}_{XY}$")
#plt.plot(Ks, best_acc_test_gat, 'o-', linewidth=2, color='c', label="GAT")

plt.xlabel("Number of hops for snowball sampling", fontsize=20)
plt.ylabel("Mean accuracy over the test node set", fontsize=20)

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.grid()

plt.legend(fontsize=16)

plt.savefig('results/20240111-selection_nodes.pdf')

In [None]:
plt.figure(figsize=(12,8))

best_acc_test_iogcn, best_acc_test_iogat, best_acc_test_iomlp, best_acc_test_gcn, best_acc_test_gcn_limited_x, best_acc_test_gcn_limited_y, best_acc_test_gcn_limited_xy, best_acc_test_gat = \
    [100*np.mean(elem, 1) for elem in [best_accs_test_iogcn, best_accs_test_iogat, best_accs_test_ogcn, best_accs_test_iomlp, best_accs_test_gcn, best_accs_test_gcn_limited_x, best_accs_test_gcn_limited_y, best_accs_test_gcn_limited_xy, best_accs_test_gat]]
best_acc_test_iogcn_std, best_acc_test_iogat_std, best_acc_test_iomlp_std, best_acc_test_gcn_std, best_acc_test_gcn_limited_x_std, best_acc_test_gcn_limited_y_std, best_acc_test_gcn_limited_xy_std, best_acc_test_gat_std = \
    [100*np.std(elem, 1) for elem in [best_accs_test_iogcn, best_accs_test_iogat, best_accs_test_ogcn, best_accs_test_iomlp, best_accs_test_gcn, best_accs_test_gcn_limited_x, best_accs_test_gcn_limited_y, best_accs_test_gcn_limited_xy, best_accs_test_gat]]

plt.errorbar(Ks, best_acc_test_iogcn[:,0], yerr=best_acc_test_iogcn_std[:,0], fmt='o-', color='b', linewidth=3, markersize=12, capsize=5, label="IOGCN-W")
#plt.plot(Ks, best_acc_test_iogat[:,0], 's-', color='r', linewidth=2, label="IOGAT-W")
plt.errorbar(Ks, best_acc_test_iomlp[:,0], yerr=best_acc_test_iomlp_std[:,0], fmt='v-', color='g', linewidth=3, markersize=12, capsize=5, label="IOMLP-W")

plt.errorbar(Ks, best_acc_test_iogcn[:,1], yerr=best_acc_test_iogcn_std[:,1], fmt='o--', color='b', linewidth=3, markersize=12, capsize=5, label="IOGCN-T")
#plt.plot(Ks, best_acc_test_iogat[:,1], 's--', color='r', linewidth=2, label="IOGAT-T")
#plt.plot(Ks, best_acc_test_iomlp[:,1], 'v--', color='g', linewidth=2, label="IOMLP-T")

plt.errorbar(Ks, best_acc_test_iogcn[:,2], yerr=best_acc_test_iogcn_std[:,2], fmt='o-.', color='b', linewidth=3, markersize=12, capsize=5, label="IOGCN-C")
#plt.plot(Ks, best_acc_test_iogat[:,2], 's-.', color='r', linewidth=2, label="IOGAT-C")
plt.errorbar(Ks, best_acc_test_iomlp[:,2], yerr=best_acc_test_iomlp_std[:,2], fmt='v-.', color='g', linewidth=3, markersize=12, capsize=5, label="IOMLP-C")

plt.errorbar(Ks, best_acc_test_gcn, yerr=best_acc_test_gcn_std, fmt='o-', linewidth=3, markersize=12, capsize=5, color='m', label="GCN")
plt.errorbar(Ks, best_acc_test_gcn_limited_x, yerr=best_acc_test_gcn_limited_x_std, fmt='o-', linewidth=3, markersize=12, capsize=5, color='y', label="GCN-Limited-$\mathcal{G}_X$")
plt.errorbar(Ks, best_acc_test_gcn_limited_y, yerr=best_acc_test_gcn_limited_y_std, fmt='o-', linewidth=3, markersize=12, capsize=5, color='orange', label="GCN-Limited-$\mathcal{G}_Y$")
plt.errorbar(Ks, best_acc_test_gcn_limited_xy, yerr=best_acc_test_gcn_limited_xy_std, fmt='o-', linewidth=3, markersize=12, capsize=5, color='gray', label="GCN-Limited-$\mathcal{G}_{XY}$")
#plt.plot(Ks, best_acc_test_gat, 'o-', linewidth=2, color='c', label="GAT")

plt.xlabel("Number of hops for snowball sampling", fontsize=20)
plt.ylabel("Mean accuracy over the test node set", fontsize=20)

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.grid()

plt.legend(fontsize=16)

plt.savefig('results/20240111-selection_nodes.pdf')