-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
65 lines (54 loc) · 2.17 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_scatter import scatter_sum
class GATNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(GATNet, self).__init__()
self.conv1 = GATConv(in_channels, 8, heads=8)
self.conv2 = GATConv(64, 8, heads=8)
self.conv3 = GATConv(64, 8, heads=8)
self.conv4 = GATConv(64, out_channels, heads=4, concat=False)
self.agg = nn.Linear(out_channels, out_channels, bias=False)
self.transform = nn.Linear(out_channels, out_channels, bias=False)
self.out = nn.Linear(out_channels * 2, 1)
def forward(self, batch, decision_var_idxes):
function_idxes = []
s = 0
flag = []
for i in range(batch.num_graphs):
data = batch.get_example(i)
function_idxes += [j + s for j in data.function_idx]
flag += [i] * len(data.function_idx)
decision_var_idxes[i] += s
s += data.x.shape[0]
flag = torch.tensor(flag, device=batch.x.device)
x = self.conv1(batch.x, batch.edge_index)
x = F.elu(x)
x = self.conv2(x, batch.edge_index)
x = F.elu(x)
x = self.conv3(x, batch.edge_index)
x = F.elu(x)
x = F.elu(self.conv4(x, batch.edge_index))
agg = x[function_idxes]
agg = scatter_sum(agg, flag, dim=0)
agg = self.agg(agg)
trans = x[decision_var_idxes, :]
trans = self.transform(trans)
return self.out(F.elu(torch.cat([agg, trans], dim=1)))
@torch.no_grad()
def inference(self, x, edge_index, decision_var_idxes, function_idx, functions=None, dv=None):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = self.conv2(x, edge_index)
x = F.elu(x)
x = self.conv3(x, edge_index)
x = F.elu(x)
x = F.elu(self.conv4(x, edge_index))
agg = x[function_idx].sum(dim=0)
agg = self.agg(agg)
agg = agg.repeat((len(decision_var_idxes), 1))
trans = x[decision_var_idxes, :]
trans = self.transform(trans)
return self.out(F.elu(torch.cat([agg, trans], dim=1)))