-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathdisc_utils.py
155 lines (145 loc) · 6.84 KB
/
disc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
author: trentbrick and yannadani
Utils for the discrete layers. Taken from https://github.com/google/edward2/blob/2077d67ab8a5c73c39b8d43ccc8cd036dc0a8566/edward2/tensorflow/layers/utils.py
Which is introduced and explained in the paper: https://arxiv.org/abs/1905.10347
And modified for PyTorch.
"""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
def one_hot(inputs, vocab_size = None):
"""Returns one hot of data over each element of the inputs"""
if vocab_size is None:
vocab_size = inputs.max() + 1
input_shape = inputs.shape
inputs = inputs.flatten().unsqueeze(1).long()
z = torch.zeros(len(inputs), vocab_size)
z.scatter_(1, inputs, 1.)
return z.view(*input_shape, vocab_size)
def one_hot_argmax(inputs, temperature, axis=-1):
"""Returns one-hot of argmax with backward pass set to softmax-temperature."""
vocab_size = inputs.shape[-1]
z = one_hot(torch.argmax(inputs, dim=axis), vocab_size)
soft = F.softmax(inputs / temperature, dim=axis)
outputs = soft + (z - soft).detach()
return outputs
def multiplicative_inverse(a, n):
"""Multiplicative inverse of a modulo n.
Args:
a: Tensor of shape [..., vocab_size]. It denotes an integer in the one-hot
space.
n: int Tensor of shape [...].
Returns:
Tensor of same shape and dtype as a.
"""
vocab_size = a.shape[-1]
a_dtype = a.dtype
sparse_a = torch.argmax(a, dim=-1)
sparse_outputs = torch.tensor(py_multiplicative_inverse( sparse_a, n))
z = one_hot(sparse_outputs, vocab_size)
return z
def py_multiplicative_inverse(a, n):
"""Multiplicative inverse of a modulo n (in Python).
Implements extended Euclidean algorithm.
Args:
a: int-like np.ndarray.
n: int.
Returns:
Multiplicative inverse as an int32 np.ndarray with same shape as a.
"""
batched_a = np.asarray(a, dtype=np.int32)
n = np.asarray(n, dtype=np.int32)
batched_inverse = []
for a in np.nditer(batched_a):
inverse = 0
new_inverse = 1
remainder = n
new_remainder = a
while new_remainder != 0:
quotient = remainder // new_remainder
(inverse, new_inverse) = (new_inverse, inverse - quotient * new_inverse)
(remainder, new_remainder) = (new_remainder,
remainder - quotient * new_remainder)
if remainder > 1:
raise ValueError(
'Inverse for {} modulo {} does not exist.'.format(a, n))
if inverse < 0:
inverse += n
batched_inverse.append(inverse)
return np.asarray(batched_inverse, dtype=np.int32).reshape(batched_a.shape)
def one_hot_minus(inputs, shift):
"""Performs (inputs - shift) % vocab_size in the one-hot space.
Args:
inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
Tensor.
shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
Tensor specifying how much to shift the corresponding one-hot vector in
inputs. Soft values perform a "weighted shift": for example,
shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by
zero; 0.3 * shifting by one; and 0.5 * shifting by two.
Returns:
Tensor of same shape and dtype as inputs.
"""
# TODO: Implement with circular conv1d.
#inputs = torch.tensor(inputs)
shift = shift.type( inputs.dtype)
vocab_size = inputs.shape[-1]
# Form a [..., vocab_size, vocab_size] matrix. Each batch element of
# inputs will vector-matrix multiply the vocab_size x vocab_size matrix. This
# "shifts" the inputs batch element by the corresponding shift batch element.
shift_matrix = torch.stack([torch.roll(shift, i, dims=-1)
for i in range(vocab_size)], dim=-2)
outputs = torch.einsum('...v,...uv->...u', inputs, shift_matrix)
return outputs
def one_hot_add(inputs, shift):
"""Performs (inputs + shift) % vocab_size in the one-hot space.
Args:
inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
Tensor.
shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
Tensor specifying how much to shift the corresponding one-hot vector in
inputs. Soft values perform a "weighted shift": for example,
shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by
zero; 0.3 * shifting by one; and 0.5 * shifting by two.
Returns:
Tensor of same shape and dtype as inputs.
"""
inputs = torch.stack((inputs, torch.zeros_like(inputs)), dim = -1)
shift = torch.stack((shift, torch.zeros_like(shift)), dim = -1)
inputs_fft = torch.fft(inputs, 1) #ignore last and first dimension to do batched fft
shift_fft = torch.fft(shift, 1)
result_fft_real = inputs_fft[...,0]*shift_fft[...,0] - inputs_fft[...,1]*shift_fft[...,1]
result_fft_imag = inputs_fft[...,0]*shift_fft[...,1] + inputs_fft[...,1]*shift_fft[...,0]
result_fft = torch.stack((result_fft_real,result_fft_imag), dim = -1)
return torch.ifft(result_fft, 1)[...,0] #return only the real part
def one_hot_multiply(inputs, scale):
"""Performs (inputs * scale) % vocab_size in the one-hot space.
Args:
inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
Tensor.
scale: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
Tensor specifying how much to scale the corresponding one-hot vector in
inputs. Soft values perform a "weighted scale": for example,
scale=[0.2, 0.3, 0.5] performs a linear combination of
0.2 * scaling by zero; 0.3 * scaling by one; and 0.5 * scaling by two.
Returns:
Tensor of same shape and dtype as inputs.
"""
# TODO: Implement with circular conv1d.
#inputs = torch.tensor(inputs)
scale = scale.type( inputs.dtype)
batch_shape = list(inputs.shape[:-1])
vocab_size = inputs.shape[-1]
# Form a [..., vocab_size, vocab_size] tensor. The ith row of the
# batched vocab_size x vocab_size matrix represents scaling inputs by i.
to_perm = torch.arange(vocab_size).unsqueeze(1).repeat(1, vocab_size) * torch.arange(vocab_size).unsqueeze(0)
permutation_matrix = one_hot(torch.fmod(to_perm,vocab_size))
# Scale the inputs according to the permutation matrix of all possible scales.
scaled_inputs = torch.einsum('...v,avu->...au', inputs, permutation_matrix)
scaled_inputs = torch.cat( (torch.zeros(batch_shape + [1, vocab_size]),
scaled_inputs[..., 1:, :]), dim=-2)
# Reduce rows of the scaled inputs by the scale values. This forms a
# weighted linear combination of scaling by zero, scaling by one, and so on.
outputs = torch.einsum('...v,...vu->...u', scale, scaled_inputs)
return outputs