forked from asappresearch/sru
/
ops.py
298 lines (258 loc) · 10.1 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from typing import List, Optional
import os
import warnings
import torch
from torch import Tensor
from torch.utils.cpp_extension import load
# JIT compilation of elementwise fwd operator (CPU version)
cpu_source = os.path.join(os.path.dirname(__file__), "csrc", "sru_cpu_impl.cpp")
load(
name="sru_cpu",
sources=[cpu_source],
extra_cflags=['-O3'],
is_python_module=False,
verbose=False
)
def apex_amp_sru_compute_gpu_fp16(*args, **kwargs):
return _apex_amp_sru_compute_gpu(*args, **kwargs)
def apex_amp_sru_compute_gpu_fp32(*args, **kwargs):
return _apex_amp_sru_compute_gpu(*args, **kwargs)
def _apex_amp_sru_compute_gpu(*args, **kwargs):
# Will already have been imported and cached at this point
from .cuda_functional import SRU_Compute_GPU
return SRU_Compute_GPU.apply(*args, **kwargs)
try:
from apex import amp
APEX_AMP_AVAILABLE = True
import sys
current_module = sys.modules[__name__]
# TODO : remove debug statement
print("SRU: OPS: APEX AMP available, registering apex_amp_sru_compute_gpu for different Tensor precision types ...")
amp.register_half_function(current_module, "apex_amp_sru_compute_gpu_fp16") # Will cast input arguments to FP16
amp.register_float_function(current_module, "apex_amp_sru_compute_gpu_fp32") # Will cast input arguments to FP32
except ImportError:
APEX_AMP_AVAILABLE = False
@torch.jit.script
def elementwise_recurrence_cpu(U: Tensor,
x: Tensor,
weight_c: Tensor,
bias: Tensor,
c_init: Tensor,
activation_type: int,
hidden_size: int,
bidirectional: bool,
has_skip_term: bool,
scale_x: Optional[Tensor] = None,
dropout_mask_c: Optional[Tensor] = None,
mask_pad: Optional[Tensor] = None) -> List[Tensor]:
"""Elementwise forward operation of SRU on CPU.
"""
assert dropout_mask_c is None, "Dropout mask cannot be set during inference"
bidir = 2 if bidirectional else 1
length = x.size(0) if x.dim() == 3 else 1
batch = x.size(-2)
k = U.size(-1) // hidden_size // bidir
is_custom = weight_c.dim() > 1
mask_pad = None if mask_pad is None else mask_pad.float().contiguous()
if not bidirectional:
return torch.ops.sru_cpu.cpu_forward(
U.contiguous(),
x.contiguous(),
weight_c.contiguous(),
bias.contiguous(),
c_init.contiguous(),
mask_pad,
length,
batch,
hidden_size,
k,
activation_type,
has_skip_term,
scale_x.item() if scale_x is not None else 1.0,
is_custom
)
else:
return torch.ops.sru_cpu.cpu_bi_forward(
U.contiguous(),
x.contiguous(),
weight_c.contiguous(),
bias.contiguous(),
c_init.contiguous(),
mask_pad,
length,
batch,
hidden_size,
k,
activation_type,
has_skip_term,
scale_x.item() if scale_x is not None else 1.0,
is_custom
)
@torch.jit.unused
def elementwise_recurrence_gpu(U: Tensor,
x: Tensor,
weight_c: Tensor,
bias: Tensor,
c_init: Tensor,
activation_type: int,
hidden_size: int,
bidirectional: bool,
has_skip_term: bool,
scale_x: Optional[Tensor] = None,
dropout_mask_c: Optional[Tensor] = None,
mask_pad: Optional[Tensor] = None,
amp_recurrence_fp16: bool = False) -> List[Tensor]:
"""Elementwise forward operation of SRU on GPU.
"""
# Imported and cached her for the first time, retrieved from cache in _apex_amp_sru_compute_gpu
from .cuda_functional import SRU_Compute_GPU
in_autocast = getattr(torch, 'is_autocast_enabled', lambda: False)()
if in_autocast:
with torch.cuda.amp.autocast(enabled=False):
cast = torch.Tensor.half if amp_recurrence_fp16 else torch.Tensor.float
U = cast(U)
x = cast(x)
weight_c = cast(weight_c)
bias = cast(bias)
c_init = cast(c_init)
scale_x = cast(scale_x) if scale_x is not None else scale_x
dropout_mask_c = cast(dropout_mask_c) if dropout_mask_c is not None else dropout_mask_c
return SRU_Compute_GPU.apply(
U,
x,
weight_c,
bias,
c_init,
activation_type,
hidden_size,
bidirectional,
has_skip_term,
scale_x,
dropout_mask_c,
mask_pad
)
elif APEX_AMP_AVAILABLE:
apex_amp_sru_compute_gpu = \
apex_amp_sru_compute_gpu_fp16 if amp_recurrence_fp16 else apex_amp_sru_compute_gpu_fp32
return apex_amp_sru_compute_gpu(
U,
x,
weight_c,
bias,
c_init,
activation_type,
hidden_size,
bidirectional,
has_skip_term,
scale_x,
dropout_mask_c,
mask_pad
)
else:
return SRU_Compute_GPU.apply(
U,
x,
weight_c,
bias,
c_init,
activation_type,
hidden_size,
bidirectional,
has_skip_term,
scale_x,
dropout_mask_c,
mask_pad
)
@torch.jit.unused
def elementwise_recurrence_naive(U: Tensor,
x: Tensor,
weight_c: Tensor,
bias: Tensor,
c_init: Tensor,
activation_type: int,
hidden_size: int,
bidirectional: bool,
has_skip_term: bool,
scale_x: Optional[Tensor] = None,
dropout_mask_c: Optional[Tensor] = None,
mask_pad: Optional[Tensor] = None) -> List[Tensor]:
"""Elementwise forward operation of SRU in pure Python.
"""
if torch.is_grad_enabled():
warnings.warn("Running SRU on CPU with grad_enabled=True. Are you sure?")
else:
return elementwise_recurrence_cpu(U, x, weight_c, bias, c_init,
activation_type, hidden_size,
bidirectional, has_skip_term,
scale_x, dropout_mask_c, mask_pad)
bidir = 2 if bidirectional else 1
length = x.size(0) if x.dim() == 3 else 1
batch = x.size(-2)
k = U.size(-1) // hidden_size // bidir
d = hidden_size
is_custom = weight_c.dim() > 1
U = U.contiguous().view(length, batch, bidir, d, k)
if is_custom:
weight_c = weight_c.view(length, batch, bidir, d, 2)
forget_wc = weight_c[..., 0]
reset_wc = weight_c[..., 1]
else:
forget_wc, reset_wc = weight_c.view(2, bidir, d)
forget_bias, reset_bias = bias.view(2, bidir, d)
if not has_skip_term:
x_prime = None
elif k == 3:
x_prime = x.view(length, batch, bidir, d)
x_prime = x_prime * scale_x if scale_x is not None else x_prime
else:
x_prime = U[..., 3]
if c_init is None:
c_init = x.new_zeros(size=(batch, bidir, d))
else:
c_init = c_init.view(batch, bidir, d)
mask_pad = mask_pad.view(length, batch, 1).float() if mask_pad is not None else None
mask_c = dropout_mask_c.view(batch, bidir, d) if dropout_mask_c is not None else None
h = x.new_zeros(length, batch, bidir, d)
c_final = []
for di in range(bidir):
time_seq = range(length) if di == 0 else range(length - 1, -1, -1)
mask_c_ = 1 if mask_c is None else mask_c[:, di, :]
c_prev = c_init[:, di, :]
fb, rb = forget_bias[di], reset_bias[di]
if is_custom:
fw = forget_wc[:, :, di, :].chunk(length)
rw = reset_wc[:, :, di, :].chunk(length)
else:
fw = forget_wc[di].expand(batch, d) # type: ignore
rw = reset_wc[di].expand(batch, d) # type: ignore
u0 = U[:, :, di, :, 0].chunk(length)
u1 = (U[:, :, di, :, 1] + fb).chunk(length)
u2 = (U[:, :, di, :, 2] + rb).chunk(length)
if x_prime is not None:
xp = x_prime[:, :, di, :].chunk(length)
for t in time_seq:
if is_custom:
forget_t = (u1[t] + c_prev*fw[t]).sigmoid()
reset_t = (u2[t] + c_prev*rw[t]).sigmoid()
else:
forget_t = (u1[t] + c_prev*fw).sigmoid()
reset_t = (u2[t] + c_prev*rw).sigmoid()
c_t = u0[t] + (c_prev - u0[t]) * forget_t
if mask_pad is not None:
c_t = c_t * (1-mask_pad[t]) + c_prev * mask_pad[t]
c_prev = c_t
if activation_type == 0:
g_c_t = c_t
elif activation_type == 1:
g_c_t = c_t.tanh()
else:
raise ValueError('Activation type must be 0 or 1, not {}'.format(activation_type))
if x_prime is not None:
h_t = xp[t] + (g_c_t - xp[t]) * mask_c_ * reset_t
else:
h_t = g_c_t * mask_c_ * reset_t
if mask_pad is not None:
h_t = h_t * (1-mask_pad[t])
h[t, :, di, :] = h_t
c_final.append(c_t.view(batch, d))
return h.view(length, batch, -1), torch.stack(c_final, dim=1).view(batch, -1) # type: ignore