In [1]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import function as fn
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
import matplotlib.pyplot as plt

#set gpu is available
if th.cuda.is_available():
    device = th.device("cuda")
    print("GPU is available")
else:
    device = th.device("cpu")
    print("GPU not available, CPU used")

GPU is available


In [2]:
#operation for neigbors
class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        if self.activation is not None:
            h = self.activation(h)
        return {'h' : h}
    
#gcn layer in network
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation, k):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
        self.k = k

    def forward(self, g, feature):
        
        g.nodes[g.nodes()[-self.k:]].data['h'] = feature
        
        #g.pull() #sums values up for gcn
        g.pull(g.nodes()[-self.k:])
        
        #g.apply_nodes(func=self.apply_mod)
        g.apply_nodes(self.apply_mod, v=g.nodes()[-self.k:])
        
        return g.nodes[g.nodes()[-self.k:]].pop('h')
    
#lstm layer in network
class G_LSTM(nn.Module):
    def __init__(self, in_feats, out_feats, k):
        super(G_LSTM, self).__init__()
        self.lstm = nn.LSTM(k*in_feats, out_feats*k, 1)
        self.out_feats = out_feats
        self.k = k

    def forward(self, g, feature, hidden):
        out, hidden = self.lstm(feature.view(1, -1), hidden)
        return out.view(self.k, self.out_feats), hidden
    
#network
class RGCN_L2(nn.Module):
    def __init__(self, in_feats, out_feats, k):
        super(RGCN_L2, self).__init__()
        self.gcn1 = GCN(in_feats, 800, F.relu, k)
        self.gcn2 = GCN(800, 100, F.relu, k)
        self.lstm1 = G_LSTM(100, 100, k)
        self.gcn3 = GCN(100, out_feats, F.tanh, k)
        
        self.k = k

    def forward(self, g, features, hidden):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        x, hidden = self.lstm1(g, x, hidden)
        
        return self.gcn3(g, x), hidden


In [62]:
#loss function
def similarity_matrix(mat):
    # get the product x * y
    # here, y = x.t()
    r = th.mm(mat, mat.t())
    # get the diagonal elements
    diag = r.diag().unsqueeze(0)
    diag = diag.expand_as(r)
    # compute the distance matrix
    D = diag + diag.t() - 2*r
    return D.sqrt()

def same_label(y):
    s = y.size(0)
    y_expand = y.unsqueeze(0).expand(s, s)
    Y = y_expand.eq(y_expand.t())
    return Y

def my_loss(output, labels):
    """
    if nodes with the same label: x^2
    if nodes with different label: 1/x
    """
    sim = similarity_matrix(output)
    same_l = same_label(labels)
    same_l_inv = same_l*(-1) + 1
    
    loss = (th.sum(th.mul(sim, same_l))**2) + (1/th.sum(th.mul(sim, same_l_inv))) 
    
    return loss

In [4]:
#load dataset
data = citegrh.load_cora()
ds_features = th.FloatTensor(data.features) #convert to pytorch data type #######
ds_labels = th.LongTensor(data.labels)
ds_train = th.ByteTensor(data.train_mask)
ds_test = th.ByteTensor(data.test_mask)
ds_g = data.graph

# add self loop for the sum of festures
ds_g.remove_edges_from(nx.selfloop_edges(ds_g))
ds_g = DGLGraph(ds_g)
ds_g.add_edges(ds_g.nodes(), ds_g.nodes())
ds_g.ndata['label'] = ds_labels #used to filter

m_func = fn.copy_src(src='h', out='m')
m_reduce_func = fn.sum(msg='m', out='h')

ds_g.register_message_func(m_func)
ds_g.register_reduce_func(m_reduce_func)

In [5]:
########### Create Model ############

#constant parameters
TRACK_K = 50 #number of nodes to track at once
DIST_VEC_SIZE = 10

model = RGCN_L2(ds_features.size()[1], DIST_VEC_SIZE, TRACK_K)


In [None]:
#training
opt = th.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(50):
    pass

In [8]:
ds_train[:10]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)

In [9]:
ds_test[:10]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)