## Reference : cs224w GCN code

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import networkx as nx
import numpy as np
import torch.optim as optim

from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.transforms as T

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [2]:
class GNNStack(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNStack, self).__init__()
        self.convs = nn.ModuleList()
        self.convs.append(CustomConv(input_dim, hidden_dim))
        self.lns = nn.ModuleList()
        self.lns.append(nn.LayerNorm(hidden_dim))
        
        for i in range(2):
            self.convs.append(CustomConv(hidden_dim, hidden_dim))
        
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.Dropout(0.25),
            nn.Linear(hidden_dim, output_dim))
        self.dropout = 0.25
        self.num_layers = 3
   
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if data.num_node_features == 0:
            x = torch.ones(data.num_nodes, 1)
    
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            emb = x
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.post_mp(x)
    
        return emb, F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)

In [3]:
class CustomConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(CustomConv, self).__init__(aggr='add')  # "Add" aggregation. # mean, max 등등
        self.lin = nn.Linear(in_channels, out_channels)
        #self.lin_self = nn.Linear(in_channels, out_channels) # convolution

    def forward(self, x, edge_index):
        """
        Convolution을 위해서는 2가지가 필수적임.
        x has shape [N, in_channels] # feature matrix
        edge_index has shape [2, E] ==> connectivity ==> 2: (u, v)
        
        """


        # Add self-loops to the adjacency matrix. (A+I)
        # pyg_utils.add_self_loops(edge_index, num_nodes = x.size(0))  
        # neighbor 정보뿐만 아니라, 내 정보까지 add해야하므로 self-loops 추가! 
        
        # 지울수도 있다 !
        edge_index, _ = pyg_utils.add_self_loops(edge_index, num_nodes=x.size(0))

        # Transform node feature matrix.
        #self_x = self.lin_self(x) # B
        x = self.lin(x) # W
        
        
        # self_x: skip connection #compute message for all the nodes
        return self.propagate(edge_index, 
                                    size=(x.size(0), x.size(0)), x=x)

    def message(self, x_i, x_j, edge_index, size):
        # Compute messages
        # x_i is self-embedding
        # x_j has shape [E, out_channels] neighbor embedding

        row, col = edge_index
        deg = pyg_utils.degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return x_j
    
    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        F.normalize(aggr_out, p=2, dim=-1) # dim: 상황에 따라 알맞게 조정할 
        return aggr_out

In [4]:
def train(dataset):
    test_loader = loader = DataLoader(dataset, batch_size=64, shuffle=True)

    # build model
    model = GNNStack(max(dataset.num_node_features, 1), 32, dataset.num_classes)
    opt = optim.Adam(model.parameters(), lr=0.01)
    
    # train
    for epoch in range(200):
        total_loss = 0
        model.train()
        for batch in loader:
            opt.zero_grad()

            emb, pred = model(batch)
            label = batch.y 
            pred = pred[batch.train_mask]
            label = label[batch.train_mask]
            
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        total_loss /= len(loader.dataset)
        
        if epoch % 10 == 0:
            test_acc = test(test_loader, model)
            print("Epoch {}. Loss: {:.4f}. Test accuracy: {:.4f}".format(
                epoch, total_loss, test_acc))

    return model

In [5]:
def test(loader, model, is_validation = False):
    model.eval()
    
    correct = 0
    for data in loader:
        with torch.no_grad():
            emb, pred = model(data)
            pred = pred.argmax(dim=1)
            label = data.y
            
        mask = data.val_mask if is_validation else data.test_mask
        pred = pred[mask]
        label = data.y[mask]
        correct += pred.eq(label).sum().item()
        
    total = 0
    for data in loader.dataset:
        total += torch.sum(data.test_mask).item()
    
    return correct/total


In [6]:
dataset = Planetoid(root='/tmp/cora', name='cora')

model = train(dataset)
print(model)

Epoch 0. Loss: 2.6875. Test accuracy: 0.1580
Epoch 10. Loss: 1.3669. Test accuracy: 0.5860
Epoch 20. Loss: 0.9079. Test accuracy: 0.6340
Epoch 30. Loss: 0.6527. Test accuracy: 0.7010
Epoch 40. Loss: 0.3722. Test accuracy: 0.6870
Epoch 50. Loss: 0.2494. Test accuracy: 0.6670
Epoch 60. Loss: 0.4162. Test accuracy: 0.6620
Epoch 70. Loss: 0.4672. Test accuracy: 0.6830
Epoch 80. Loss: 0.1396. Test accuracy: 0.6890
Epoch 90. Loss: 0.1288. Test accuracy: 0.6750
Epoch 100. Loss: 0.0905. Test accuracy: 0.6840
Epoch 110. Loss: 0.1545. Test accuracy: 0.6850
Epoch 120. Loss: 0.0507. Test accuracy: 0.7070
Epoch 130. Loss: 0.0668. Test accuracy: 0.6710
Epoch 140. Loss: 0.0416. Test accuracy: 0.7100
Epoch 150. Loss: 0.0260. Test accuracy: 0.7150
Epoch 160. Loss: 0.0555. Test accuracy: 0.7070
Epoch 170. Loss: 0.0211. Test accuracy: 0.6610
Epoch 180. Loss: 0.1674. Test accuracy: 0.6990
Epoch 190. Loss: 0.0153. Test accuracy: 0.6830
GNNStack(
  (convs): ModuleList(
    (0): CustomConv(
      (lin): Line