In [5]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

import dgl
import dgl.data
from dgl.nn import GraphConv
from dgl.dataloading import GraphDataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)
    
    def forward(self, g, h):
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        with g.local_scope():
            g.ndata['h'] = h
            hg = dgl.mean_nodes(g, 'h')
            logits = self.classify(hg)
        
        return logits

In [6]:
dataset = dgl.data.GINDataset('MUTAG', False)

dataloader = GraphDataLoader(dataset, batch_size=1024, drop_last=False, shuffle=True)

In [7]:
dataloader = list(dataloader)

In [10]:
model = GCN(7, 20, 5)
model

GCN(
  (conv1): GraphConv(in=7, out=20, normalization=both, activation=None)
  (conv2): GraphConv(in=20, out=20, normalization=both, activation=None)
  (classify): Linear(in_features=20, out_features=5, bias=True)
)

In [16]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[ 0.4071, -0.3284,  0.1131, -0.1439,  0.1569,  0.4531, -0.2887,  0.4667,
                        0.2787,  0.2067, -0.3740, -0.0412,  0.1453, -0.2694,  0.0724, -0.3103,
                        0.2829,  0.1558, -0.4469,  0.2382],
                      [ 0.2141, -0.4440,  0.1798, -0.2074,  0.2522,  0.1700,  0.3815,  0.1576,
                        0.3682, -0.0556, -0.3880, -0.3067, -0.0109,  0.3489, -0.1416,  0.2301,
                        0.1008, -0.0502,  0.1041,  0.3265],
                      [-0.2747,  0.2322, -0.1554, -0.2519, -0.0464, -0.0629, -0.2695, -0.1921,
                        0.3574, -0.2102,  0.1038,  0.2773,  0.3896, -0.3642, -0.1041,  0.2278,
                       -0.0249, -0.1925, -0.4129, -0.0535],
                      [ 0.3161, -0.1561,  0.4137,  0.4025, -0.3294,  0.0369,  0.0079, -0.3210,
                        0.2718,  0.1824,  0.4682, -0.2233,  0.1201, -0.4672, -0.2559, -0.0445,
                       -0.2120

In [12]:
optim = torch.optim.Adam(model.parameters(), lr=0.01)

In [17]:
for epoch in range(20):
    for batched_graph, labels in dataloader:
        feats = batched_graph.ndata['attr']
        logits = model(batched_graph, feats)
        loss = F.cross_entropy(logits, labels)
        
        optim.zero_grad()
        loss.backward()
        optim.step()

In [19]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[ 0.3640, -0.1542,  0.2849, -0.0370,  0.3416,  0.6301, -0.3498,  0.6455,
                        0.4624,  0.3857, -0.3740,  0.1338,  0.2111, -0.3417,  0.2497, -0.3103,
                        0.4562,  0.2516, -0.4369,  0.3827],
                      [ 0.1087, -0.5791,  0.1289, -0.2526,  0.4409,  0.2870,  0.1751,  0.1103,
                        0.5113,  0.1463, -0.3880, -0.1992,  0.1954,  0.5575, -0.2779,  0.3645,
                        0.0164,  0.0149, -0.0574,  0.1705],
                      [-0.3335,  0.3267, -0.0420, -0.2783,  0.0443,  0.0068, -0.3066, -0.2291,
                        0.4723, -0.0059, -0.0280,  0.4410,  0.5805, -0.3642,  0.0265,  0.3609,
                       -0.0444, -0.1481, -0.4233, -0.1572],
                      [ 0.3515, -0.1972,  0.4163,  0.5588, -0.3176,  0.2084, -0.1730, -0.3210,
                        0.3433,  0.3848,  0.6382, -0.1717,  0.3230, -0.4672, -0.4420, -0.1459,
                       -0.2120