In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.insert(0, '../src')
import math
import numpy as np
from itertools import product
from data import make_dataset
from client import Client
from server import Server
from torch_geometric.transforms import RandomNodeSplit

dataset = make_dataset("lastfm", root="../data")
graph = dataset[0]
RandomNodeSplit(num_val=0.25, num_test=0.25)(graph)
linkless_graph = graph.clone()
linkless_graph.edge_index = None

## Baselines

First, let's train a MLP without any link information. This should be the lower bound of our LDP GNN performance.

In [None]:
mlp_hparam_space = {
    "lr": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    "wd": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
}
mlp_hparam_list = [dict(zip(mlp_hparam_space.keys(), values)) for values in product(*mlp_hparam_space.values())]

min_mlp_loss = math.inf
best_mlp_hp = None

for hp in mlp_hparam_list:
    mlp_server = Server(None, None, graph)
    log = mlp_server.fit("mlp", dataset.num_features, dataset.num_classes, hparam=hp)
    val_loss = log[np.argmin(log[:,1]),1]
    if val_loss < min_mlp_loss:
        best_mlp_hp = hp
        min_mlp_loss = val_loss

print("Best hparam found:", best_mlp_hp, "with validation loss", min_mlp_loss)

In [None]:
mlp_res = np.zeros(30)
for i in range(30):
    mlp_server = Server(None, None, graph)
    log = mlp_server.fit("mlp", dataset.num_features, dataset.num_classes, hparam=best_mlp_hp)
    mlp_res[i] = log[np.argmin(log[:,1]),2]
print(mlp_res.mean())
print(mlp_res.std())

Then, let's train a GCN with all the information. This is the non-private upper bound of our LDP GNN.

In [None]:
gcn_hparam_space = {
    "lr": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    "wd": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
}
gcn_hparam_list = [dict(zip(gcn_hparam_space.keys(), values)) for values in product(*gcn_hparam_space.values())]

min_gcn_loss = math.inf
best_gcn_hp = None

for hp in gcn_hparam_list:
    gcn_server = Server(None, None, graph)
    log = mlp_server.fit("gcn", dataset.num_features, dataset.num_classes, hparam=hp)
    val_loss = log[np.argmin(log[:,1]),1]
    if val_loss < min_gcn_loss:
        best_gcn_hp = hp
        min_gcn_loss = val_loss

print("Best hparam found:", best_gcn_hp, "with validation loss", min_gcn_loss)

In [None]:
gcn_res = np.zeros(30)
for i in range(30):
    gcn_server = Server(None, None, graph)
    log = gcn_server.fit("gcn", dataset.num_features, dataset.num_classes, hparam=best_gcn_hp)
    gcn_res[i] = log[np.argmin(log[:,1]),2]
print(gcn_res.mean())
print(gcn_res.std())

## Grid Search for hyper parameter tuning

In [None]:
hparam_space = {
    "delta": [0.1, 0.3, 0.5],
    "lr": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    "wd": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
}

hparam_list = [dict(zip(hparam_space.keys(), values)) for values in product(*hparam_space.values())]

In [None]:
import math

num_trials = 10
min_val_loss = {}
best_hparam = {}
best_hparam_res = {}

for eps in [0.5, 1, 2, 4, 8]:
    print("Epsilon = ", eps)
    min_val_loss[eps] = math.inf
    best_hparam[eps] = None
    best_hparam_res[eps] = None
    for hparam in hparam_list:
        hparam_res = np.zeros((num_trials, 2))
        for i in range(num_trials):
            client = Client(eps=eps, delta=hparam["delta"], data=graph)
            server = Server(eps=eps, delta=hparam["delta"], data=linkless_graph)

            priv_adj, priv_deg = client.AddLDP()
            server.receive(priv_adj, priv_deg)
            server.estimate()
            log = server.fit("gcn", d=dataset.num_features, c=dataset.num_classes, hparam=hparam)
            hparam_res[i] = log[:,1].min(), log[np.argmin(log[:,1])][2] # (val_loss, test_acc)
        val_loss = hparam_res[:,0].mean()
        if val_loss < min_val_loss[eps]:
            min_val_loss[eps] = val_loss
            best_hparam_res[eps] = hparam_res
            best_hparam[eps] = hparam
    print("Best hparam is: ", best_hparam[eps], "with test accuracy", best_hparam_res[eps][:,1].mean(), "(", best_hparam_res[eps][:,1].std(), ")")

## Training with best hparam

In [None]:
res = {}
num_trials = 30

for eps in [0.5, 1, 2, 4, 8]:
    hp = best_hparam[eps]
    res[eps] = np.zeros(num_trials)
    for i in range(num_trials):
        client = Client(eps=eps, delta=hp["delta"], data=graph)
        server = Server(eps=eps, delta=hp["delta"], data=linkless_graph)

        priv_adj, priv_deg = client.AddLDP()
        server.receive(priv_adj, priv_deg)
        server.estimate()
        log = server.fit("gcn", d=dataset.num_features, c=dataset.num_classes, hparam=hp)
        res[eps][i] = log[np.argmin(log[:,1])][2]

In [None]:
{i:(res[i].mean(), res[i].std()) for i in [0.5, 1, 2, 4, 8]}

In [None]:
from matplotlib import pyplot as plt

fig = plt.figure()
plt.axhline(y = gcn_res.mean(), color = 'g', linestyle = '--')
plt.axhline(y = mlp_res.mean(), color = 'r', linestyle = '--')
# plt.errorbar(["0.5", "1", "2", "4", "8"], [res[i].mean() for i in [0.5, 1, 2, 4, 8]], yerr=[res[i].std() for i in [0.5, 1, 2, 4, 8]], fmt='.k', capsize=3)
# plt.plot(["0.5", "1", "2", "4", "8"], [res[i].mean() for i in [0.5, 1, 2, 4, 8]], marker="D", color="blue")

# plt.yscale("log")
plt.ylim(ymin=0, ymax=1)
plt.xlabel("$\epsilon$")
plt.ylabel("Accuracy (%)")
plt.title("GCN on LastFMAsia")
plt.legend(["$\epsilon=\infty$ (non-private GCN)","$\epsilon=0$ (MLP without links)"], loc=4)
# plt.savefig("log/cora.pdf", bbox_inches='tight')