-
-
Notifications
You must be signed in to change notification settings - Fork 166
/
model.py
290 lines (245 loc) · 11.4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import pickle
import random
from time import time
from typing import Union
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn.functional import mse_loss, l1_loss, binary_cross_entropy, cross_entropy
from torch.optim import Optimizer
class NBeatsNet(nn.Module):
SEASONALITY_BLOCK = 'seasonality'
TREND_BLOCK = 'trend'
GENERIC_BLOCK = 'generic'
def __init__(self,
device=torch.device('cpu'),
stack_types=(TREND_BLOCK, SEASONALITY_BLOCK),
nb_blocks_per_stack=3,
forecast_length=5,
backcast_length=10,
thetas_dim=(4, 8),
share_weights_in_stack=False,
hidden_layer_units=256,
nb_harmonics=None):
super(NBeatsNet, self).__init__()
self.forecast_length = forecast_length
self.backcast_length = backcast_length
self.hidden_layer_units = hidden_layer_units
self.nb_blocks_per_stack = nb_blocks_per_stack
self.share_weights_in_stack = share_weights_in_stack
self.nb_harmonics = nb_harmonics
self.stack_types = stack_types
self.stacks = []
self.thetas_dim = thetas_dim
self.parameters = []
self.device = device
print('| N-Beats')
for stack_id in range(len(self.stack_types)):
self.stacks.append(self.create_stack(stack_id))
self.parameters = nn.ParameterList(self.parameters)
self.to(self.device)
self._loss = None
self._opt = None
def create_stack(self, stack_id):
stack_type = self.stack_types[stack_id]
print(f'| -- Stack {stack_type.title()} (#{stack_id}) (share_weights_in_stack={self.share_weights_in_stack})')
blocks = []
for block_id in range(self.nb_blocks_per_stack):
block_init = NBeatsNet.select_block(stack_type)
if self.share_weights_in_stack and block_id != 0:
block = blocks[-1] # pick up the last one when we share weights.
else:
block = block_init(self.hidden_layer_units, self.thetas_dim[stack_id],
self.device, self.backcast_length, self.forecast_length, self.nb_harmonics)
self.parameters.extend(block.parameters())
print(f' | -- {block}')
blocks.append(block)
return blocks
def save(self, filename: str):
torch.save(self, filename)
@staticmethod
def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
return torch.load(f, map_location, pickle_module, **pickle_load_args)
@staticmethod
def select_block(block_type):
if block_type == NBeatsNet.SEASONALITY_BLOCK:
return SeasonalityBlock
elif block_type == NBeatsNet.TREND_BLOCK:
return TrendBlock
else:
return GenericBlock
def compile(self, loss: str, optimizer: Union[str, Optimizer]):
if loss == 'mae':
loss_ = l1_loss
elif loss == 'mse':
loss_ = mse_loss
elif loss == 'cross_entropy':
loss_ = cross_entropy
elif loss == 'binary_crossentropy':
loss_ = binary_cross_entropy
else:
raise ValueError(f'Unknown loss name: {loss}.')
# noinspection PyArgumentList
if isinstance(optimizer, str):
if optimizer == 'adam':
opt_ = optim.Adam
elif optimizer == 'sgd':
opt_ = optim.SGD
elif optimizer == 'rmsprop':
opt_ = optim.RMSprop
else:
raise ValueError(f'Unknown opt name: {optimizer}.')
opt_ = opt_(lr=1e-4, params=self.parameters())
else:
opt_ = optimizer
self._opt = opt_
self._loss = loss_
def fit(self, x_train, y_train, validation_data=None, epochs=10, batch_size=32):
def split(arr, size):
arrays = []
while len(arr) > size:
slice_ = arr[:size]
arrays.append(slice_)
arr = arr[size:]
arrays.append(arr)
return arrays
for epoch in range(epochs):
x_train_list = split(x_train, batch_size)
y_train_list = split(y_train, batch_size)
assert len(x_train_list) == len(y_train_list)
shuffled_indices = list(range(len(x_train_list)))
random.shuffle(shuffled_indices)
self.train()
train_loss = []
timer = time()
for batch_id in shuffled_indices:
batch_x, batch_y = x_train_list[batch_id], y_train_list[batch_id]
self._opt.zero_grad()
_, forecast = self(torch.tensor(batch_x, dtype=torch.float).to(self.device))
loss = self._loss(forecast, squeeze_last_dim(torch.tensor(batch_y, dtype=torch.float).to(self.device)))
train_loss.append(loss.item())
loss.backward()
self._opt.step()
elapsed_time = time() - timer
train_loss = np.mean(train_loss)
test_loss = '[undefined]'
if validation_data is not None:
x_test, y_test = validation_data
self.eval()
_, forecast = self(torch.tensor(x_test, dtype=torch.float).to(self.device))
test_loss = self._loss(forecast, squeeze_last_dim(torch.tensor(y_test, dtype=torch.float))).item()
num_samples = len(x_train_list)
time_per_step = int(elapsed_time / num_samples * 1000)
print(f'Epoch {str(epoch + 1).zfill(len(str(epochs)))}/{epochs}')
print(f'{num_samples}/{num_samples} [==============================] - '
f'{int(elapsed_time)}s {time_per_step}ms/step - '
f'loss: {train_loss:.4f} - val_loss: {test_loss:.4f}')
def predict(self, x, return_backcast=False):
self.eval()
b, f = self(torch.tensor(x, dtype=torch.float).to(self.device))
b, f = b.detach().numpy(), f.detach().numpy()
if len(x.shape) == 3:
b = np.expand_dims(b, axis=-1)
f = np.expand_dims(f, axis=-1)
if return_backcast:
return b
return f
def forward(self, backcast):
backcast = squeeze_last_dim(backcast)
forecast = torch.zeros(size=(backcast.size()[0], self.forecast_length,)) # maybe batch size here.
for stack_id in range(len(self.stacks)):
for block_id in range(len(self.stacks[stack_id])):
b, f = self.stacks[stack_id][block_id](backcast)
backcast = backcast.to(self.device) - b
forecast = forecast.to(self.device) + f
return backcast, forecast
def squeeze_last_dim(tensor):
if len(tensor.shape) == 3 and tensor.shape[-1] == 1: # (128, 10, 1) => (128, 10).
return tensor[..., 0]
return tensor
def seasonality_model(thetas, t, device):
p = thetas.size()[-1]
assert p <= thetas.shape[1], 'thetas_dim is too big.'
p1, p2 = (p // 2, p // 2) if p % 2 == 0 else (p // 2, p // 2 + 1)
s1 = torch.tensor([np.cos(2 * np.pi * i * t) for i in range(p1)]).float() # H/2-1
s2 = torch.tensor([np.sin(2 * np.pi * i * t) for i in range(p2)]).float()
S = torch.cat([s1, s2])
return thetas.mm(S.to(device))
def trend_model(thetas, t, device):
p = thetas.size()[-1]
assert p <= 4, 'thetas_dim is too big.'
T = torch.tensor([t ** i for i in range(p)]).float()
return thetas.mm(T.to(device))
def linear_space(backcast_length, forecast_length):
ls = np.arange(-backcast_length, forecast_length, 1) / forecast_length
b_ls = ls[:backcast_length]
f_ls = ls[backcast_length:]
return b_ls, f_ls
class Block(nn.Module):
def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, share_thetas=False,
nb_harmonics=None):
super(Block, self).__init__()
self.units = units
self.thetas_dim = thetas_dim
self.backcast_length = backcast_length
self.forecast_length = forecast_length
self.share_thetas = share_thetas
self.fc1 = nn.Linear(backcast_length, units)
self.fc2 = nn.Linear(units, units)
self.fc3 = nn.Linear(units, units)
self.fc4 = nn.Linear(units, units)
self.device = device
self.backcast_linspace, self.forecast_linspace = linear_space(backcast_length, forecast_length)
if share_thetas:
self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)
else:
self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)
self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False)
def forward(self, x):
x = squeeze_last_dim(x)
x = F.relu(self.fc1(x.to(self.device)))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
return x
def __str__(self):
block_type = type(self).__name__
return f'{block_type}(units={self.units}, thetas_dim={self.thetas_dim}, ' \
f'backcast_length={self.backcast_length}, forecast_length={self.forecast_length}, ' \
f'share_thetas={self.share_thetas}) at @{id(self)}'
class SeasonalityBlock(Block):
def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, nb_harmonics=None):
if nb_harmonics:
super(SeasonalityBlock, self).__init__(units, nb_harmonics, device, backcast_length,
forecast_length, share_thetas=True)
else:
super(SeasonalityBlock, self).__init__(units, forecast_length, device, backcast_length,
forecast_length, share_thetas=True)
def forward(self, x):
x = super(SeasonalityBlock, self).forward(x)
backcast = seasonality_model(self.theta_b_fc(x), self.backcast_linspace, self.device)
forecast = seasonality_model(self.theta_f_fc(x), self.forecast_linspace, self.device)
return backcast, forecast
class TrendBlock(Block):
def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, nb_harmonics=None):
super(TrendBlock, self).__init__(units, thetas_dim, device, backcast_length,
forecast_length, share_thetas=True)
def forward(self, x):
x = super(TrendBlock, self).forward(x)
backcast = trend_model(self.theta_b_fc(x), self.backcast_linspace, self.device)
forecast = trend_model(self.theta_f_fc(x), self.forecast_linspace, self.device)
return backcast, forecast
class GenericBlock(Block):
def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, nb_harmonics=None):
super(GenericBlock, self).__init__(units, thetas_dim, device, backcast_length, forecast_length)
self.backcast_fc = nn.Linear(thetas_dim, backcast_length)
self.forecast_fc = nn.Linear(thetas_dim, forecast_length)
def forward(self, x):
# no constraint for generic arch.
x = super(GenericBlock, self).forward(x)
theta_b = F.relu(self.theta_b_fc(x))
theta_f = F.relu(self.theta_f_fc(x))
backcast = self.backcast_fc(theta_b) # generic. 3.3.
forecast = self.forecast_fc(theta_f) # generic. 3.3.
return backcast, forecast