In [1]:
import dgl
import dgl.graphbolt as gb 
import torch

In [2]:
dataset = gb.BuiltinDataset("cora").load()

The dataset is already preprocessed.


In [11]:
device="cpu"
graph = dataset.graph.to(device)
feature = dataset.feature.to(device).pin_memory_()
assert feature.is_pinned()
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]
print(f"Task: {task_name}.")
print(f"Train_set: {train_set}")
print(f"Graph: {graph}")

Task: link_prediction.
Train_set: ItemSet(
    items=(tensor([[1408,  370],
        [1216, 2446],
        [ 887, 1623],
        ...,
        [ 462, 1048],
        [2602, 2603],
        [ 805,  963]], dtype=torch.int32),),
    names=('seeds',),
)
Graph: FusedCSCSamplingGraph(csc_indptr=tensor([    0,     3,     6,  ..., 10548, 10552, 10556], dtype=torch.int32),
                      indices=tensor([ 633, 1862, 2582,  ...,  598, 1473, 2706], dtype=torch.int32),
                      total_num_nodes=2708, num_edges=10556,)


In [5]:
from functools import partial
def create_train_dataloader():
    datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(graph, [5, 5])
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
    datapipe = datapipe.copy_to("cuda:0")
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)

In [6]:
feature.keys()

[(<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat')]

In [7]:
feature._features[('node',None,'feat')] = gb.GPUCachedFeature(feature._features[('node',None,'feat')], 512)

In [8]:
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F


class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

In [9]:
in_size = feature.size("node", None, "feat")[0]
model = SAGE(in_size, 128).to("cuda:0")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [10]:
from tqdm.auto import tqdm
import time
t_start = time.time()
for epoch in range(3):
    model.train()
    total_loss = 0
    for step, data in tqdm(enumerate(create_train_dataloader())):
        # Get node pairs with labels for loss calculation.
        compacted_pairs, labels = data.node_pairs_with_labels
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_pairs[0]] * y[compacted_pairs[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}")
print(f"elapsed time {time.time() - t_start}")

0it [00:00, ?it/s]

RuntimeError: Keys should be on a CUDA device.
This exception is thrown by __iter__ of FeatureFetcher(datapipe=MultiprocessingWrapper, edge_feature_keys=None, feature_store=TorchBasedFeatureStore(
    {(<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat'): <dgl.graphbolt.impl.gpu_cached_feature.GPUCachedFeature object at 0x7f313032d780>}
), node_feature_keys=['feat'])