<h2>Load Data</h2>

In [None]:
import sys
sys.path += ['/work/pytorch_geometric', '/work/gaas/python']

In [None]:
import rmm

rmm.reinitialize(pool_allocator=True,initial_pool_size=5e+9, maximum_pool_size=20e+9)

In [None]:
import cudf
import cugraph
from cugraph.experimental import PropertyGraph
from cugraph.gnn.gaas_extensions import load_reddit
G = load_reddit(None, '/work/data/reddit')
G

In [None]:
from cugraph.gnn.pyg_extensions import CuGraphData
cd = CuGraphData(G, reserved_keys=['id','type'])

<h2>Training</h2>

In [None]:
import torch
from torch_geometric.data import Data

TRAINING_ARGS = {
    'batch_size':1000,
    'fanout':[10,25],
    'num_epochs':1,
}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data = cd.to(device)
#data = Data(x=cd.x, edge_index = cd.edge_index, y = cd.y) #uncomment to run w/o cugraph

<h3>Create the Data Loader</h3>

In [None]:
from torch_geometric.loader import LinkNeighborLoader
import numpy as np

train_loader = LinkNeighborLoader(
    data,
    num_neighbors=TRAINING_ARGS['fanout'],
    batch_size=TRAINING_ARGS['batch_size'],
    shuffle=True
)

<h3>Define the GraphSAGE Model</h3>

In [None]:
import torch
from torch.nn import LSTM
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import SAGEConv

class SAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_channels = in_channels if i == 0 else hidden_channels
            out_channels = num_classes if i == num_layers - 1 else hidden_channels
            self.convs.append(SAGEConv(in_channels, out_channels))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != self.num_layers - 1:
                x = x.relu()
                x = F.dropout(x, p=0.3, training=self.training)
        return F.log_softmax(x, dim=1)



In [None]:
import pandas as pd
num_classes = len(pd.Series(np.array(cd.y, dtype=int)).unique())
num_classes

In [None]:
import torch.nn.functional as F
import tqdm
from math import ceil
from torch_geometric.loader import DataLoader
from torch.nn.functional import nll_loss

from datetime import datetime

import cupy
from cuml.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score

#model = SAGE(data.num_node_features, hidden_channels=256, num_classes=num_classes, num_layers=3)
model = SAGE(data.num_node_features, hidden_channels=64, num_classes=num_classes, num_layers=1)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

num_batches = int(ceil(data.num_nodes / TRAINING_ARGS['batch_size']))
#num_batches = 5

def train():
    model.train()

    total_loss = 0
    for i, sampled_data in enumerate(tqdm.tqdm(train_loader)):
        if i == num_batches:
            return total_loss / data.num_nodes # FIXME is this right?
        sampled_data = sampled_data.to(device)
        print(f'iter: {i}')
        print('# nodes: ', sampled_data.num_nodes)
        print('# edges: ', sampled_data.num_edges)
        #out = model(sampled_data.x[-1], sampled_data.edge_index[0:2])
        out = model(sampled_data.x, sampled_data.edge_index)
        
        #loss = F.nll_loss(out, sampled_data.y[-1].T[0].to(torch.long))
        loss = F.nll_loss(out, sampled_data.y.T.to(torch.long))
        print(f'loss: {loss}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * out.size(0)

@torch.no_grad()
def encode(loader):
    model.eval()

    xs, ys = [], []
    for i, data in enumerate(loader):
        print(f'encode {i}')
        if i == num_batches:
            break

        out = model(data.x, data.edge_index)
        xs.append(torch.argmax(out, dim=1))
        ys.append(data.y)

    return torch.cat(xs, dim=0), torch.cat(ys, dim=0)

@torch.no_grad()
def test():
    model.eval()
    eval_loader = LinkNeighborLoader(data, num_neighbors=TRAINING_ARGS['fanout'], batch_size=TRAINING_ARGS['batch_size'])
    x_out, y_out = encode(eval_loader)

    val_acc = 0
    test_acc = balanced_accuracy_score(cupy.from_dlpack(y_out.__dlpack__()).get(), cupy.from_dlpack(x_out.__dlpack__()).get())

    return val_acc, test_acc, x_out, y_out


for epoch in range(1, 1 + TRAINING_ARGS['num_epochs']):
    train_start = datetime.now()
    loss = train()
    train_end = datetime.now()
    print('train time:', (train_end - train_start).total_seconds())

    test_start = datetime.now()
    val_acc, test_acc, x_out, y_out = test()
    test_end = datetime.now()
    print('test time:', (test_end - test_start).total_seconds())
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')