In [1]:
import copy
import os.path as osp

import torch
import torch.nn.functional as F
from tqdm import tqdm

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv

# download and loading the obg dataset
path = osp.join(osp.dirname(osp.realpath('./')), 'data', 'Reddit')
dataset = Reddit(path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device, 'x', 'y')

BS=2048
kwargs = {'batch_size': BS, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                              num_neighbors=[25, 10], shuffle=True, **kwargs)

subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
                                 num_neighbors=[-1], shuffle=False, **kwargs)
# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)
# Already send node features/labels to GPU for faster access during sampling:

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [3]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  549652 KB |  549652 KB |  549652 KB |       0 B  |
|       from large pool |  549652 KB |  549652 KB |  549652 KB |       0 B  |
|       from small pool |       0 KB |       0 KB |       0 KB |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |  549652 KB |  549652 KB |  549652 KB |       0 B  |
|       from large pool |  549652 KB |  549652 KB |  549652 KB |       0 B  |
|       from small pool |       0 KB |       0 KB |       0 KB |       0 B  |
|---------------------------------------------------------------

In [7]:
LR=0.001
HC=8
NC=3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device, 'x', 'y')
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for i in range(NC-2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = x.relu_()
                x = F.dropout(x, p=0.5, training=self.training)
            if i < len(self.convs) - 1 and i>0:
                x = conv(x, edge_index)
                x = x.relu_()
                x = F.dropout(x, p=0.5, training=self.training)
        return x

    @torch.no_grad()
    def inference(self, x_all, subgraph_loader):
        print("testtest")
        pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch:
        for i, conv in enumerate(self.convs):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device)
                x = conv(x, batch.edge_index.to(device))
                if i < len(self.convs) - 1:
                    x = x.relu_()
                xs.append(x[:batch.batch_size].cpu())
                pbar.update(batch.batch_size)
            x_all = torch.cat(xs, dim=0)
            if i < len(self.convs) - 1 and i>0:
                xs = []
                for batch in subgraph_loader:
                    x = x_all[batch.n_id.to(x_all.device)].to(device)
                    x = conv(x, batch.edge_index.to(device))
                    x = x.relu_()
                    xs.append(x[:batch.batch_size].cpu())
                x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all
    
model = SAGE(dataset.num_features, HC, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

def train(epoch):
    model.train()

    pbar = tqdm(total=int(len(train_loader.dataset)))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = total_correct = total_examples = 0
    for batch in train_loader:
        optimizer.zero_grad()
        y = batch.y[:batch.batch_size]
        y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * batch.batch_size
        total_correct += int((y_hat.argmax(dim=-1) == y).sum())
        total_examples += batch.batch_size
        pbar.update(batch.batch_size)
    
    pbar.close()

    return total_loss / total_examples, total_correct / total_examples
@torch.no_grad()
def test():
    model.eval()
    y_hat = model.inference(data.x, subgraph_loader).argmax(dim=-1)
    y = data.y.to(y_hat.device)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((y_hat[mask] == y[mask]).sum()) / int(mask.sum()))
    return accs

In [8]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 3         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    1073 MB |    4387 MB |    4456 GB |    4455 GB |
|       from large pool |    1073 MB |    4386 MB |    4446 GB |    4445 GB |
|       from small pool |       0 MB |       2 MB |       9 GB |       9 GB |
|---------------------------------------------------------------------------|
| Active memory         |    1073 MB |    4387 MB |    4456 GB |    4455 GB |
|       from large pool |    1073 MB |    4386 MB |    4446 GB |    4445 GB |
|       from small pool |       0 MB |       2 MB |       9 GB |       9 GB |
|---------------------------------------------------------------

In [9]:
for epoch in range(1, 50):
    loss, acc = train(epoch)
    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
      f'Test: {test_acc:.4f}')

Epoch 01: 100%|██████████| 153431/153431 [00:04<00:00, 36769.70it/s]


Epoch 01, Loss: 3.5768, Approx. Train: 0.0613


Epoch 02: 100%|██████████| 153431/153431 [00:04<00:00, 37602.46it/s]


Epoch 02, Loss: 2.9223, Approx. Train: 0.1979


Epoch 03: 100%|██████████| 153431/153431 [00:04<00:00, 37794.63it/s]


Epoch 03, Loss: 2.3836, Approx. Train: 0.3421


Epoch 04: 100%|██████████| 153431/153431 [00:03<00:00, 40039.42it/s]


Epoch 04, Loss: 2.0031, Approx. Train: 0.4481


Epoch 05: 100%|██████████| 153431/153431 [00:04<00:00, 36631.83it/s]


Epoch 05, Loss: 1.7077, Approx. Train: 0.5334


Epoch 06: 100%|██████████| 153431/153431 [00:04<00:00, 38254.35it/s]


Epoch 06, Loss: 1.4882, Approx. Train: 0.6014


Epoch 07: 100%|██████████| 153431/153431 [00:03<00:00, 38789.68it/s]


Epoch 07, Loss: 1.3454, Approx. Train: 0.6418


Epoch 08: 100%|██████████| 153431/153431 [00:03<00:00, 39435.37it/s]


Epoch 08, Loss: 1.2443, Approx. Train: 0.6744


Epoch 09: 100%|██████████| 153431/153431 [00:04<00:00, 37223.75it/s]


Epoch 09, Loss: 1.1734, Approx. Train: 0.6974


Epoch 10: 100%|██████████| 153431/153431 [00:04<00:00, 38184.57it/s]


Epoch 10, Loss: 1.1266, Approx. Train: 0.7159


Epoch 11: 100%|██████████| 153431/153431 [00:04<00:00, 34915.52it/s]


Epoch 11, Loss: 1.0827, Approx. Train: 0.7308


Epoch 12: 100%|██████████| 153431/153431 [00:03<00:00, 38640.27it/s]


Epoch 12, Loss: 1.0477, Approx. Train: 0.7422


Epoch 13: 100%|██████████| 153431/153431 [00:04<00:00, 37864.71it/s]


Epoch 13, Loss: 1.0185, Approx. Train: 0.7530


Epoch 14: 100%|██████████| 153431/153431 [00:03<00:00, 38730.63it/s]


Epoch 14, Loss: 0.9929, Approx. Train: 0.7611


Epoch 15: 100%|██████████| 153431/153431 [00:03<00:00, 39139.66it/s]


Epoch 15, Loss: 0.9696, Approx. Train: 0.7706


Epoch 16: 100%|██████████| 153431/153431 [00:04<00:00, 38351.91it/s]


Epoch 16, Loss: 0.9503, Approx. Train: 0.7769


Epoch 17: 100%|██████████| 153431/153431 [00:04<00:00, 37383.66it/s]


Epoch 17, Loss: 0.9364, Approx. Train: 0.7810


Epoch 18: 100%|██████████| 153431/153431 [00:03<00:00, 39491.70it/s]


Epoch 18, Loss: 0.9195, Approx. Train: 0.7865


Epoch 19: 100%|██████████| 153431/153431 [00:03<00:00, 38626.85it/s]


Epoch 19, Loss: 0.9068, Approx. Train: 0.7889


Epoch 20: 100%|██████████| 153431/153431 [00:04<00:00, 35455.23it/s]


Epoch 20, Loss: 0.9022, Approx. Train: 0.7908


Epoch 21: 100%|██████████| 153431/153431 [00:04<00:00, 38172.86it/s]


Epoch 21, Loss: 0.8903, Approx. Train: 0.7943


Epoch 22: 100%|██████████| 153431/153431 [00:04<00:00, 38163.21it/s]


Epoch 22, Loss: 0.8791, Approx. Train: 0.7969


Epoch 23: 100%|██████████| 153431/153431 [00:03<00:00, 38550.35it/s]


Epoch 23, Loss: 0.8734, Approx. Train: 0.7992


Epoch 24: 100%|██████████| 153431/153431 [00:04<00:00, 38282.59it/s]


Epoch 24, Loss: 0.8675, Approx. Train: 0.8007


Epoch 25: 100%|██████████| 153431/153431 [00:03<00:00, 40210.93it/s]


Epoch 25, Loss: 0.8569, Approx. Train: 0.8036


Epoch 26: 100%|██████████| 153431/153431 [00:04<00:00, 37539.36it/s]


Epoch 26, Loss: 0.8543, Approx. Train: 0.8042


Epoch 27: 100%|██████████| 153431/153431 [00:04<00:00, 37404.93it/s]


Epoch 27, Loss: 0.8483, Approx. Train: 0.8062


Epoch 28: 100%|██████████| 153431/153431 [00:03<00:00, 40461.25it/s]


Epoch 28, Loss: 0.8389, Approx. Train: 0.8088


Epoch 29: 100%|██████████| 153431/153431 [00:03<00:00, 38742.18it/s]


Epoch 29, Loss: 0.8382, Approx. Train: 0.8093


Epoch 30: 100%|██████████| 153431/153431 [00:04<00:00, 38010.27it/s]


Epoch 30, Loss: 0.8304, Approx. Train: 0.8115


Epoch 31: 100%|██████████| 153431/153431 [00:03<00:00, 39433.69it/s]


Epoch 31, Loss: 0.8296, Approx. Train: 0.8118


Epoch 32: 100%|██████████| 153431/153431 [00:03<00:00, 39500.74it/s]


Epoch 32, Loss: 0.8197, Approx. Train: 0.8135


Epoch 33: 100%|██████████| 153431/153431 [00:03<00:00, 38933.74it/s]


Epoch 33, Loss: 0.8190, Approx. Train: 0.8139


Epoch 34: 100%|██████████| 153431/153431 [00:04<00:00, 36153.84it/s]


Epoch 34, Loss: 0.8141, Approx. Train: 0.8154


Epoch 35: 100%|██████████| 153431/153431 [00:04<00:00, 37332.24it/s]


Epoch 35, Loss: 0.8120, Approx. Train: 0.8168


Epoch 36: 100%|██████████| 153431/153431 [00:03<00:00, 38392.76it/s]


Epoch 36, Loss: 0.8066, Approx. Train: 0.8170


Epoch 37: 100%|██████████| 153431/153431 [00:04<00:00, 38149.01it/s]


Epoch 37, Loss: 0.8066, Approx. Train: 0.8178


Epoch 38: 100%|██████████| 153431/153431 [00:04<00:00, 38320.70it/s]


Epoch 38, Loss: 0.8013, Approx. Train: 0.8190


Epoch 39: 100%|██████████| 153431/153431 [00:04<00:00, 37519.58it/s]


Epoch 39, Loss: 0.7986, Approx. Train: 0.8193


Epoch 40: 100%|██████████| 153431/153431 [00:03<00:00, 40047.67it/s]


Epoch 40, Loss: 0.7983, Approx. Train: 0.8204


Epoch 41: 100%|██████████| 153431/153431 [00:04<00:00, 37022.69it/s]


Epoch 41, Loss: 0.7912, Approx. Train: 0.8213


Epoch 42: 100%|██████████| 153431/153431 [00:03<00:00, 39020.38it/s]


Epoch 42, Loss: 0.7913, Approx. Train: 0.8213


Epoch 43: 100%|██████████| 153431/153431 [00:04<00:00, 38119.37it/s]


Epoch 43, Loss: 0.7875, Approx. Train: 0.8226


Epoch 44: 100%|██████████| 153431/153431 [00:04<00:00, 36382.42it/s]


Epoch 44, Loss: 0.7837, Approx. Train: 0.8224


Epoch 45: 100%|██████████| 153431/153431 [00:04<00:00, 38205.51it/s]


Epoch 45, Loss: 0.7814, Approx. Train: 0.8242


Epoch 46: 100%|██████████| 153431/153431 [00:03<00:00, 41105.57it/s]


Epoch 46, Loss: 0.7839, Approx. Train: 0.8238


Epoch 47: 100%|██████████| 153431/153431 [00:03<00:00, 39231.45it/s]


Epoch 47, Loss: 0.7807, Approx. Train: 0.8249


Epoch 48: 100%|██████████| 153431/153431 [00:04<00:00, 35896.82it/s]


Epoch 48, Loss: 0.7835, Approx. Train: 0.8258


Epoch 49: 100%|██████████| 153431/153431 [00:03<00:00, 38424.87it/s]


Epoch 49, Loss: 0.7726, Approx. Train: 0.8258
testtest


Evaluating: 100%|██████████| 698895/698895 [00:29<00:00, 23380.11it/s]


Epoch: 49, Train: 0.8395, Val: 0.8536, Test: 0.8526


In [10]:
print(torch.cuda.memory_summary())
for i, conv in enumerate(model.convs):
    print(conv)

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 3         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    1073 MB |    4387 MB |   11728 GB |   11727 GB |
|       from large pool |    1073 MB |    4386 MB |   11703 GB |   11702 GB |
|       from small pool |       0 MB |       2 MB |      25 GB |      25 GB |
|---------------------------------------------------------------------------|
| Active memory         |    1073 MB |    4387 MB |   11728 GB |   11727 GB |
|       from large pool |    1073 MB |    4386 MB |   11703 GB |   11702 GB |
|       from small pool |       0 MB |       2 MB |      25 GB |      25 GB |
|---------------------------------------------------------------

In [16]:
len(model.convs)

3

0
