In [None]:
import time

%env DGLBACKEND=mxnet
import dgl
import dgl.function as fn
import matplotlib.pylab as plt
%matplotlib notebook
import mxnet as mx
from mxnet import gluon
import networkx as nx
import numpy as np

In [None]:
alpha = 0.1
batch_size = 16

N = 2 # the number of chains
L = 100 # the length of a chain
r_train = 0.1 # the ratio of test nodes

In [None]:
path_graph = nx.path_graph(N * L).to_directed()
for i in range(N - 1): # break the path graph into N chains
    path_graph.remove_edge((i + 1) * L - 1, (i + 1) * L)
    path_graph.remove_edge((i + 1) * L, (i + 1) * L - 1)
for n in path_graph.nodes:
    path_graph.add_edge(n, n) # add self connections

g = dgl.DGLGraph(path_graph, readonly=True)
# g.from_networkx(path_graph)

In [None]:
def T(g):
    message_func = fn.copy_src('y', 'm')
    reduce_func = fn.max('m', 'y')
    g.update_all(message_func, reduce_func)

In [None]:
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(111)
fig.show()
fig.canvas.draw()

pos = nx.random_layout(path_graph)
nx.draw_networkx_nodes(path_graph, pos, range(N), 100, 'r')
nx.draw_networkx_nodes(path_graph, pos, range(N, 2 * N), 100, 'g')

In [None]:
s = 0
y = mx.nd.zeros([N * L, 1])
y[s] = 1
g.ndata['y'] = y

for i in range(L):
    T(g)
    x = (g.ndata['y'] == 1).asnumpy().nonzero()[0].tolist()
    nx.draw_networkx_nodes(path_graph, pos, x, 100, 'b')
    fig.canvas.draw()
    time.sleep(1e-3)

In [None]:
class SteadyStateOperator(gluon.Block):
    def __init__(self, n_hidden, activation, **kwargs):
        super(SteadyStateOperator, self).__init__(**kwargs)
        with self.name_scope():
            self.dense1 = gluon.nn.Dense(n_hidden, activation=activation)
            self.dense2 = gluon.nn.Dense(n_hidden)
        
    def forward(self, g):
        def message_func(edges):
            return {'m' : mx.nd.concat(edges.src['x'], edges.src['h'], dim=1)}
        
        def reduce_func(nodes):
            m = mx.nd.sum(nodes.mailbox['m'], axis=1) / nodes.data['deg']
            z = mx.nd.concat(nodes.data['x'], m, dim=1)
            return {'h' : self.dense2(self.dense1(z))}
        
        g.update_all(message_func, reduce_func)

class Predictor(gluon.Block):
    def __init__(self, n_hidden, n_classes, activation, **kwargs):
        super(Predictor, self).__init__(**kwargs)
        with self.name_scope():
            self.dense1 = gluon.nn.Dense(n_hidden, activation=activation)
            self.dense2 = gluon.nn.Dense(n_classes)

    def forward(self, g):        
        def apply_node_func(nodes):
            return {'z' : self.dense2(self.dense1(nodes.data['h']))}
        g.apply_nodes(apply_node_func)

In [None]:
def update_embeddings(g, steady_state_operator):
    prev = g.ndata['h']
    steady_state_operator(g)
    g.ndata['h'] = (1 - alpha) * prev + g.ndata['h']

def update_parameters(g, steady_state_operator, predictor, trainer):
    n = g.number_of_nodes()
    prev = g.ndata['h']
    with mx.autograd.record():
        steady_state_operator(g)
        predictor(g)
        z = g.ndata['z']
        y = g.ndata['y'].reshape(n)
        loss = mx.nd.softmax_cross_entropy(z, y)
    loss.backward()
    trainer.step(n)
    g.ndata['h'] = prev
    return loss.asnumpy()[0]

In [None]:
def train(g, steady_state_operator, predictor, trainer):
    for i in range(n_embedding_updates):
        update_embeddings(g, steady_state_operator)
    for i in range(n_parameter_updates):
        loss = update_parameters(g, steady_state_operator, predictor, trainer)
    return loss

def test(g, steady_state_operator, predictor):
    for i in range(L):
        update_embeddings(g, steady_state_operator)
    predictor(g)
    y_bar = mx.nd.argmax(g.ndata['z'], axis=1)
    n = g.number_of_nodes()
    y = g.ndata['y'].reshape(n)
    accuracy = mx.nd.sum(y_bar == y) / n
    return accuracy.asnumpy()[0]

In [None]:
n_feats = N
n_hidden = 16
activation = 'relu'
lr = 1e-3

steady_state_operator = SteadyStateOperator(n_hidden, activation)
predictor = Predictor(n_hidden, N, activation)
steady_state_operator.initialize()
predictor.initialize()
params = steady_state_operator.collect_params()
params.update(predictor.collect_params())
trainer = gluon.Trainer(params, 'adam', {'learning_rate' : lr})

In [None]:
for scheme in g.node_attr_schemes():
    g.pop_n_repr(scheme)

n = g.number_of_nodes()
g.ndata['x'] = mx.nd.zeros([g.number_of_nodes(), n_feats])
# for i in range(N):
#     g.ndata['x'][i * L][i] = 1
g.ndata['y'] = mx.nd.concat(*[i * mx.nd.ones([L, 1], dtype='float32') for i in range(N)], dim=0)
g.ndata['h'] = mx.nd.random_normal(shape=[n, n_hidden])
g.ndata['deg'] = mx.nd.cast(g.in_degrees(range(g.number_of_nodes())).reshape(n, 1), 'float32')

n_train = int(r_train * N * L)

nodes_train = sum([list(range(i * L, i * L + n_train)) for i in range(N)], [])
g_train = g.subgraph(nodes_train) # subgraph for training
g_train.copy_from_parent()

nodes_test = sum([list(range(i * L + n_train, (i + 1) * L)) for i in range(N)], [])
g_test = g.subgraph(nodes_test) # subgraph for test
g_test.copy_from_parent()

In [None]:
n_epochs = 100
n_embedding_updates = 10
n_parameter_updates = 10
alpha = 0.1

for i in range(n_epochs):
    loss = train(g_train, steady_state_operator, predictor, trainer)
    accuracy = test(g_train, steady_state_operator, predictor)
    print('[epoch %d]loss: %.3f, accuracy: %.3f' % (i, loss, accuracy))

In [None]:
for sub_g in dgl.contrib.sampling.NeighborSampler(g, batch_size, 3): # because
    pass

def train_on_subgraphs(g, steady_state_operator, predictor):
    for i in range(n_embedding_updates):
        update_embeddings(g, steady_state_operator)
    for i in range(n_parameter_updates):
        loss = update_parameters(g, steady_state_operator, predictor)
    return loss