This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
/
attention.py
1355 lines (1145 loc) · 50.9 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention Layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import random
import jax
import numpy as onp
from tensor2tensor.trax import backend
from tensor2tensor.trax.backend import numpy as np
from tensor2tensor.trax.layers import base
from tensor2tensor.trax.layers import combinators as cb
from tensor2tensor.trax.layers import core
from tensor2tensor.trax.layers import initializers as init
# Layers are always CamelCase, but functions in general are snake_case
# pylint: disable=invalid-name
@base.layer()
def ShiftRight(x, mode='train', **unused_kwargs):
"""Layer to shift the tensor to the right by padding on axis 1."""
if mode == 'predict':
# Do nothing in predict mode, as then the sequence length is 1.
return x
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[1] = (1, 0) # Padding on axis=1
padded = np.pad(x, pad_widths, mode='constant',
constant_values=x.dtype.type(0))
return padded[:, :-1]
@base.layer()
def CausalMask(x, params, axis=-1, **kwargs):
del params, kwargs
size = x.shape[axis]
return onp.tril(onp.ones((1, size, size), dtype=onp.bool_), k=0)
@base.layer()
def PaddingMask(x, params, pad=0, **kwargs):
del params, kwargs
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
@base.layer(n_inputs=2)
def EncoderDecoderMask(x, **unused_kwargs):
"""Makes encoder-decoder mask from decoder input and a padding mask."""
decoder_input, padding_mask = x
padding_mask = np.reshape(
padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
# Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
class PositionalEncoding(base.Layer):
"""Implements bare positional encoding."""
def __init__(self, max_len=2048, mode='train'):
super(PositionalEncoding, self).__init__()
self._max_len = max_len
self._mode = mode
def forward(self, inputs, params=(), state=(), **kwargs):
if self._mode in ('train', 'eval'):
x = inputs
symbol_size = np.shape(x)[1]
return (x + params[:, :symbol_size, :], state)
else:
assert self._mode == 'predict'
# Fast inference: return consectutive elements of the encoding sequence,
# storing the index in state.
return (inputs + np.expand_dims(params[:, state, :], 1), state + 1)
def new_params_and_state(self, input_shape, input_dtype, rng):
del input_dtype, rng
d_feature = input_shape[-1]
pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32)
position = onp.arange(0, self._max_len)[:, onp.newaxis]
div_term = onp.exp(
onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature))
pe[:, 0::2] = onp.sin(position * div_term)
pe[:, 1::2] = onp.cos(position * div_term)
pe = pe[onp.newaxis, :, :] # [1, self._max_len, d_feature]
params = np.array(pe) # These are trainable parameters, initialized above.
state = 0 if self._mode == 'predict' else ()
return params, state
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
"""Core dot product self-attention.
Args:
query: array of representations
key: array of representations
value: array of representations
mask: attention-mask, gates attention
dropout: float: dropout rate
mode: 'eval' or 'train': whether to use dropout
rng: JAX PRNGKey: subkey for disposable use
Returns:
Self attention for q, k, v arrays.
"""
depth = np.shape(query)[-1]
dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
if mask is not None:
# TODO(kitaev): workaround for https://github.com/google/jax/issues/850
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
mask = jax.lax.tie_in(dots, mask)
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
out = np.matmul(dots, value)
return out
@base.layer(n_inputs=4, n_outputs=2)
def PureAttention(x, params, n_heads=1, dropout=0.0, mode='train', **kwargs):
"""Pure transformer-style multi-headed attention.
Args:
x: inputs (q, k, v, mask)
params: parameters (none)
n_heads: int: number of attention heads
dropout: float: dropout rate
mode: str: 'train' or 'eval'
**kwargs: other arguments including the rng
Returns:
Pure Multi-headed attention result, and the mask.
"""
del params
rng = kwargs.get('rng', None)
q, k, v, mask = x
d_feature = q.shape[-1]
assert d_feature % n_heads == 0
d_head = d_feature // n_heads
nbatch = np.shape(q)[0]
# nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head
def SplitHeads(x):
return np.transpose(
np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))
# nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature
def JoinHeads(x): # pylint: disable=invalid-name
return np.reshape(
np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
# Split heads, dot-product attention, rejoin heads.
res = JoinHeads(
DotProductAttention(
SplitHeads(q), SplitHeads(k), SplitHeads(v), mask,
dropout=dropout, mode=mode, rng=rng))
return res, mask # Keep the mask.
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
"""Transformer-style multi-headed attention.
Accepts inputs of the form q, k, v, mask.
Args:
d_feature: int: dimensionality of feature embedding
n_heads: int: number of attention heads
dropout: float: dropout rate
mode: str: 'train' or 'eval'
Returns:
Multi-headed self-attention result and the mask.
"""
return [
cb.Parallel(
core.Dense(d_feature),
core.Dense(d_feature),
core.Dense(d_feature),
),
PureAttention( # pylint: disable=no-value-for-parameter
n_heads=n_heads, dropout=dropout, mode=mode),
core.Dense(d_feature),
]
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
"""Transformer-style multi-headed attention.
Accepts inputs of the form (x, mask) and constructs (q, k, v) from x.
Args:
d_feature: int: dimensionality of feature embedding
n_heads: int: number of attention heads
dropout: float: dropout rate
mode: str: 'train' or 'eval'
Returns:
Multi-headed self-attention result and the mask.
"""
return [
cb.Dup(), cb.Dup(),
AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
]
def BasicCausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'):
"""Transformer-style multi-headed causal attention.
This implementation is less configurable than the CausalAttention layer
defined below, but it shares code with the non-causal attention.
# TODO(jonni,lukaszkaiser): standardize and improve layer comments.
Accepts inputs of the form x and constructs (q, k, v) and causal mask from x.
Args:
d_feature: int: dimensionality of feature embedding
n_heads: int: number of attention heads
dropout: float: dropout rate
mode: str: 'train' or 'eval'
Returns:
Multi-headed self-attention result.
"""
return [
cb.Dup(),
cb.Parallel([], CausalMask(axis=-2)), # pylint: disable=no-value-for-parameter
Attention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
cb.Parallel([], cb.Drop()), # x
]
class ShiftRightLearned(base.Layer):
"""Layer constructor function for shifting right by a learned vector."""
def __init__(self, initializer=init.RandomNormalInitializer(0.01)):
super(ShiftRightLearned, self).__init__()
self._initializer = initializer
def forward(self, x, params=(), state=(), **kwargs):
del kwargs
c = backend.numpy.reshape(params, [1, 1, -1])
c += backend.numpy.zeros((x.shape[0], 1, x.shape[2]), dtype=x.dtype)
return backend.numpy.concatenate([c, x], axis=1)[:, :-1, :], state
def new_params_and_state(self, input_shape, input_dtype, rng):
del input_dtype
b = self._initializer((input_shape[-1],), rng)
return b, ()
class ComputeAttentionHeads(base.Layer):
"""Computes queries/keys/values via linear projection.
The output shape is (n_batch * n_heads, seqlen, d_head); the batch and head
dimensions are fused to allow for more efficient memory layouts.
"""
def __init__(self, n_heads=1, d_head=64,
kernel_initializer=init.GlorotUniformInitializer()):
super(ComputeAttentionHeads, self).__init__()
self._n_heads = n_heads
self._d_head = d_head
self._kernel_initializer = kernel_initializer
# The lack of a bias term here is consistent with the tensor2tensor
# implementation, and shouldn't have an effect on modeling quality.
# Note that AttentionQKV above is different in that it uses a bias term.
def forward(self, x, params=(), state=(), **kwargs):
del kwargs
seqlen = x.shape[1]
res = np.dot(x, params)
# n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head))
# n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
res = np.transpose(res, (0, 2, 1, 3))
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
res = np.reshape(res, (-1, seqlen, self._d_head))
return res, state
def new_params_and_state(self, input_shape, input_dtype, rng):
del input_dtype
w = self._kernel_initializer(
(input_shape[-1], self._n_heads * self._d_head), rng)
return w, ()
class ComputeAttentionOutput(base.Layer):
"""Joins outputs from different heads via linear projection."""
def __init__(self, n_heads=1, d_model=1024,
kernel_initializer=init.GlorotUniformInitializer()):
super(ComputeAttentionOutput, self).__init__()
self._n_heads = n_heads
self._d_model = d_model
self._kernel_initializer = kernel_initializer
# The lack of a bias term here is consistent with the tensor2tensor
# implementation, and shouldn't have an effect on modeling quality.
# Note that AttentionQKV above is different in that it uses a bias term.
def forward(self, x, params=(), state=(), **kwargs):
del kwargs
seqlen = x.shape[1]
d_head = x.shape[2]
x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head
x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))
return np.dot(x, params), state
def new_params_and_state(self, input_shape, input_dtype, rng):
del input_dtype
w = self._kernel_initializer(
(input_shape[-1] * self._n_heads, self._d_model), rng)
return w, ()
class BaseCausalAttention(base.Layer):
"""Base class for variants of causal self-attention."""
def __init__(self, mode='train'):
del mode
super(BaseCausalAttention, self).__init__(n_inputs=3)
def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
"""Forward pass for the attention layer."""
raise NotImplementedError()
def forward_and_backward(self, inputs, grad, **kwargs):
"""Performs both forward and backward pass for the attention layer.
This is used in reversible models: for the backward pass of a reversible
model, we need to compute both the forward direction (to recover the
previous layer's activations) and the backward direction simultaneously.
Some computation can be shared between the forward and backward directions,
which makes it more efficient to implement them jointly.
This method assumes that the layer is stateless and has no parameters.
Args:
inputs: A tuple (q, k, v), where each element has shape
n_batch*n_heads, seqlen, d_head
grad: gradient signal for the layer output.
**kwargs: kwargs for the layer
Returns:
A nested-tuple structure (output, (q_grad, k_grad, v_grad)) that contains
the output of the forward pass and the gradient signal for each input.
"""
raise NotImplementedError()
def _fast_inference_init_state(input_shapes, input_dtypes, buffer_length):
"""Initializes state of a causal attention layer for fast inference."""
((batch_size, _, _), _, _) = input_shapes
def init_buffer(shape, dtype):
(_, _, depth) = shape
return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
(_, k, v) = tuple(
init_buffer(shape, dtype)
for (shape, dtype) in zip(input_shapes, input_dtypes)
)
mask = np.zeros((batch_size, 1, buffer_length))
index = 0
state = (k, v, mask, index)
return state
def _fast_inference_update_state(inputs, state):
"""Updates state of a causal attention layer for fast inference."""
assert backend.get_name() == 'jax', (
'JAX backend is required to use the predict mode.')
for x in inputs:
assert x.shape[1] == 1, (
'In predict mode the input sequence must be of length 1.')
# Fast inference: run with only 1 query in each step, storing the sequence
# of keys and values calculated so far in state.
(_, new_k, new_v) = inputs
(ks, vs, mask, index) = state
ks = jax.ops.index_update(ks, jax.ops.index[:, index, :], new_k[:, 0, :])
vs = jax.ops.index_update(vs, jax.ops.index[:, index, :], new_v[:, 0, :])
mask = jax.ops.index_update(mask, jax.ops.index[:, :, index], 1)
return (ks, vs, mask, index + 1)
class DotProductCausalAttention(BaseCausalAttention):
"""A standard (non-memory-efficient) dot product attention implementation."""
def __init__(self, dropout=0.0, mode='train'):
super(DotProductCausalAttention, self).__init__()
self._dropout = dropout
self._mode = mode
def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
del params
q, k, v = inputs
if self._mode in ('train', 'eval'):
mask_size = q.shape[-2]
# Not all backends define np.tril. However, using onp.tril is inefficient
# in that it creates a large global constant. TODO(kitaev): try to find an
# alternative that works across all backends.
if backend.get_name() == 'jax':
mask = np.tril(
np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
else:
mask = onp.tril(
onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
else:
assert self._mode == 'predict'
state = _fast_inference_update_state(inputs, state)
(k, v, mask, _) = state
res = DotProductAttention(
q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng)
return res, state
def forward_and_backward(self, inputs, ct, **kwargs):
assert backend.get_name() == 'jax', (
'JAX backend is required to use forward_and_backward.')
# Simultaneous forward pass and backprop through the attention mechanism.
def _do_forward(x): # pylint: disable=invalid-name
res, _ = self.forward(x, **kwargs)
return res
output, vjpfun = jax.vjp(_do_forward, inputs)
return output, vjpfun(ct)[0]
def new_params_and_state(self, input_shapes, input_dtype, rng):
if self._mode in ('train', 'eval'):
return (), ()
assert self._mode == 'predict'
params = ()
# Buffer length is hardcoded for now. TODO(pkozakowski): Pass it from the
# model.
max_len = 2048
state = _fast_inference_init_state(input_shapes, input_dtype, max_len)
return params, state
class MemoryEfficientCausalAttention(BaseCausalAttention):
"""Memory-efficient dot product attention.
This layer performs causal attention on long sequences without running out
of memory. Instead of computing dot products for all query-key pairs at once,
it uses a loop to compute attention for a small set of query positions at a
time. The "loop_stride" parameter controls how many query positions are
considered at each iteration of the loop.
Note that this class does not slice along the batch/head dimension. Looping
over batch elements and heads instead of query positions is also a viable
option. We haven't implemented it, but it may perform well, too.
"""
def __init__(self, loop_stride, dropout, mode, share_qk=False, hard_k=0):
assert backend.get_name() == 'jax', (
'JAX backend is required to use MemoryEfficientCausalAttention.')
super(MemoryEfficientCausalAttention, self).__init__()
self._loop_stride = loop_stride
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if mode == 'train':
self.dropout = dropout
else:
self.dropout = None
self._share_qk = share_qk
self._hard_k = hard_k
def forward(self, inputs, params=(), state=(), **kwargs):
del params
output, _ = self.forward_and_backward(inputs, None, **kwargs)
return output, state
def has_backward(self):
return True
def backward(self, inputs, output, ct, params=(), state=(), **kwargs):
del output, params, state
_, inputs_ct = self.forward_and_backward(inputs, ct, **kwargs)
return inputs_ct, ()
def make_unit_length(self, x, epsilon=1e-6):
variance = np.mean(x**2, axis=-1, keepdims=True)
norm_inputs = x / np.sqrt(variance + epsilon)
return norm_inputs
def forward_and_backward(self, inputs, ct, rng=None, **kwargs):
del kwargs
query, key, value = inputs
depth = np.shape(query)[-1]
do_backprop = ct is not None
# jax uses the term cotangent (ct) to refer to gradient signals, and
# vector-Jacobian product (vjp) for back-propagation through a layer.
def make_mask(N, M, k): # pylint: disable=invalid-name
"""Constructs a slice of the causal attention mask.
Args:
N: number of query positions
M: number of key positions
k: position of the initial query element
Returns:
N x M mask, where 1.0 indicates that attention is not allowed.
"""
x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
mask = jax.lax.lt(
(jax.lax.broadcast_in_dim(
x, shape=(N, M), broadcast_dimensions=(0,)) + k),
jax.lax.broadcast(y, [N]))
mask = jax.lax.convert_element_type(mask, np.float32)
return mask
def make_self_mask(N, M, k): # pylint: disable=invalid-name
"""Masks out elements attending to self.
Args:
N: number of query positions
M: number of key positions
k: position of the initial query element
Returns:
N x M mask, where 1.0 indicates that attention is not allowed.
"""
x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
mask = jax.lax.eq(
(jax.lax.broadcast_in_dim(
x, shape=(N, M), broadcast_dimensions=(0,)) + k),
jax.lax.broadcast(y, [N]))
mask = jax.lax.convert_element_type(mask, np.float32)
return mask
def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name
"""Forward pass for a subset of the query vectors."""
if self._share_qk:
key = self.make_unit_length(key)
dots = np.matmul(
query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
# Causal masking
mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
dots = dots - 1e9 * mask
# Mask out attention to self except when no other targets are available.
if self._share_qk:
self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
dots = dots - 1e5 * self_mask
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if self.dropout is not None and self.dropout > 0.0:
# Dropout is broadcast across the batch+head dimension
dropout_shape = (1, dots.shape[-2], dots.shape[-1])
slice_rng = jax.random.fold_in(rng, q_loop_idx)
keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier
if self._hard_k > 0:
top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight.
top_k = jax.lax.stop_gradient(top_k)
dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones).
dots = np.maximum(dots, 0)
dots_sum = np.sum(dots, axis=-1, keepdims=True) # Re-normalize.
dots /= dots_sum # Re-normalize.
out_slice = np.matmul(dots, value)
return out_slice
def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): # pylint: disable=invalid-name
# Capture q_loop_idx to avoid calculated gradients wrt. it.
def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name
return forward_slice(query_slice, q_loop_idx, key, value)
output_slice, vjpfun = jax.vjp(
forward_slice_with_q_loop_idx, query_slice, key, value)
return output_slice, vjpfun(ct_slice)
q_loop_idx = np.zeros((), dtype=np.int32)
q_loop_max = query.shape[-2]
q_loop_stride = self._loop_stride
assert q_loop_max % q_loop_stride == 0, (
'Stride must evenly divide the number of query elements.')
out_accum = np.zeros_like(query)
if do_backprop:
query_ct_accum = np.zeros_like(query)
key_ct_accum = np.zeros_like(key)
value_ct_accum = np.zeros_like(value)
init_vals = (
q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum)
else:
init_vals = (q_loop_idx, out_accum)
def cond_fun(vals): # pylint: disable=invalid-name
q_loop_idx = vals[0]
return jax.lax.lt(q_loop_idx, q_loop_max)
def body_fun(vals): # pylint: disable=invalid-name
"""Compute a slice of the attention mechanism."""
if do_backprop:
(q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum) = vals
else:
q_loop_idx, out_accum = vals
query_slice = jax.lax.dynamic_slice_in_dim(
query, q_loop_idx, q_loop_stride, axis=-2)
if do_backprop:
ct_slice = jax.lax.dynamic_slice_in_dim(
ct, q_loop_idx, q_loop_stride, axis=-2)
out_slice, partial_ct = forward_and_vjp_slice(
query_slice, q_loop_idx, key, value, ct_slice)
query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
key_ct_accum = key_ct_accum + partial_ct[1]
value_ct_accum = value_ct_accum + partial_ct[2]
else:
out_slice = forward_slice(query_slice, q_loop_idx, key, value)
out_accum = jax.lax.dynamic_update_slice_in_dim(
out_accum, out_slice, q_loop_idx, axis=-2)
q_loop_idx = q_loop_idx + q_loop_stride
if do_backprop:
return (q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum)
else:
return (q_loop_idx, out_accum)
final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)
if not do_backprop:
return final_vals[1], None
else:
return final_vals[1], final_vals[2:]
class TimeBinCausalAttention(BaseCausalAttention):
"""Causal attention where only nearby chunks of items attend to each other."""
def __init__(self, mode, dropout=0.0, bin_length=None, n_bins=None,
share_qk=False):
super(TimeBinCausalAttention, self).__init__()
if (bin_length is None) == (n_bins is None):
raise ValueError('Exactly one of {bin_length, n_bins} must be set.')
self.bin_length = bin_length
self.n_bins = n_bins
self._share_qk = share_qk
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if mode == 'train':
self.dropout = dropout
else:
self.dropout = 0.0
self._mode = mode
def forward_and_backward(self, inputs, ct, **kwargs):
assert backend.get_name() == 'jax', (
'JAX backend is required to use forward_and_backward.')
# Simultaneous forward pass and backprop through the attention mechanism.
def _do_forward(x): # pylint: disable=invalid-name
res, _ = self.forward(x, **kwargs)
return res
output, vjpfun = jax.vjp(_do_forward, inputs)
return output, vjpfun(ct)[0]
def make_unit_length(self, x, epsilon=1e-6):
variance = np.mean(x**2, axis=-1, keepdims=True)
norm_inputs = x / np.sqrt(variance + epsilon)
return norm_inputs
def _pad_inputs(self, inputs):
seq_len = inputs[0].shape[-2]
n_bins = self.n_bins
bin_length = self.bin_length
if n_bins is None:
n_bins = int(math.ceil(seq_len / bin_length))
else:
bin_length = int(math.ceil(seq_len / n_bins))
pad_len = n_bins * bin_length - seq_len
def pad_input(x):
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[-2] = (0, pad_len) # Padding on axis=-2
return np.pad(x, pad_widths, mode='constant',
constant_values=x.dtype.type(0))
padded_inputs = tuple(map(pad_input, inputs))
return (padded_inputs, seq_len, n_bins)
def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
del params, kwargs
if self._mode in ('train', 'eval'):
output = self._forward_train_eval(inputs, rng)
return (output, state)
else:
assert self._mode == 'predict'
return self._forward_predict(inputs, state, rng)
def _forward_train_eval(self, inputs, rng):
(inputs, original_len, n_bins) = self._pad_inputs(inputs)
q, k, v = inputs
seqlen = q.shape[-2]
# q/k/v are n_batch*n_heads, seqlen, d_head
# Time indices for causal masking.
t = jax.lax.tie_in(q, np.arange(seqlen))
# Split off a "bin" axis for chunks of consecutive items.
bq_t = np.reshape(t, (n_bins, -1))
bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1]))
if self._share_qk:
bk = self.make_unit_length(bq)
else:
bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1]))
bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1]))
# Allow each chunk to attend within itself, and also one chunk back.
def look_one_back(x):
# Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
if len(x.shape) == 2:
x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
return np.concatenate([x, x_extra], axis=1)
else:
assert len(x.shape) == 4
x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
return np.concatenate([x, x_extra], axis=2)
bkv_t = look_one_back(bq_t)
bk = look_one_back(bk)
bv = look_one_back(bv)
# Dot-product attention.
dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])
# Causal masking based on the time indices.
mask = jax.lax.convert_element_type(
jax.lax.lt(bq_t[None, :, :, None], bkv_t[None, :, None, :]),
np.float32)
dots = dots - 1e9 * mask
# Mask out attention to self except when no other targets are available.
if self._share_qk:
self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
self_mask = jax.lax.tie_in(dots, self_mask)
dots = dots - 1e5 * self_mask
if self.dropout > 0.0:
# Dropout is broadcast across the batch+head dimension
dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1])
keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
bo = np.matmul(dots, bv)
output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1]))
assert output.shape == v.shape
return output[..., :original_len, :]
def _forward_predict(self, inputs, state, rng):
state = _fast_inference_update_state(inputs, state)
(q, _, _) = inputs
(ks, vs, mask, index) = state
output = DotProductAttention(
q, ks, vs, mask, dropout=self.dropout, mode=self._mode, rng=rng
)
def roll_state(state):
"""Rolls the buffers backward to make space for new data."""
(ks, vs, mask, index) = state
# Move the second bin into the first one's place in both buffers.
def roll_buffer(buf):
return jax.ops.index_update(
buf,
jax.ops.index[:, :self.bin_length, :],
buf[:, self.bin_length:, :],
)
(ks, vs) = map(roll_buffer, (ks, vs))
# Zero out the second bin in the mask.
mask = jax.ops.index_update(
mask, jax.ops.index[:, :, self.bin_length:], 0
)
# Update the index to match the rolled buffers.
index -= self.bin_length
return (ks, vs, mask, index)
# Once we get to the end of the buffer, move the second bin back to make
# space for new data: [ bin_i bin_{i+1} | ] -> [ bin_{i+1} | bin_{i+1} ],
# where | is where index points at in the buffer.
state = jax.lax.cond(
pred=(index == 2 * self.bin_length),
true_operand=state,
true_fun=roll_state,
false_operand=state,
false_fun=(lambda x: x),
)
return (output, state)
def new_params_and_state(self, input_shapes, input_dtype, rng):
if self._mode in ('train', 'eval'):
return (), ()
assert self._mode == 'predict'
assert self.bin_length is not None, (
'For fast inference, TimeBinCausalAttention must be parameterized by '
'bin_length.'
)
params = ()
state = _fast_inference_init_state(
input_shapes, input_dtype, 2 * self.bin_length
)
return params, state
class LSHCausalAttention(BaseCausalAttention):
"""Causal attention based on locality-sensitive hashing."""
def __init__(self, dropout, mode, n_bins=64, n_hashes=1, n_buckets=64,
one_rng=False, allow_duplicate_attention=False,
attend_across_buckets=False, hard_k=0, factorize_hash=False,
rehash_each_round=True, drop_for_hash_rate=0.0):
del dropout
self._mode = mode
super(LSHCausalAttention, self).__init__()
assert n_buckets >= n_bins, 'This setting is not recommended: too few bins.'
assert rehash_each_round or allow_duplicate_attention, (
'The setting {allow_duplicate_attention=False, rehash_each_round=False}'
' is not implemented.')
self.n_bins = n_bins
self.n_hashes = n_hashes
self.n_buckets = n_buckets
self._drop_for_hash_rate = drop_for_hash_rate
self._one_rng = one_rng
self._factorize_hash = factorize_hash
self._prng = None
if one_rng:
seed = random.randint(0, 2**31 - 1)
self._prng = backend.random.get_prng(seed)
self._allow_duplicate_attention = allow_duplicate_attention
self._attend_across_buckets = attend_across_buckets
self._hard_k = hard_k
self._rehash_each_round = rehash_each_round
def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
del params, kwargs
output, _ = self.batch_call_and_or_grad(inputs[0], inputs[2], rng=rng)
return output, state
def forward_and_backward(self, inputs, ct, rng=None, **kwargs):
del kwargs
output, (qk_ct, v_ct) = self.batch_call_and_or_grad(
inputs[0], inputs[2], ct=ct, rng=rng)
return output, (qk_ct, np.zeros_like(inputs[1]), v_ct)
def has_backward(self):
return True
def backward(self, inputs, output, ct, params=(), state=(), rng=None,
**kwargs):
del output, params, state
_, (qk_ct, v_ct) = self.batch_call_and_or_grad(
inputs[0], inputs[2], return_output=False, ct=ct, rng=rng)
inputs_ct = (qk_ct, np.zeros_like(inputs[1]), v_ct)
return inputs_ct, ()
def batch_call_and_or_grad(self, qk, v, ct=None, return_output=True,
rng=None):
assert return_output or ct is not None, 'No work to perform!'
# pylint: disable=protected-access
stash_buckets = (return_output and ct is None
and base.Layer._STASH_IN is not None)
if return_output and ct is not None and base.Layer._STASH_OUT is not None:
buckets = base.Layer._STASH_OUT.pop(self)
else:
buckets = None
# pylint: enable=protected-access
# The approach here is to perform attention for one batch element and head
# at a time. Note that there is absolutely no interaction across examples or
# heads: this layer has no parameters, and hashing patterns are also
# different across examples/heads. As a result, batching doesn't give any
# performance gains except in the case of accelerator under-utilization. We
# assume that hash-based attention will be applied primarily to long
# sequences, where unbatched attention for a single head has sufficient
# computation to fill up the accelerator.
batch_loop_idx = np.zeros((), dtype=np.int32)
batch_loop_max = qk.shape[0]
init_vals = (batch_loop_idx,)
if return_output:
out_accum = np.zeros_like(qk)
init_vals = init_vals + (out_accum,)
if stash_buckets:
buckets_accum = np.zeros(
[qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
init_vals = init_vals + (buckets_accum,)
if ct is not None:
qk_ct_accum = np.zeros_like(qk)
v_ct_accum = np.zeros_like(v)
init_vals = init_vals + (qk_ct_accum, v_ct_accum)
def cond_fun(vals):
batch_loop_idx = vals[0]
return jax.lax.lt(batch_loop_idx, batch_loop_max)
def body_fun(vals):
"""Performs attention for a single batch element and head."""
batch_loop_idx = vals[0]
if self._prng is None:
hash_rng = jax.random.fold_in(rng, batch_loop_idx)
else:
# TODO(kitaev): Maybe use the same RNG across examples (but not heads)?
hash_rng = jax.random.fold_in(self._prng, batch_loop_idx)
qk_slice = jax.lax.dynamic_index_in_dim(
qk, batch_loop_idx, axis=0, keepdims=False)
v_slice = jax.lax.dynamic_index_in_dim(
v, batch_loop_idx, axis=0, keepdims=False)
if buckets is None:
buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng)
else:
buckets_slice = jax.lax.dynamic_index_in_dim(
buckets, batch_loop_idx, axis=0, keepdims=False)
if ct is None:
out_slice = self.single_call(
qk_slice, v_slice, buckets_slice, hash_rng=hash_rng)
else:
def _do_single_call(qk_slice, v_slice):
return self.single_call(
qk_slice, v_slice, buckets_slice, hash_rng=hash_rng)
ct_slice = jax.lax.dynamic_index_in_dim(
ct, batch_loop_idx, axis=0, keepdims=False)
out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice)
qk_ct_slice, v_ct_slice = vjpfun(ct_slice)
new_vals = (batch_loop_idx + 1,)
if return_output:
out_accum = vals[1]
out_accum = jax.lax.dynamic_update_index_in_dim(
out_accum, out_slice, batch_loop_idx, axis=0)
new_vals = new_vals + (out_accum,)
if stash_buckets:
buckets_accum = vals[2]
buckets_accum = jax.lax.dynamic_update_index_in_dim(
buckets_accum, buckets_slice, batch_loop_idx, axis=0)
new_vals = new_vals + (buckets_accum,)
if ct is not None:
qk_ct_accum, v_ct_accum = vals[-2:]
qk_ct_accum = jax.lax.dynamic_update_index_in_dim(
qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0)
v_ct_accum = jax.lax.dynamic_update_index_in_dim(
v_ct_accum, v_ct_slice, batch_loop_idx, axis=0)
new_vals = new_vals + (qk_ct_accum, v_ct_accum)
return new_vals
final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)
if return_output:
out = final_vals[1]