In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric import edge_index
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling

In [None]:
dataset = Planetoid( root = '~/tmp/Cora', name = 'Cora' )
data = dataset[0]

negative_sampling -> 随机生成不在图的真实边中的负样本，学习不存在的边提高模型的区分能力

total_edge_index[0] 和 total_edge_index[1] 分别是每条边的两个端点
x[total_edge_index[0]] 和 x[total_edge_index[1]] 分别是边两端点的特征
torch.cat([...], dim=1): 将每条边的两端点特征拼接，得到边特征

In [None]:
class GNN( nn.Module ):
    def __init__( self, in_channels, out_channels ):
        super( GNN, self ).__init__()
        self.conv = GCNConv( in_channels, out_channels )
        self.fc = nn.Linear( 2 * out_channels, 2 )              # 边的特征值 -> *2
        
    def forward( self, x, edge_index ):
        x = self.conv( x, edge_index )
        x = F.relu( x )
        pos_edge_index = edge_index                             # 正样本边索引 -> 真实边
        total_edge_index = torch.cat( [ pos_edge_index, 
                                        negative_sampling( edge_index, num_neg_samples = pos_edge_index.size(1) ) ],
                                        dim = 1 )
        edge_features = torch.cat( [x[ total_edge_index[0] ], x[ total_edge_index[1] ]], dim = 1 )
        print( total_edge_index.size() )
        return self.fc( edge_features )
    

In [None]:
net = GNN( dataset.num_features, 64 )
optimizer = torch.optim.Adam( net.parameters(), lr = 0.01, weight_decay = 5e-4 )

In [None]:
data.edge_index.size(1)

In [None]:
for epoch in range( 1000 ):
    net.train()
    optimizer.zero_grad()
    logits = net( data.x, data.edge_index )
    pos_edge_index = data.edge_index
    pos_labels = torch.ones( pos_edge_index.size(1), dtype=torch.long )
    neg_labels = torch.zeros( pos_edge_index.size(1), dtype=torch.long )
    labels = torch.cat( [pos_labels, neg_labels], dim = 0 )
    new_train_mask = torch.cat( [data.train_mask, data.train_mask], dim = 0 )
    loss = F.cross_entropy( logits[ new_train_mask ], labels[ new_train_mask ] )
    loss.backward()
    optimizer.step()
    print( loss.item() )
    
    with torch.no_grad():
        net.eval()
        logits = net( data.x, data.edge_index )
        pos_edge_index = data.edge_index
        pos_labels = torch.ones( pos_edge_index.size(1), dtype = torch.long )
        neg_labels = torch.zeros( pos_edge_index.size(1), dtype = torch.long )
        labels = torch.cat( [pos_labels, neg_labels], dim = 0 )
        new_test_mask = torch.cat( [ data.test_mask, data.test_mask ], dim = 0 )
        

In [None]:
'''
pos_edge_index = data.edge_index
pos_labels = torch.ones(pos_edge_index.size(1), dtype=torch.long)
neg_labels = torch.zeros(pos_edge_index.size(1), dtype=torch.long)
labels = torch.cat([pos_labels, neg_labels], dim=0)
labels
'''