# PyG+cuGraph Heterogeneous MAG Example with cuGraph-Service
# Skip notebook test

### Requires installation of PyG & cuGraph-Service
#### A cuGraph-Service Server must be running

## Setup

In [None]:
import pathlib
import os
from cugraph_service_client.client import CugraphServiceClient
# Create a new client instance
client = CugraphServiceClient()

# Set up the creation extensions
ext_path = os.path.join(
    pathlib.Path('__file__').parent.resolve(),
    'cgs_creation_extensions'
)
print(f'loading extensions from {ext_path}')
client.load_graph_creation_extensions(str(ext_path))

In [None]:
from cugraph_service_client.client import RemoteGraph

# This line may take a while if the data has not yet been downloaded.
graph_id = client.call_graph_creation_extension('create_mag')

pG = RemoteGraph(client, graph_id)

### Construct a Graph Store, Feature Store, and Loaders

In [None]:
from cugraph.experimental.pyg_extensions import to_pyg

feature_store, graph_store = to_pyg(pG)

In [None]:
from cugraph.experimental.pyg_extensions import CuGraphSampler
sampler = CuGraphSampler(
    data=(feature_store, graph_store),
    shuffle=True,
    num_neighbors=[10,25],
    batch_size=50,
)

In [None]:
from torch_geometric.loader import NodeLoader
loader = NodeLoader(
    data=(feature_store, graph_store),
    shuffle=True,
    batch_size=50,
    node_sampler=sampler,
    input_nodes=('author', graph_store.get_vertex_index('author'))
)

test_loader = NodeLoader(
    data=(feature_store, graph_store),
    shuffle=True,
    batch_size=50,
    node_sampler=sampler,
    input_nodes=('author', graph_store.get_vertex_index('author'))
)


### Create the Network

In [None]:
edge_types = [attr.edge_type for attr in graph_store.get_all_edge_attrs()]
edge_types

In [None]:
num_classes = pG.get_vertex_data(columns=['y'])['y'].max() + 1
num_classes

In [None]:
import torch
import torch.nn.functional as F

from torch_geometric.nn import HeteroConv, Linear, SAGEConv

class HeteroGNN(torch.nn.Module):
    def __init__(self, edge_types, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                edge_type: SAGEConv((-1, -1), hidden_channels)
                for edge_type in edge_types
            })
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
        print(x_dict, edge_index_dict)
        return self.lin(x_dict['paper'])


model = HeteroGNN(edge_types, hidden_channels=64, out_channels=num_classes,
                  num_layers=2).cuda()

with torch.no_grad():  # Initialize lazy modules.
    data = next(iter(loader))
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

num_batches = 5
def train():
    model.train()
    optimizer.zero_grad()
    for b_i, data in enumerate(loader):
        if b_i == num_batches:
            break

        out = model(data.x_dict, data.edge_index_dict)
        loss = F.cross_entropy(out, data.y_dict['paper'])
        loss.backward()
        optimizer.step()
    
    return float(loss) / num_batches


@torch.no_grad()
def test():
    model.eval()
    test_iter = iter(test_loader)

    acc = 0.0
    for _ in range(2*num_batches):
        data = next(test_iter)
        pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

        
        acc += (pred == data['paper'].y).sum() / len(data['paper'])
    return acc / (2*num_batches)


for epoch in range(1, 101):
    loss = train()
    train_acc = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}')


### Train the Network

In [None]:
for epoch in range(1, 101):
    loss = train()
    train_acc = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}')