In [1]:
import dgl
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import torch


# Define a Heterograph Conv model

class RGCN(nn.Module):
    def __init__(self, emb_types, emb_size, hid_feats, out_feats, rel_names):
        super().__init__()
        # https://www.jianshu.com/p/767950b560c4
        embed_dict = {ntype : nn.Parameter(torch.Tensor(emb_types[ntype], emb_size))
                      for ntype in emb_types.keys()}
        for key, embed in embed_dict.items():
            nn.init.xavier_uniform_(embed)
        self.embed = nn.ParameterDict(embed_dict)
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(emb_size, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph):
        # inputs are features of nodes
        h = self.conv1(graph, self.embed)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

In [2]:
g = dgl.load_graphs("./graphs/industrial_and_scientific_5_core.dgl")[0][0]
g

Graph(num_nodes={'Brand': 1900, 'Customer': 11041, 'Product': 5334, 'Review': 77071},
      num_edges={('Brand', 'rev_SOLD_BY', 'Product'): 5555, ('Customer', 'WROTE', 'Review'): 77071, ('Product', 'SOLD_BY', 'Brand'): 5555, ('Product', 'rev_REVIEW_OF', 'Review'): 77071, ('Review', 'REVIEW_OF', 'Product'): 77071, ('Review', 'rev_WROTE', 'Customer'): 77071},
      metagraph=[('Brand', 'Product', 'rev_SOLD_BY'), ('Product', 'Brand', 'SOLD_BY'), ('Product', 'Review', 'rev_REVIEW_OF'), ('Customer', 'Review', 'WROTE'), ('Review', 'Product', 'REVIEW_OF'), ('Review', 'Customer', 'rev_WROTE')])

In [3]:
model = RGCN({x: g.number_of_nodes(x) for x in g.ntypes}, 512, 256, 2, g.etypes)

labels = g.nodes["Review"].data["Positive"]
train_mask = g.nodes["Review"].data["train_mask"]
test_mask = g.nodes["Review"].data["test_mask"]

In [4]:
opt = torch.optim.Adam(model.parameters())

for epoch in range(250):
    model.train()

    logits = model(g)["Review"]
    loss = F.cross_entropy(logits[train_mask == 1], labels[train_mask == 1].type(torch.long))
    opt.zero_grad()
    loss.backward()
    opt.step()
    print("Epoch:", epoch, "Loss:", loss.item())



Epoch: 0 Loss: 0.6928020119667053
Epoch: 1 Loss: 0.671084463596344
Epoch: 2 Loss: 0.6490806937217712
Epoch: 3 Loss: 0.6235765814781189
Epoch: 4 Loss: 0.5933032631874084
Epoch: 5 Loss: 0.5576277375221252
Epoch: 6 Loss: 0.5167205333709717
Epoch: 7 Loss: 0.47172102332115173
Epoch: 8 Loss: 0.4247688353061676
Epoch: 9 Loss: 0.37867745757102966
Epoch: 10 Loss: 0.33642256259918213
Epoch: 11 Loss: 0.30055058002471924
Epoch: 12 Loss: 0.2725561559200287
Epoch: 13 Loss: 0.252458393573761
Epoch: 14 Loss: 0.23884126543998718
Epoch: 15 Loss: 0.22939857840538025
Epoch: 16 Loss: 0.22171899676322937
Epoch: 17 Loss: 0.21393801271915436
Epoch: 18 Loss: 0.20503370463848114
Epoch: 19 Loss: 0.19480569660663605
Epoch: 20 Loss: 0.18366236984729767
Epoch: 21 Loss: 0.1723494976758957
Epoch: 22 Loss: 0.16169394552707672
Epoch: 23 Loss: 0.15243180096149445
Epoch: 24 Loss: 0.14512765407562256
Epoch: 25 Loss: 0.14007924497127533
Epoch: 26 Loss: 0.13707202672958374
Epoch: 27 Loss: 0.13542404770851135
Epoch: 28 Loss:

In [5]:
preds = model(g)["Review"][test_mask == 1].detach().numpy()
y_test = labels[test_mask == 1]

In [6]:
from sklearn.metrics import classification_report

print(classification_report(y_test, preds.argmax(1)))

              precision    recall  f1-score   support

         0.0       0.24      0.23      0.23      4750
         1.0       0.89      0.90      0.90     33786

    accuracy                           0.82     38536
   macro avg       0.57      0.56      0.56     38536
weighted avg       0.81      0.82      0.81     38536



In [7]:
from sklearn.metrics import f1_score
f1_score(y_test, preds.argmax(1))

0.8960650569395542

In [8]:
from sklearn.metrics import recall_score
recall_score(y_test, preds.argmax(1))

0.9001361510684899

In [9]:
from sklearn.metrics import precision_score
precision_score(y_test, preds.argmax(1))

0.8920306221218431