In [7]:
import os.path as osp

import torch
import torch.nn.functional as F

from torch_geometric.datasets import Reddit
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv, GATConv

In [2]:
dataset = Reddit('~/data')

Downloading https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/reddit.zip
Extracting /home/ygx/data/raw/reddit.zip


In [6]:
dataset.data.x.shape

torch.Size([2708, 1433])

In [9]:
data = dataset[0]

loader = NeighborSampler(data, size=[25, 10], num_hops=2, batch_size=1000,
                         shuffle=True, add_self_loops=True)


class SAGENet(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SAGENet, self).__init__()
        self.conv1 = SAGEConv(in_channels, 16, normalize=False)
        self.conv2 = SAGEConv(16, out_channels, normalize=False)

    def forward(self, x, data_flow):
        data = data_flow[0]
        x = x[data.n_id]
        x = F.relu(self.conv1((x, None), data.edge_index, size=data.size))
        x = F.dropout(x, p=0.5, training=self.training)
        data = data_flow[1]
        x = self.conv2((x, None), data.edge_index, size=data.size)
        return F.log_softmax(x, dim=1)


class GATNet(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GATNet, self).__init__()
        self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=True,
                             dropout=0.6)

    def forward(self, x, data_flow):
        block = data_flow[0]
        x = x[block.n_id]
        x = F.elu(
            self.conv1((x, x[block.res_n_id]), block.edge_index,
                       size=block.size))
        x = F.dropout(x, p=0.6, training=self.training)
        block = data_flow[1]
        x = self.conv2((x, x[block.res_n_id]), block.edge_index,
                       size=block.size)
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device = torch.device('cpu')
#Net = SAGENet if args.model == 'SAGE' else GATNet
model = GATNet(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


def train():
    model.train()

    total_loss = 0
    for data_flow in loader(data.train_mask):
        optimizer.zero_grad()
        out = model(data.x.to(device), data_flow.to(device))
        loss = F.nll_loss(out, data.y[data_flow.n_id].to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data_flow.batch_size
    return total_loss / data.train_mask.sum().item()


def test(mask):
    model.eval()

    correct = 0
    for data_flow in loader(mask):
        pred = model(data.x.to(device), data_flow.to(device)).max(1)[1]
        correct += pred.eq(data.y[data_flow.n_id].to(device)).sum().item()
    return correct / mask.sum().item()


for epoch in range(1, 31):
    loss = train()
    test_acc = test(data.test_mask)
    print('Epoch: {:02d}, Loss: {:.4f}, Test: {:.4f}'.format(
        epoch, loss, test_acc))

Epoch: 01, Loss: 1.9271, Test: 0.5220
Epoch: 02, Loss: 1.6878, Test: 0.6620
Epoch: 03, Loss: 1.4352, Test: 0.7020
Epoch: 04, Loss: 1.2650, Test: 0.7220
Epoch: 05, Loss: 1.0983, Test: 0.7350
Epoch: 06, Loss: 0.9570, Test: 0.7590
Epoch: 07, Loss: 0.8296, Test: 0.7760
Epoch: 08, Loss: 0.8989, Test: 0.7740
Epoch: 09, Loss: 0.7911, Test: 0.7870
Epoch: 10, Loss: 0.6484, Test: 0.7940
Epoch: 11, Loss: 0.7406, Test: 0.7990
Epoch: 12, Loss: 0.6380, Test: 0.8050
Epoch: 13, Loss: 0.5891, Test: 0.8030
Epoch: 14, Loss: 0.5450, Test: 0.7950
Epoch: 15, Loss: 0.6128, Test: 0.7970
Epoch: 16, Loss: 0.5073, Test: 0.7940
Epoch: 17, Loss: 0.5181, Test: 0.7880
Epoch: 18, Loss: 0.5492, Test: 0.7860
Epoch: 19, Loss: 0.4379, Test: 0.7840
Epoch: 20, Loss: 0.4601, Test: 0.7840
Epoch: 21, Loss: 0.4202, Test: 0.7840
Epoch: 22, Loss: 0.4937, Test: 0.7790
Epoch: 23, Loss: 0.4239, Test: 0.7800
Epoch: 24, Loss: 0.4327, Test: 0.7860
Epoch: 25, Loss: 0.3609, Test: 0.7850
Epoch: 26, Loss: 0.4742, Test: 0.7860
Epoch: 27, L