-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
autoencoder.py
109 lines (83 loc) · 3.61 KB
/
autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import os.path as osp
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GAE, VGAE, GCNConv
parser = argparse.ArgumentParser()
parser.add_argument('--variational', action='store_true')
parser.add_argument('--linear', action='store_true')
parser.add_argument('--dataset', type=str, default='Cora',
choices=['Cora', 'CiteSeer', 'PubMed'])
parser.add_argument('--epochs', type=int, default=400)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
split_labels=True, add_negative_train_samples=False),
])
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, args.dataset, transform=transform)
train_data, val_data, test_data = dataset[0]
class GCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv2 = GCNConv(2 * out_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
class VariationalGCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv_mu = GCNConv(2 * out_channels, out_channels)
self.conv_logstd = GCNConv(2 * out_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
class LinearEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = GCNConv(in_channels, out_channels)
def forward(self, x, edge_index):
return self.conv(x, edge_index)
class VariationalLinearEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_mu = GCNConv(in_channels, out_channels)
self.conv_logstd = GCNConv(in_channels, out_channels)
def forward(self, x, edge_index):
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
in_channels, out_channels = dataset.num_features, 16
if not args.variational and not args.linear:
model = GAE(GCNEncoder(in_channels, out_channels))
elif not args.variational and args.linear:
model = GAE(LinearEncoder(in_channels, out_channels))
elif args.variational and not args.linear:
model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
elif args.variational and args.linear:
model = VGAE(VariationalLinearEncoder(in_channels, out_channels))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
z = model.encode(train_data.x, train_data.edge_index)
loss = model.recon_loss(z, train_data.pos_edge_label_index)
if args.variational:
loss = loss + (1 / train_data.num_nodes) * model.kl_loss()
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test(data):
model.eval()
z = model.encode(data.x, data.edge_index)
return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
for epoch in range(1, args.epochs + 1):
loss = train()
auc, ap = test(test_data)
print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')