In [5]:
import torch
import torch.multiprocessing as mp
from torch_geometric.datasets import Reddit 
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv
import torch.distributed as dist

In [6]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch_size):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        return x

In [7]:
def run(rank, world_size, dataset):
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

    train_loader = NeighborSampler(
        dataset[0].edge_index, 
        sizes=[15, 10, 5], 
        batch_size=64, 
        num_workers=4
    )

    model = SAGE(dataset[0].num_features, 256, dataset[0].num_classes).to(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(10):
        for data in train_loader:
            data = data.to(rank)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch_size)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()

In [8]:
dataset = Reddit('./data/Reddit')
world_size = torch.cuda.device_count()

Downloading https://data.dgl.ai/dataset/reddit.zip
Extracting data/Reddit/raw/reddit.zip
Processing...
Done!


In [17]:
torch.unique(dataset[0].y).numel()

41

In [18]:
dataset[0]

Data(x=[232965, 602], edge_index=[2, 114615892], y=[232965], train_mask=[232965], val_mask=[232965], test_mask=[232965])