In [None]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-geometric

Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html


In [None]:
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from tqdm import tqdm
from torch_geometric.datasets import Planetoid
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import SAGEConv

In [None]:
dataset = Planetoid(root="data/Planetoid", name="Cora", transform=NormalizeFeatures())
data = dataset[0]

cluster_data = ClusterData(data, num_parts=200, recursive=False,
                           save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data, batch_size=5, shuffle=True,
                             num_workers=12)

subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=128,
                                  shuffle=False, num_workers=12)

cluster_data, train_loader, subgraph_loader

  cpuset_checked))


(ClusterData(
   data=Data(adj=[2708, 2708, nnz=10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708]),
   num_parts=200
 ),
 <torch_geometric.data.cluster.ClusterLoader at 0x7feb68ac0f10>,
 NeighborSampler(sizes=[-1]))

In [None]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.convs = ModuleList(
            [SAGEConv(in_channels, 128),
             SAGEConv(128, out_channels)]
        )
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, 0.5, self.training)
        return F.log_softmax(x, dim=-1)

    def inference(self, x_all):
        pbar = tqdm(total=x_all.size(0) * len(self.convs))
        pbar.set_description("Evaluating...")

        for i, conv in enumerate(self.convs):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = conv((x, x_target), edge_index)
                if i != len(self.convs) - 1:
                    x = F.relu(x)
                xs.append(x.cpu())

                pbar.update(batch_size)

            x_all = torch.cat(xs, dim=0)

        pbar.close()

        return x_all

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [None]:
def train():
    model.train()

    total_loss = total_nodes = 0
    for batch in train_loader:
        nodes = batch.train_mask.sum().item()
        if nodes == 0:
            continue

        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * nodes
        total_nodes += nodes
        
    return total_loss / total_nodes

In [None]:
@torch.no_grad()
def test():
    model.eval()

    out = model.inference(data.x)
    y_pred = out.argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = y_pred[mask].eq(data.y[mask]).sum().item()
        accs.append(correct / mask.sum().item())
        
    return accs

In [None]:
for epoch in range(10):
    loss = train()
    if (epoch + 1) % 2 == 0:
        train_acc, val_acc, test_acc = test()
        print()
        print(f"Epoch {epoch + 1} Loss {loss:.6f}")
        print(f"Train Acc {train_acc:.4f} Val Acc {val_acc:.4f} Test {test_acc:.4f}")
    else:
        print(f"Epoch {epoch + 1} Loss {loss:.6f}")

  cpuset_checked))


Epoch 1 Loss 0.003247


Evaluating...: 100%|██████████| 5416/5416 [00:01<00:00, 3687.02it/s]


Epoch 2 Loss 0.002976
Train Acc 1.0000 Val Acc 0.7400 Test 0.7430





Epoch 3 Loss 0.004575


Evaluating...: 100%|██████████| 5416/5416 [00:01<00:00, 3693.93it/s]


Epoch 4 Loss 0.002680
Train Acc 1.0000 Val Acc 0.7320 Test 0.7510





Epoch 5 Loss 0.001561


Evaluating...: 100%|██████████| 5416/5416 [00:01<00:00, 3707.57it/s]


Epoch 6 Loss 0.002849
Train Acc 1.0000 Val Acc 0.7320 Test 0.7500





Epoch 7 Loss 0.001750


Evaluating...: 100%|██████████| 5416/5416 [00:01<00:00, 3647.15it/s]


Epoch 8 Loss 0.002605
Train Acc 1.0000 Val Acc 0.7340 Test 0.7450





Epoch 9 Loss 0.003195


Evaluating...: 100%|██████████| 5416/5416 [00:01<00:00, 3720.09it/s]


Epoch 10 Loss 0.007478
Train Acc 1.0000 Val Acc 0.7460 Test 0.7540



