-
Notifications
You must be signed in to change notification settings - Fork 34
/
model.py
224 lines (182 loc) · 8.76 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import torch
from torch import nn
import torch.nn.functional as F
class nconv(nn.Module):
def __init__(self):
super(nconv,self).__init__()
def forward(self,x, A):
A = A.to(x.device)
if len(A.shape) == 3:
x = torch.einsum('ncvl,nvw->ncwl',(x,A))
else:
x = torch.einsum('ncvl,vw->ncwl',(x,A))
return x.contiguous()
class linear(nn.Module):
def __init__(self,c_in,c_out):
super(linear,self).__init__()
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=True)
def forward(self,x):
return self.mlp(x)
class gcn(nn.Module):
def __init__(self,c_in,c_out,dropout,support_len=3,order=2):
super(gcn,self).__init__()
self.nconv = nconv()
c_in = (order*support_len+1)*c_in
self.mlp = linear(c_in,c_out)
self.dropout = dropout
self.order = order
def forward(self,x,support):
out = [x]
for a in support:
x1 = self.nconv(x,a)
out.append(x1)
for k in range(2, self.order + 1):
x2 = self.nconv(x1,a)
out.append(x2)
x1 = x2
h = torch.cat(out,dim=1)
h = self.mlp(h)
h = F.dropout(h, self.dropout, training=self.training)
return h
class GraphWaveNet(nn.Module):
"""
Paper: Graph WaveNet for Deep Spatial-Temporal Graph Modeling.
Link: https://arxiv.org/abs/1906.00121
Ref Official Code: https://github.com/nnzhan/Graph-WaveNet/blob/master/model.py
"""
def __init__(self, num_nodes, support_len, dropout=0.3, gcn_bool=True, addaptadj=True, aptinit=None, in_dim=2,out_dim=12,residual_channels=32,dilation_channels=32,skip_channels=256,end_channels=512,kernel_size=2,blocks=4,layers=2, **kwargs):
"""
kindly note that although there is a 'supports' parameter, we will not use the prior graph if there is a learned dependency graph.
Details can be found in the feed forward function.
"""
super(GraphWaveNet, self).__init__()
self.dropout = dropout
self.blocks = blocks
self.layers = layers
self.gcn_bool = gcn_bool
self.addaptadj = addaptadj
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.bn = nn.ModuleList()
self.gconv = nn.ModuleList()
self.fc_his = nn.Sequential(nn.Linear(96, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU())
self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=residual_channels, kernel_size=(1,1))
receptive_field = 1
self.supports_len = support_len
if gcn_bool and addaptadj:
if aptinit is None:
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes), requires_grad=True)
self.supports_len +=1
else:
m, p, n = torch.svd(aptinit)
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True)
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True)
self.supports_len += 1
for b in range(blocks):
additional_scope = kernel_size - 1
new_dilation = 1
for i in range(layers):
# dilated convolutions
self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1,kernel_size),dilation=new_dilation))
self.gate_convs.append(nn.Conv1d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation))
# 1x1 convolution for residual connection
self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1)))
# 1x1 convolution for skip connection
self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1)))
self.bn.append(nn.BatchNorm2d(residual_channels))
new_dilation *= 2
receptive_field += additional_scope
additional_scope *= 2
if self.gcn_bool:
self.gconv.append(gcn(dilation_channels,residual_channels,dropout,support_len=self.supports_len))
self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, out_channels=end_channels, kernel_size=(1,1), bias=True)
self.end_conv_2 = nn.Conv2d(in_channels=end_channels, out_channels=out_dim, kernel_size=(1,1), bias=True)
self.receptive_field = receptive_field
def _calculate_random_walk_matrix(self, adj_mx):
B, N, N = adj_mx.shape
adj_mx = adj_mx + torch.eye(int(adj_mx.shape[1])).unsqueeze(0).expand(B, N, N).to(adj_mx.device)
d = torch.sum(adj_mx, 2)
d_inv = 1. / d
d_inv = torch.where(torch.isinf(d_inv), torch.zeros(d_inv.shape).to(adj_mx.device), d_inv)
d_mat_inv = torch.diag_embed(d_inv)
random_walk_mx = torch.bmm(d_mat_inv, adj_mx)
return random_walk_mx
def forward(self, input, hidden_states, sampled_adj):
"""feed forward of Graph WaveNet.
Args:
input (torch.Tensor): input history MTS with shape [B, L, N, C].
His (torch.Tensor): the output of TSFormer of the last patch (segment) with shape [B, N, d].
adj (torch.Tensor): the learned discrete dependency graph with shape [B, N, N].
Returns:
torch.Tensor: prediction with shape [B, N, L]
"""
# reshape input: [B, L, N, C] -> [B, C, N, L]
input = input.transpose(1, 3)
# feed forward
input = nn.functional.pad(input,(1,0,0,0))
input = input[:, :2, :, :]
in_len = input.size(3)
if in_len<self.receptive_field:
x = nn.functional.pad(input,(self.receptive_field-in_len,0,0,0))
else:
x = input
x = self.start_conv(x)
skip = 0
#
# ====== if use learned adjacency matrix, then reset the self.supports ===== #
self.supports = [self._calculate_random_walk_matrix(sampled_adj), self._calculate_random_walk_matrix(sampled_adj.transpose(-1, -2))]
# calculate the current adaptive adj matrix
new_supports = None
if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp]
# WaveNet layers
for i in range(self.blocks * self.layers):
# |----------------------------------------| *residual*
# | |
# | |-- conv -- tanh --| |
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
# |-- conv -- sigm --| |
# 1x1
# |
# ---------------------------------------> + -------------> *skip*
#(dilation, init_dilation) = self.dilations[i]
#residual = dilation_func(x, dilation, init_dilation, i)
residual = x
# dilated convolution
filter = self.filter_convs[i](residual)
filter = torch.tanh(filter)
gate = self.gate_convs[i](residual)
gate = torch.sigmoid(gate)
x = filter * gate
# parametrized skip connection
s = x
s = self.skip_convs[i](s)
try:
skip = skip[:, :, :, -s.size(3):]
except:
skip = 0
skip = s + skip
if self.gcn_bool and self.supports is not None:
if self.addaptadj:
x = self.gconv[i](x, new_supports)
else:
x = self.gconv[i](x,self.supports)
else:
x = self.residual_convs[i](x)
x = x + residual[:, :, :, -x.size(3):]
x = self.bn[i](x)
hidden_states = self.fc_his(hidden_states) # B, N, D
hidden_states = hidden_states.transpose(1, 2).unsqueeze(-1)
skip = skip + hidden_states
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
# reshape output: [B, P, N, 1] -> [B, N, P]
x = x.squeeze(-1).transpose(1, 2)
return x