In [16]:
import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
from torch_geometric.datasets import Planetoid

import numpy as np
import time
import random
from tqdm import tqdm
from collections import defaultdict

from graphsage.encoders import Encoder
from graphsage.aggregators import MeanAggregator
from graphsage.spv_graphsage import SupervisedGraphSage

from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support

In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
data_name = 'Pubmed'

### Train GraphSage on node classification

In [18]:
if data_name == 'Cora':
    # load cora
    num_nodes = 2708
    num_feats = 1433
    num_classes = 7
    feat_data = np.zeros((num_nodes, num_feats))
    labels = np.empty((num_nodes,1), dtype=np.int64)
    node_map = {}
    label_map = {}
    with open("./data/cora_gsg/cora.content") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            feat_data[i,:] = [float(x) for x in info[1:-1]]
            node_map[info[0]] = i
            if not info[-1] in label_map:
                label_map[info[-1]] = len(label_map)
            labels[i] = label_map[info[-1]]

    adj_lists = defaultdict(set)
    with open("./data/cora_gsg/cora.cites") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            paper1 = node_map[info[0]]
            paper2 = node_map[info[1]]
            adj_lists[paper1].add(paper2)
            adj_lists[paper2].add(paper1)
    print('feat_data shape:', feat_data.shape)
    print('labels shape:', labels.shape)
    print('adj_lists type:', type(adj_lists))
elif data_name == 'Pubmed':
    

feat_data shape: (2708, 1433)
labels shape: (2708, 1)
adj_lists type: <class 'collections.defaultdict'>


In [19]:
# run cora
np.random.seed(1)
random.seed(1)
features = nn.Embedding(num_nodes, num_feats)
features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)

agg1 = MeanAggregator(features, cuda=True)
enc1 = Encoder(features, num_feats, 128, adj_lists, agg1, gcn=True, cuda=False)
agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False)
enc2 = Encoder(
        lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
        base_model=enc1, gcn=True, cuda=False)
enc1.num_samples = 5
enc2.num_samples = 5

graphsage = SupervisedGraphSage(num_classes, enc2)
rand_indices = np.array(range(num_nodes))
train = list(rand_indices[:140])
val = rand_indices[140:640]
test = rand_indices[-1000:]
batch_size = 32

# optimizer = torch.optim.SGD(filter(lambda p : p.requires_grad, graphsage.parameters()), lr=0.7)
optimizer = torch.optim.Adam(graphsage.parameters(), lr=1e-3)
print_every = 25
epochs = 500
for batch in range(1, epochs + 1):
    batch_nodes = train[:batch_size]
    random.shuffle(train)
    optimizer.zero_grad()
    loss = graphsage.loss(
        batch_nodes, 
        Variable(torch.LongTensor(labels[np.array(batch_nodes)]))
    )
    loss.backward()
    optimizer.step()
    if batch == 1 or batch % print_every == 0:
        print('batch {} | xe: {:.6f}'.format(batch, loss.data.item()))

# tr_output = graphsage.forward(train)
# val_output = graphsage.forward(val) 

# tr_acc = accuracy_score(labels[train], tr_output.data.numpy().argmax(axis=1))
# val_acc = accuracy_score(labels[val], val_output.data.numpy().argmax(axis=1))
# print('tr acc: {:.4f}'.format(tr_acc))
# print('val acc: {:.4f}'.format(val_acc))

batch 1 | xe: 1.930618
batch 25 | xe: 0.939674


  init.xavier_uniform(self.weight)


batch 50 | xe: 0.261618
batch 75 | xe: 0.072758
batch 100 | xe: 0.089928
batch 125 | xe: 0.031390
batch 150 | xe: 0.013784
batch 175 | xe: 0.003537
batch 200 | xe: 0.116970
batch 225 | xe: 0.019461
batch 250 | xe: 0.004817
batch 275 | xe: 0.021583
batch 300 | xe: 0.017895
batch 325 | xe: 0.002445
batch 350 | xe: 0.004432
batch 375 | xe: 0.080515
batch 400 | xe: 0.007993
batch 425 | xe: 0.014591
batch 450 | xe: 0.001856
batch 475 | xe: 0.003758
batch 500 | xe: 0.024429
tr precision: 0.7740; tr recall: 0.7740; tr f1: 0.7740
val precision: 0.7740; val recall: 0.7740; val f1: 0.7740
average batch time: 0.002294099807739258


### Test quality of GraphSage embeddings for node classification via LogisticRegression

In [21]:
X_tr = graphsage.enc(train + val.tolist()).t().detach().numpy()
X_te = graphsage.enc(test).t().detach().numpy()
y_tr = labels[np.array(train + val.tolist())].squeeze()
y_te = labels[np.array(test)].squeeze()

In [31]:
clf = LogisticRegressionCV(class_weight='balanced', max_iter=750).fit(X_tr, y_tr)
clf.score(X_tr, y_tr), clf.score(X_te, y_te)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

(0.9890625, 0.737)

### Train GraphSage on link prediction

In [23]:
dataset = Planetoid('./data', data_name)
data = dataset[0]

feat_data = data.x
adj_lists = data.edge_index
labels = data.y
num_nodes, num_features = feat_data.shape

In [25]:
(data.train_mask + data.val_mask).sum()

tensor(640)

In [26]:
def sample_edges(edge_index, mask, strict=False):
    pos_samples, neg_samples = [], []
    if strict: candidates = mask.nonzero().squeeze()
    for node in tqdm(mask.nonzero()):
        neighbors = edge_index[:, edge_index[0] == node][1]
        not_neighbors = edge_index[:, edge_index[0] != node][1]
        if strict:
            neighbors = [n for n in neighbors if n in candidates]
            not_neighbors = [n for n in not_neighbors if n in candidates]

        num_neighbors = len(neighbors)
        if strict and num_neighbors > 0:
            pos = neighbors[random.sample(range(num_neighbors), 1)[0]].item()
            pos_samples.append([node.item(), pos])
        
        num_not_neighbors = len(not_neighbors)
        if strict and num_not_neighbors > 0:
            neg = not_neighbors[random.sample(range(num_not_neighbors), 1)[0]].item()
            neg_samples.append([node.item(), neg])

    return pos_samples, neg_samples

pos_samples_tr, neg_samples_tr = sample_edges(adj_lists, data.train_mask + data.val_mask, strict=True)
pos_samples_te, neg_samples_te = sample_edges(adj_lists, data.test_mask, strict=True)
len(pos_samples_tr), len(neg_samples_tr), len(pos_samples_te), len(neg_samples_te)

100%|██████████| 640/640 [00:29<00:00, 21.40it/s]
100%|██████████| 1000/1000 [00:50<00:00, 19.73it/s]


(382, 640, 692, 1000)

In [27]:
node_embeddings = graphsage.enc(list(range(num_nodes))).t().detach()
print('node_embeddings shape:', node_embeddings.shape)

node_embeddings shape: torch.Size([2708, 128])


In [28]:
embeddings_pos_hdmd_tr = node_embeddings[pos_samples_tr, :][:, 0, :] * node_embeddings[pos_samples_tr, :][:, 1, :]
embeddings_neg_hdmd_tr = node_embeddings[neg_samples_tr, :][:, 0, :] * node_embeddings[neg_samples_tr, :][:, 1, :]
embeddings_pos_hdmd_te = node_embeddings[pos_samples_te, :][:, 0, :] * node_embeddings[pos_samples_te, :][:, 1, :]
embeddings_neg_hdmd_te = node_embeddings[neg_samples_te, :][:, 0, :] * node_embeddings[neg_samples_te, :][:, 1, :]

In [29]:
embeddings_hdmd_tr = np.concatenate([
    embeddings_pos_hdmd_tr, 
    embeddings_neg_hdmd_tr
])

targets_tr = np.concatenate([
    np.ones(len(embeddings_pos_hdmd_tr)),
    np.zeros(len(embeddings_neg_hdmd_tr)),
])

embeddings_hdmd_te = np.concatenate([
    embeddings_pos_hdmd_te,
    embeddings_neg_hdmd_te
])

targets_te = np.concatenate([
    np.ones(len(embeddings_pos_hdmd_te)),
    np.zeros(len(embeddings_neg_hdmd_te))
])

In [30]:
clf = LogisticRegressionCV(class_weight='balanced', max_iter=500).fit(embeddings_hdmd_tr, targets_tr)
tr_outputs = clf.predict(embeddings_hdmd_tr)
te_outputs = clf.predict(embeddings_hdmd_te)

tr_prec, tr_recall, _, _ = precision_recall_fscore_support(targets_tr, tr_outputs, average="micro")
val_prec, val_recall, _, _ = precision_recall_fscore_support(targets_te, te_outputs, average="micro")

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

(0.6487279843444227, 0.5218676122931442)

In [37]:
clf.n_iter_

array([[[ 13,  22,  47, 151, 303, 447, 500, 500, 500, 500],
        [ 13,  24,  45, 116, 213, 381, 500, 500, 500, 500],
        [ 13,  23,  46, 128, 280, 453, 500, 500, 500, 500],
        [ 13,  23,  52, 121, 241, 445, 500, 500, 500, 500],
        [ 13,  23,  59, 122, 258, 355, 500, 500, 500, 500]]], dtype=int32)