In [1]:
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges

In [7]:
dataset = Planetoid("data", "CiteSeer", transform=T.NormalizeFeatures())
dataset.data

Data(edge_index=[2, 9104], test_mask=[3327], train_mask=[3327], val_mask=[3327], x=[3327, 3703], y=[3327])

In [8]:
data = dataset[0]

In [10]:
data.train_mask = data.test_mask = data.val_mask = None
data

Data(edge_index=[2, 9104], x=[3327, 3703], y=[3327])

In [11]:
data = train_test_split_edges(data)
data

Data(test_neg_edge_index=[2, 455], test_pos_edge_index=[2, 455], train_neg_adj_mask=[3327, 3327], train_pos_edge_index=[2, 7740], val_neg_edge_index=[2, 227], val_pos_edge_index=[2, 227], x=[3327, 3703], y=[3327])

In [15]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2*out_channels, cached=True)
        self.conv2 = GCNConv(2*out_channels, out_channels, cached=True)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

In [16]:
from torch_geometric.nn import GAE

In [18]:
# parameters
out_channels = 2
num_features = dataset.num_features
epochs = 100

# model
model = GAE(GCNEncoder(num_features, out_channels))

# move to GPU 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = data.x.to(device)
#train_pos_edge_index = data[data.train_pos_edge_index].to(device)
train_pos_edge_index = data.train_pos_edge_index.to(device)

# initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [19]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss.backward()
    optimizer.step()
    return float(loss)

def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return model.test(z, pos_edge_index, neg_edge_index)

In [22]:
for epoch in range(1, epochs + 1):
    loss = train()
    auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

Epoch: 001, AUC: 0.6488, AP: 0.6882
Epoch: 002, AUC: 0.6572, AP: 0.6965
Epoch: 003, AUC: 0.6671, AP: 0.7054
Epoch: 004, AUC: 0.6712, AP: 0.7102
Epoch: 005, AUC: 0.6725, AP: 0.7123
Epoch: 006, AUC: 0.6730, AP: 0.7141
Epoch: 007, AUC: 0.6729, AP: 0.7158
Epoch: 008, AUC: 0.6719, AP: 0.7162
Epoch: 009, AUC: 0.6717, AP: 0.7167
Epoch: 010, AUC: 0.6710, AP: 0.7167
Epoch: 011, AUC: 0.6703, AP: 0.7170
Epoch: 012, AUC: 0.6698, AP: 0.7177
Epoch: 013, AUC: 0.6693, AP: 0.7181
Epoch: 014, AUC: 0.6685, AP: 0.7188
Epoch: 015, AUC: 0.6681, AP: 0.7196
Epoch: 016, AUC: 0.6676, AP: 0.7202
Epoch: 017, AUC: 0.6673, AP: 0.7208
Epoch: 018, AUC: 0.6671, AP: 0.7216
Epoch: 019, AUC: 0.6666, AP: 0.7217
Epoch: 020, AUC: 0.6660, AP: 0.7218
Epoch: 021, AUC: 0.6655, AP: 0.7218
Epoch: 022, AUC: 0.6647, AP: 0.7216
Epoch: 023, AUC: 0.6646, AP: 0.7218
Epoch: 024, AUC: 0.6642, AP: 0.7216
Epoch: 025, AUC: 0.6642, AP: 0.7218
Epoch: 026, AUC: 0.6644, AP: 0.7221
Epoch: 027, AUC: 0.6651, AP: 0.7224
Epoch: 028, AUC: 0.6659, AP:

In [23]:
Z = model.encode(x, train_pos_edge_index)
Z

tensor([[ 0.4190,  0.4041],
        [-0.5143,  1.1552],
        [ 0.3302, -0.9620],
        ...,
        [-0.0927,  1.1178],
        [ 1.4214,  0.1656],
        [ 1.3833, -0.8801]], grad_fn=<AddBackward0>)

In [27]:
from torch.utils.tensorboard import SummaryWriter

In [28]:
writer = SummaryWriter('runs/GAE_experiment_'+'2d_100_epochs')

In [29]:
for epoch in range(1, epochs + 1):
    loss = train()
    auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('auc train', auc, epoch)
    writer.add_scalar('ap train', ap, epoch)

Epoch: 001, AUC: 0.8516, AP: 0.8474
Epoch: 002, AUC: 0.8518, AP: 0.8475
Epoch: 003, AUC: 0.8517, AP: 0.8475
Epoch: 004, AUC: 0.8519, AP: 0.8479
Epoch: 005, AUC: 0.8521, AP: 0.8482
Epoch: 006, AUC: 0.8522, AP: 0.8485
Epoch: 007, AUC: 0.8527, AP: 0.8489
Epoch: 008, AUC: 0.8529, AP: 0.8492
Epoch: 009, AUC: 0.8530, AP: 0.8494
Epoch: 010, AUC: 0.8529, AP: 0.8494
Epoch: 011, AUC: 0.8526, AP: 0.8495
Epoch: 012, AUC: 0.8521, AP: 0.8492
Epoch: 013, AUC: 0.8522, AP: 0.8494
Epoch: 014, AUC: 0.8526, AP: 0.8497
Epoch: 015, AUC: 0.8529, AP: 0.8498
Epoch: 016, AUC: 0.8530, AP: 0.8500
Epoch: 017, AUC: 0.8527, AP: 0.8499
Epoch: 018, AUC: 0.8526, AP: 0.8498
Epoch: 019, AUC: 0.8522, AP: 0.8497
Epoch: 020, AUC: 0.8522, AP: 0.8499
Epoch: 021, AUC: 0.8523, AP: 0.8501
Epoch: 022, AUC: 0.8523, AP: 0.8502
Epoch: 023, AUC: 0.8523, AP: 0.8501
Epoch: 024, AUC: 0.8524, AP: 0.8503
Epoch: 025, AUC: 0.8524, AP: 0.8504
Epoch: 026, AUC: 0.8521, AP: 0.8503
Epoch: 027, AUC: 0.8519, AP: 0.8503
Epoch: 028, AUC: 0.8519, AP: