# Using GraphSAGE to Predict Product Classifications
## TigerGraph ML Team

In this notebook, we will be using a dataset from OGB to predict product classifications using GraphSAGE.
The data is stored in a TigerGraph instance, and we will be using the `tgml` library to load it into Python.

## Setup Connection
In this section, we setup the connection to the TigerGraph instance, and verify that the data is loaded.

In [None]:
from tgml.data import TigerGraph

tgraph = TigerGraph(
    host="http://18.222.126.26", # Replace with your instance ip
    graph="OGBNProducts",
    username="tigergraph",
    password="tigergraph",
    token_auth=False
)

In [None]:
tgraph.info()

In [None]:
tgraph.number_of_vertices()

In [None]:
tgraph.number_of_edges()

In [None]:
print(
    "Number of vertices in training set:",
    tgraph.number_of_vertices(filter_by="train_mask"),
)
print(
    "Number of vertices in validation set:",
    tgraph.number_of_vertices(filter_by="val_mask"),
)
print(
    "Number of vertices in test set:", tgraph.number_of_vertices(filter_by="test_mask")
)

## Define Hyperparameters
Here, we define the hyperparameters for the model.

In [None]:
hp = {
    "batch_size": 1024,
    "num_neighbors": 20,
    "num_hops": 2,
    "hidden_dim": 128,
    "num_layers": 2,
    "dropout": 0.1,
    "lr":0.01,
    "l2_penalty":0
}

## Setup Dataloaders
We will be using the `tgml` library to load the data into Python, specifically the `NeighborLoader` functionality.
`NeighborLoader` creates batches of vertices and their neighbors, as described in the original GraphSAGE paper.

In [None]:
from tgml.dataloaders import NeighborLoader

In [None]:
train_loader = NeighborLoader(
    graph=tgraph,
    tmp_id="tmp_id",
    v_in_feats="x",
    v_out_labels="y:int",
    v_extra_feats="train_mask:bool,val_mask:bool,test_mask:bool",
    output_format="PyG",
    batch_size=hp["batch_size"],
    num_neighbors=hp["num_neighbors"],
    num_hops=hp["num_hops"],
    shuffle=True,
    filter_by="train_mask",
)

In [None]:
valid_loader = NeighborLoader(
    graph=tgraph,
    tmp_id="tmp_id2",
    v_in_feats="x",
    v_out_labels="y:int",
    v_extra_feats="train_mask:bool,val_mask:bool,test_mask:bool",
    output_format="PyG",
    batch_size=hp["batch_size"],
    num_neighbors=hp["num_neighbors"],
    num_hops=hp["num_hops"],
    shuffle=False,
    filter_by="val_mask",
)

In [None]:
test_loader = NeighborLoader(
    graph=tgraph,
    tmp_id="tmp_id3",
    v_in_feats="x",
    v_out_labels="y:int",
    v_extra_feats="train_mask:bool,val_mask:bool,test_mask:bool",
    output_format="PyG",
    batch_size=hp["batch_size"],
    num_neighbors=hp["num_neighbors"],
    num_hops=hp["num_hops"],
    shuffle=False,
    filter_by="test_mask",
)

## Define Model
We will use the `GraphSAGE` model from PyTorch Geometric, using our hyperparameters defined above.

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GraphSAGE

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = GraphSAGE(
    in_channels=100, # dimension of x feature vectors
    hidden_channels=hp["hidden_dim"],
    num_layers=hp["num_layers"],
    out_channels=47,
    dropout=hp["dropout"],
).to(device)

optimizer = torch.optim.Adam(
    model.parameters(), lr=hp["lr"], weight_decay=hp["l2_penalty"]
)

## Train the Model
We will train the model, while logging various metrics with Tensorboard.

In [None]:
from datetime import datetime

from tgml.metrics import Accumulator, Accuracy
from torch.utils.tensorboard import SummaryWriter

In [None]:
log_dir = "logs/products/graphsage/subgraph/" + datetime.now().strftime("%Y%m%d-%H%M%S")
train_log = SummaryWriter(log_dir+"/train")
valid_log = SummaryWriter(log_dir+"/valid")
global_steps = 0
logs = {}
for epoch in range(10):
    # Train
    model.train()
    epoch_train_loss = Accumulator()
    epoch_train_acc = Accuracy()
    for bid, batch in enumerate(train_loader):
        batchsize = batch.x.shape[0]
        batch.to(device)
        # Forward pass
        out = model(batch.x, batch.edge_index)
        # Calculate loss
        loss = F.cross_entropy(out[batch.train_mask], batch.y[batch.train_mask])
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_train_loss.update(loss.item() * batchsize, batchsize)
        # Predict on training data
        with torch.no_grad():
            pred = out.argmax(dim=1)
            epoch_train_acc.update(pred[batch.train_mask], batch.y[batch.train_mask])
        # Log training status after each batch
        logs["loss"] = epoch_train_loss.mean
        logs["acc"] = epoch_train_acc.value
        print(
            "Epoch {}, Train Batch {}, Loss {:.4f}, Accuracy {:.4f}".format(
                epoch, bid, logs["loss"], logs["acc"]
            )
        )
        train_log.add_scalar("Loss", logs["loss"], global_steps)
        train_log.add_scalar("Accuracy", logs["acc"], global_steps)
        train_log.flush()
        global_steps += 1
    # Evaluate
    model.eval()
    epoch_val_loss = Accumulator()
    epoch_val_acc = Accuracy()
    for batch in valid_loader:
        batchsize = batch.x.shape[0]
        batch.to(device)
        with torch.no_grad():
            # Forward pass
            out = model(batch.x, batch.edge_index)
            # Calculate loss
            valid_loss = F.cross_entropy(out[batch.val_mask], batch.y[batch.val_mask])
            epoch_val_loss.update(valid_loss.item() * batchsize, batchsize)
            # Prediction
            pred = out.argmax(dim=1)
            epoch_val_acc.update(pred[batch.val_mask], batch.y[batch.val_mask])
    # Log testing result after each epoch
    logs["val_loss"] = epoch_val_loss.mean
    logs["val_acc"] = epoch_val_acc.value
    print(
        "Epoch {}, Valid Loss {:.4f}, Valid Accuracy {:.4f}".format(
            epoch, logs["val_loss"], logs["val_acc"]
        )
    )
    valid_log.add_scalar("Loss", logs["val_loss"], global_steps)
    valid_log.add_scalar("Accuracy", logs["val_acc"], global_steps)
    valid_log.flush()



## Evaluate Model
We will evaluate the model on the test set, and print the accuracy.

In [None]:
model.eval()
acc = Accuracy()
for batch in test_loader:
    batch.to(device)
    with torch.no_grad():
        pred = model(batch.x, batch.edge_index).argmax(dim=1)
        acc.update(pred[batch.test_mask], batch.y[batch.test_mask])
print("Accuracy: {:.4f}".format(acc.value))