In [42]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import function as fn
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
import matplotlib.pyplot as plt


is_cuda = th.cuda.is_available()

if is_cuda:
    device = th.device("cuda")
    print("GPU is available")
else:
    device = th.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


In [4]:
#monitor and update the graph with new nodes
class JGraph():
    def __init__(self, g, k): #g graph, g_m
        reset(g, k)
        
    def addNode(self):
        self.count+=1
        self.G_curr = self.G_orig.subgraph(list(range(self.count)))
        self.G_curr.register_message_func(m_func)
        self.G_curr.register_reduce_func(m_reduce_func)
        
        ##### check if send and recv functions are kept after subgraph
        
        return self.G_curr
        
    def reset(self, g=None, k=None):
        if g is not None: self.G_orig = g
        if k is not None: self.k = k
        self.count = self.k #added nodes
        self.G_curr = self.G_orig.subgraph(list(range(self.count)))
        self.G_curr.register_message_func(m_func)
        self.G_curr.register_reduce_func(m_reduce_func)
        
    def lastKNodes(self):
        return self.G_curr.edges[self.count-self.k:self.count]

    def pull(self):
        self.G_curr.pull(list(range(self.count-self.k, self.count)))
    
    def apply_nodes(self, func):
        self.G_curr.apply_nodes(func, v=list(range(self.count-self.k, self.count)))

In [40]:
#operation for neigbors
class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        if self.activation is not None:
            h = self.activation(h)
        return {'h' : h}
    
#gcn layer in network
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.lastKNodes().ndata['h'] = feature
        g.pull() #sums values up for gcn
        g.apply_nodes(func=self.apply_mod)
        return lastKNodes().pop('h')
    
#lstm layer in network
class G_LSTM(nn.Module):
    def __init__(self, in_feats, out_feats, k):
        super(G_LSTM, self).__init__()
        self.lstm = nn.LSTM(k*in_feats, out_feats*k, 1)
        self.out_feats = out_feats

    def forward(self, g, feature, hidden):
        out, hidden = self.lstm(feature.view(1, -1), hidden)
        return out.view(g.k, self.out_feats), hidden
    
#network
class RGCN_L2(nn.Module):
    def __init__(self, in_feats, out_feats, k):
        super(RGCN_L2, self).__init__()
        self.gcn1 = GCN(in_feats, 800, F.relu)
        self.gcn2 = GCN(800, 100, F.relu)
        self.lstm1 = G_LSTM(100, 100, k)
        self.gcn3 = GCN(100, out_feats, F.sigmoid)

    def forward(self, g, features, hidden):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        x, hidden = self.lstm1(g, x, hidden)
        x = self.gcn3(g, x)
        return x, hidden

In [14]:
def my_loss(output, target, ):
    #close distance for output with the same target label, far otherwise
    loss = torch.mean((output target)**2)
    return loss

In [38]:
#constant parameters

m_func = fn.copy_src(src='h', out='m')
m_reduce_func = fn.sum(msg='m', out='h')

TRACK_K = 50 #number of nodes to track at once
DIST_VEC_SIZE = 10

#load dataset
data = citegrh.load_cora()
ds_features = th.FloatTensor(data.features) #convert to pytorch data type #######
ds_labels = th.LongTensor(data.labels)
ds_train = th.ByteTensor(data.train_mask)
ds_test = th.ByteTensor(data.test_mask)
ds_g = data.graph

# add self loop for the sum of festures
ds_g.remove_edges_from(nx.selfloop_edges(ds_g))
ds_g = DGLGraph(ds_g)
ds_g.add_edges(ds_g.nodes(), ds_g.nodes())
ds_g.ndata['label'] = ds_labels

ds_g.register_message_func(m_func)
ds_g.register_reduce_func(m_reduce_func)


model_g = JGraph(ds_g, TRACK_K) #model's graph

model = RGCN_L2(ds_features.size()[1], DIST_VEC_SIZE, model_g.k)

In [39]:
ds_g.ndata

{'label': tensor([2, 5, 4,  ..., 1, 0, 2])}