Skip to content
2 contributors

Users who have contributed to this file

@rusty1s @kkonevets
95 lines (76 sloc) 3.2 KB
import os.path as osp
import argparse
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Reddit
from import NeighborSampler
from torch_geometric.nn import SAGEConv, GATConv
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='SAGE')
args = parser.parse_args()
assert args.model in ['SAGE', 'GAT']
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
data = dataset[0]
loader = NeighborSampler(data, size=[25, 10], num_hops=2, batch_size=1000,
shuffle=True, add_self_loops=True)
class SAGENet(torch.nn.Module):
def __init__(self, in_channels, out_channels, concat=False):
super(SAGENet, self).__init__()
self.conv1 = SAGEConv(in_channels, 16, normalize=False, concat=concat)
self.conv2 = SAGEConv(16, out_channels, normalize=False, concat=concat)
def forward(self, x, data_flow):
data = data_flow[0]
x = x[data.n_id]
x = F.relu(
self.conv1((x, None), data.edge_index, size=data.size,
x = F.dropout(x, p=0.5,
data = data_flow[1]
x = self.conv2((x, None), data.edge_index, size=data.size,
return F.log_softmax(x, dim=1)
class GATNet(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(GATNet, self).__init__()
self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=True,
def forward(self, x, data_flow):
block = data_flow[0]
x = x[block.n_id]
x = F.elu(
self.conv1((x, x[block.res_n_id]), block.edge_index,
x = F.dropout(x, p=0.6,
block = data_flow[1]
x = self.conv2((x, x[block.res_n_id]), block.edge_index,
return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Net = SAGENet if args.model == 'SAGE' else GATNet
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
total_loss = 0
for data_flow in loader(data.train_mask):
out = model(,
loss = F.nll_loss(out, data.y[data_flow.n_id].to(device))
total_loss += loss.item() * data_flow.batch_size
return total_loss / data.train_mask.sum().item()
def test(mask):
correct = 0
for data_flow in loader(mask):
pred = model(,[1]
correct += pred.eq(data.y[data_flow.n_id].to(device)).sum().item()
return correct / mask.sum().item()
for epoch in range(1, 31):
loss = train()
test_acc = test(data.test_mask)
print('Epoch: {:02d}, Loss: {:.4f}, Test: {:.4f}'.format(
epoch, loss, test_acc))
You can’t perform that action at this time.