In [1]:
from stellargraph.mapper import CorruptedGraphSAGENodeGenerator, GraphSAGENodeGenerator
import stellargraph as sg
import networkx as nx
from stellargraph import StellarGraph, StellarDiGraph
from stellargraph.layer import GraphSAGEInfoMax
from stellargraph import datasets
import random
import numpy as np
from sklearn import preprocessing, feature_extraction, model_selection
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dot
from tensorflow.keras.activations import sigmoid, relu
from tensorflow.keras.optimizers import Adam
from IPython.display import display, HTML

from tensorflow.keras.layers import Input, Dense
import tensorflow as tf
from tensorflow.keras import Model

In [2]:
def info_loss(y_true, y_pred):
    
    return -tf.math.reduce_mean(tf.math.log(y_pred))

In [3]:
dataset = datasets.Cora()
display(HTML(dataset.description))
G, node_subjects = dataset.load()

In [4]:
train_subjects, test_subjects = model_selection.train_test_split(
    node_subjects, train_size=0.1, test_size=None, stratify=node_subjects
)

target_encoding = preprocessing.LabelBinarizer()

train_targets = target_encoding.fit_transform(train_subjects)
test_targets = target_encoding.transform(test_subjects)

In [5]:
batch_size = 50
num_samples = [10, 5]

generator = CorruptedGraphSAGENodeGenerator(G, batch_size, num_samples)
gen = generator.flow(G.nodes(), targets=np.ones(len(G.nodes())))

graph_sage_infomax = GraphSAGEInfoMax([64, 64,], generator=generator, normalize="l2")

x_in, x_out = graph_sage_infomax.unsupervised_node_model()

model = Model(inputs=x_in, outputs=x_out)
model.compile(loss=info_loss, optimizer=Adam(lr=1e-3))

In [6]:
model.fit(gen, epochs=20)

  ...
    to  
  ['...']
Train for 55 steps
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x1a364b9c50>

In [7]:
x_emb_in = model.inputs[:len(graph_sage_infomax.neighbourhood_sizes)]
x_emb_out = model.get_layer("GRAPH_SAGE_NORM").output
emb_model = Model(inputs=x_emb_in, outputs=x_emb_out)

In [8]:
for x in gen:
    for y in x[0]:
        print(y.shape)
    break

(50, 1, 1433)
(50, 10, 1433)
(50, 50, 1433)
(50, 1, 1433)
(50, 10, 1433)
(50, 50, 1433)
(50, 0, 1433)


In [9]:
gsgenerator = GraphSAGENodeGenerator(G, batch_size, num_samples)

gstrain_gen = gsgenerator.flow(train_subjects.index,)
gstest_gen = gsgenerator.flow(test_subjects.index,)

In [10]:
test_embeddings = emb_model.predict(gstest_gen)
train_embeddings = emb_model.predict(gstrain_gen)

In [11]:
from sklearn.linear_model import LogisticRegression

lr = LogisticRegression()
lr.fit(train_embeddings, train_subjects)

y_pred = lr.predict(test_embeddings)
(y_pred == test_subjects).mean()



0.32895816242821985

# END OF DEMO

In [12]:
lr = LogisticRegression()
lr.fit(G.node_features(train_subjects.index), train_subjects)

y_pred = lr.predict(G.node_features(test_subjects.index))
(y_pred == test_subjects).mean()



0.6755537325676785

In [13]:
G.node_features(train_subjects.index).shape, train_embeddings.shape

((270, 1433), (270, 64))

In [14]:
# Create tensor inputs for neighbourhood sampling
x_inp = [
    Input(shape=(s, graph_sage_infomax.input_feature_size)) for s in graph_sage_infomax.neighbourhood_sizes
]

x_inp_corrupted = [
    Input(shape=(s, graph_sage_infomax.input_feature_size)) for s in graph_sage_infomax.neighbourhood_sizes
]
# Output from GraphSAGE model
node_feats = graph_sage_infomax(x_inp)

node_feats_corrupted = graph_sage_infomax(x_inp_corrupted)

summary = sigmoid(tf.math.reduce_mean(node_feats, axis=0))
D = Dense(summary.shape[0], use_bias=False)

scores = sigmoid(tf.linalg.matvec(D(node_feats), summary))
scores_corrupted = 1 - sigmoid(tf.linalg.matvec(D(node_feats_corrupted), summary))

lscores = tf.math.log(scores)
lscores_corrupted = tf.math.log(scores_corrupted)

pm_out = tf.stack([scores, scores_corrupted], axis=1)
x_out = tf.stack([lscores, lscores_corrupted], axis=1)

In [15]:
vm = Model(inputs=x_inp + x_inp_corrupted, outputs=node_feats)
pm = Model(inputs=x_inp + x_inp_corrupted, outputs=pm_out)
model2 = Model(inputs=x_inp + x_inp_corrupted, outputs=x_out)
pm.predict(x[0])

array([[0.45425794, 0.41555315],
       [0.56886566, 0.6810629 ],
       [0.41740674, 0.64435476],
       [0.53161174, 0.76607174],
       [0.5371841 , 0.510937  ],
       [0.26571882, 0.54920506],
       [0.37447605, 0.53327376],
       [0.5255777 , 0.6897084 ],
       [0.5617466 , 0.57636213],
       [0.32220683, 0.48119074],
       [0.7011592 , 0.6251826 ],
       [0.37174746, 0.7186746 ],
       [0.35635832, 0.7043456 ],
       [0.5556387 , 0.63467777],
       [0.45765993, 0.6928941 ],
       [0.35241443, 0.4847107 ],
       [0.17910449, 0.78752965],
       [0.55537915, 0.20882374],
       [0.34006974, 0.47681653],
       [0.58879447, 0.5290415 ],
       [0.4058775 , 0.41675746],
       [0.4520637 , 0.6473123 ],
       [0.56054014, 0.55829203],
       [0.5313442 , 0.74214965],
       [0.48606178, 0.6824586 ],
       [0.42257836, 0.45931977],
       [0.5874176 , 0.66912305],
       [0.37184227, 0.6455971 ],
       [0.6507781 , 0.7001918 ],
       [0.6163117 , 0.6672619 ],
       [0.

In [16]:
model2.compile(loss=info_loss, optimizer='adam')
model2.fit(gen, epochs=10, verbose=0)

  ...
    to  
  ['...']


<tensorflow.python.keras.callbacks.History at 0x1a388fe128>

In [17]:
pm.predict(x[0])

array([[nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan]], dtype=float32)