- Original Paper: https://arxiv.org/abs/1609.02907
- Original Code: https://github.com/rusty1s/pytorch_geometric/

In [18]:
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv

In [19]:
# Cora 데이터셋 불러오기
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/data/Cora', name='Cora', transform=T.NormalizeFeatures())
# 데이터셋 전체가 하나의 그래프임
print(f'그래프의 개수 : {len(dataset)} -> 데이터셋 전체가 하나의 그래프임')
print(f'그래프의 클래스 종류 : {dataset.num_classes} -> 노드의 클래스는 레이블 말하는듯?') 
print(f'노드의 특징 수 : {dataset.num_node_features}')

그래프의 개수 : 1 -> 데이터셋 전체가 하나의 그래프임
그래프의 클래스 종류 : 7 -> 노드의 클래스는 레이블 말하는듯?
노드의 특징 수 : 1433


In [20]:
data = dataset[0]
print(data)
num_node = data.x.shape[0]
num_node_features = data.x.shape[1]
print(f'\nnode 수 : {num_node}, node features 수 : {num_node_features}')
print(f"Edge Index Shape: {data.edge_index.shape}")
print(f"Edge Weight: {data.edge_attr}")


Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

node 수 : 2708, node features 수 : 1433
Edge Index Shape: torch.Size([2, 10556])
Edge Weight: None


In [21]:
print(data.keys)
print('학습할 노드 개수 : ',data.train_mask.sum().item())

['x', 'test_mask', 'edge_index', 'val_mask', 'y', 'train_mask']
학습할 노드 개수 :  140


In [22]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels=dataset.num_features, out_channels=16, cached=True, normalize=True) # 16이거는 그냥 사용자 지정값인듯?
        self.conv2 = GCNConv(in_channels=16, out_channels=dataset.num_classes, cached=True, normalize=True)
        
    def forward(self):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_weight))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        return F.log_softmax(x, dim=1)

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = GCN().to(device), data.to(device)

optimizer = torch.optim.Adam([
dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0)
], lr=0.01)

In [24]:
def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(
        input=model()[data.train_mask], target=data.y[data.train_mask]).backward()
    optimizer.step()
    
@torch.no_grad()
def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

In [25]:
best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}')

Epoch: 001, Train: 0.2857, Val: 0.2000, Test: 0.2220
Epoch: 002, Train: 0.4786, Val: 0.2840, Test: 0.2920
Epoch: 003, Train: 0.6357, Val: 0.4860, Test: 0.5030
Epoch: 004, Train: 0.6786, Val: 0.5000, Test: 0.5430
Epoch: 005, Train: 0.6857, Val: 0.5260, Test: 0.5530
Epoch: 006, Train: 0.7429, Val: 0.5740, Test: 0.5770
Epoch: 007, Train: 0.8143, Val: 0.6080, Test: 0.6260
Epoch: 008, Train: 0.8500, Val: 0.6460, Test: 0.6700
Epoch: 009, Train: 0.8929, Val: 0.6780, Test: 0.7010
Epoch: 010, Train: 0.9143, Val: 0.7220, Test: 0.7370
Epoch: 011, Train: 0.9357, Val: 0.7420, Test: 0.7550
Epoch: 012, Train: 0.9429, Val: 0.7580, Test: 0.7550
Epoch: 013, Train: 0.9571, Val: 0.7480, Test: 0.7440
Epoch: 014, Train: 0.9714, Val: 0.7400, Test: 0.7420
Epoch: 015, Train: 0.9714, Val: 0.7420, Test: 0.7420
Epoch: 016, Train: 0.9714, Val: 0.7380, Test: 0.7350
Epoch: 017, Train: 0.9714, Val: 0.7360, Test: 0.7370
Epoch: 018, Train: 0.9786, Val: 0.7420, Test: 0.7400
Epoch: 019, Train: 0.9714, Val: 0.7520, Test: 