In [69]:
import torch
from torch import nn as nn
from torch.nn import functional as F
import math
NUM_SminusR = 8
NUM_freeT = NUM_SminusR - 4
NUM_freeA = NUM_SminusR - 2

In [70]:
def searchsorted(bin_locations, inputs, eps=1e-6):
    bin_locations[...,-1] += eps
    return torch.sum(
        inputs[..., None] >= bin_locations,
        dim=-1
    ) - 1


def cbrt(x, eps=0):
    ans = torch.sign(x)*torch.exp(torch.log(torch.abs(x))/3.0)
    return ans


def sqrt(x, eps=1e-9):
    ans = torch.exp((torch.log(torch.abs(x))) / 2.0)
    return ans

In [71]:
min_bin_width = 0.01
min_bin_height = 0.01
eps=1e-4
quadratic_threshold = 1e-7
linear_threshold = 1e-7
B = 10

inputs = torch.rand(B,1,1,1)
unnormalized_dt = torch.rand(B,1,1,1,NUM_freeT)
unnormalized_dalpha = torch.rand(B,1,1,1,NUM_freeA)

In [72]:
# generate t
dt = torch.softmax(unnormalized_dt, dim=-1)
dt = dt * (1 - 4 * min_bin_width)
dt = min_bin_width + (1 - dt.shape[-1] * min_bin_width / (1 - 4*min_bin_width)) * dt
dt = F.pad(dt, pad=(4,4), mode='constant', value=min_bin_width)

t = torch.cumsum(dt, dim=-1)
t = F.pad(t, pad=(1,0), mode='constant', value=0.0)
t = t - 2*min_bin_width



# generate alpha
dalpha = torch.softmax(unnormalized_dalpha, dim=-1)
dalpha = dalpha * (1 - 2 * min_bin_height)
dalpha = min_bin_height + (1 - dalpha.shape[-1] * min_bin_height / (1 - 2*min_bin_height)) * dalpha
dalpha = F.pad(dalpha, pad=(2,2), mode='constant', value=min_bin_height)


knots3 = torch.cumsum(dalpha, dim=-1)
knots3 = F.pad(knots3, pad=(1,0), mode='constant', value=0.0)
knots3 = knots3 - min_bin_height

#t = torch.roll(t, shifts=-2, dims=-1)
#alpha = torch.roll(alpha, shifts=-3, dims=-1)

widths2 = dt
num_bins = NUM_SminusR

cumwidths = t[..., 2:num_bins+3]

cumheights = knots3[..., 0:num_bins + 1] * (torch.square(widths2[..., 2:num_bins + 3]) / (
        (widths2[..., 0:num_bins + 1] + widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3])
        * (widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3])
)
                                            ) \
             + knots3[..., 1:num_bins + 2] * (
                     (widths2[..., 2:num_bins + 3] * (widths2[..., 0:num_bins + 1] + widths2[..., 1:num_bins + 2]))
                     / ((widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3]) * (
                     widths2[..., 0:num_bins + 1] + widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3]))
                     + (widths2[..., 1:num_bins + 2] * (
                         widths2[..., 2:num_bins + 3] + widths2[..., 3:num_bins + 4]))
                     / ((widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3]) * (
                     widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3] + widths2[..., 3:num_bins + 4]))
             ) \
             + knots3[..., 2:num_bins + 3] * (
                     torch.square(widths2[..., 1:num_bins + 2]) / (
                     (widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3] + widths2[..., 3:num_bins + 4])
                     * (widths2[..., 1:num_bins + 2] + widths2[..., 2:num_bins + 3])
             )
             )

bin_idx = searchsorted(cumheights, inputs)[..., None]

t = torch.roll(t, shifts=-2, dims=-1)
knots = torch.roll(knots3, shifts=-3, dims=-1)

i0 = bin_idx
im1 = torch.remainder(bin_idx - 1, num_bins + 3)
i0 = bin_idx
im1 = torch.remainder(bin_idx - 1, num_bins + 3)
im2 = torch.remainder(bin_idx - 2, num_bins + 3)
im3 = torch.remainder(bin_idx - 3, num_bins + 3)

j3 = bin_idx + 3
j2 = bin_idx + 2
j1 = bin_idx + 1
j0 = bin_idx
jm1 = torch.remainder(bin_idx - 1, num_bins + 5)
jm2 = torch.remainder(bin_idx - 2, num_bins + 5)

km0 = knots.gather(-1, i0)[..., 0]
km1 = knots.gather(-1, im1)[..., 0]
km2 = knots.gather(-1, im2)[..., 0]
km3 = knots.gather(-1, im3)[..., 0]

t3 = t.gather(-1, j3)[..., 0]
t2 = t.gather(-1, j2)[..., 0]
t1 = t.gather(-1, j1)[..., 0]
t0 = t.gather(-1, j0)[..., 0]
tm1 = t.gather(-1, jm1)[..., 0]
tm2 = t.gather(-1, jm2)[..., 0]

input_left_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_right_cumwidths = cumwidths.gather(-1, bin_idx + 1)[..., 0]

inputs_a1 = km0 * (
        1 / ((t3 - t0) * (t2 - t0) * (t1 - t0))
) + km1 * (
                    - 1 / ((t2 - tm1) * (t1 - tm1) * (t1 - t0))
                    - 1 / ((t2 - tm1) * (t2 - t0) * (t1 - t0))
                    - 1 / ((t3 - t0) * (t2 - t0) * (t1 - t0))
            ) + km2 * (
                    1 / ((t1 - t0) * (t1 - tm2) * (t1 - tm1))
                    + 1 / ((t1 - t0) * (t2 - t0) * (t2 - tm1))
                    + 1 / ((t1 - t0) * (t1 - tm1) * (t2 - tm1))
            ) + km3 * (
                    -1 / ((t1 - tm2) * (t1 - tm1) * (t1 - t0))
            )

inputs_b1 = km0 * (
        (-3 * t0) / ((t3 - t0) * (t2 - t0) * (t1 - t0))
) + km1 * (
                    (2 * tm1 + t1) / ((t2 - tm1) * (t1 - tm1) * (t1 - t0))
                    + (tm1 + t2 + t0) / ((t2 - tm1) * (t2 - t0) * (t1 - t0))
                    + (t3 + 2 * t0) / ((t3 - t0) * (t2 - t0) * (t1 - t0))
            ) + km2 * (
                    (-2 * t1 - tm2) / ((t1 - t0) * (t1 - tm2) * (t1 - tm1))
                    + (-2 * t2 - t0) / ((t1 - t0) * (t2 - t0) * (t2 - tm1))
                    + (-t2 - t1 - tm1) / ((t1 - t0) * (t1 - tm1) * (t2 - tm1))
            ) + km3 * (
                    (3 * t1) / ((t1 - tm2) * (t1 - tm1) * (t1 - t0))
            )

inputs_c1 = km0 * (
        (3 * t0 * t0) / ((t3 - t0) * (t2 - t0) * (t1 - t0))
) + km1 * (
                    (- tm1 * tm1 - 2 * tm1 * t1) / ((t2 - tm1) * (t1 - tm1) * (t1 - t0))
                    + (- tm1 * t2 - tm1 * t0 - t2 * t0) / ((t2 - tm1) * (t2 - t0) * (t1 - t0))
                    + (- t0 * t0 - 2 * t3 * t0) / ((t3 - t0) * (t2 - t0) * (t1 - t0))
            ) + km2 * (
                    (t1 * t1 + 2 * t1 * tm2) / ((t1 - t0) * (t1 - tm2) * (t1 - tm1))
                    + (t2 * t2 + 2 * t0 * t2) / ((t1 - t0) * (t2 - t0) * (t2 - tm1))
                    + (t2 * t1 + tm1 * t1 + t2 * tm1) / ((t1 - t0) * (t1 - tm1) * (t2 - tm1))
            ) + km3 * (
                    (-3 * t1 * t1) / ((t1 - tm2) * (t1 - tm1) * (t1 - t0))
            )

inputs_d1 = km0 * (
        (- t0 * t0 * t0) / ((t3 - t0) * (t2 - t0) * (t1 - t0))
) + km1 * (
                    (tm1 * tm1 * t1) / ((t2 - tm1) * (t1 - tm1) * (t1 - t0))
                    + (tm1 * t2 * t0) / ((t2 - tm1) * (t2 - t0) * (t1 - t0))
                    + (t3 * t0 * t0) / ((t3 - t0) * (t2 - t0) * (t1 - t0))
            ) + km2 * (
                    - (t1 * t1 * tm2) / ((t1 - t0) * (t1 - tm2) * (t1 - tm1))
                    - (t0 * t2 * t2) / ((t1 - t0) * (t2 - t0) * (t2 - tm1))
                    - (t2 * tm1 * t1) / ((t1 - t0) * (t1 - tm1) * (t2 - tm1))
            ) + km3 * (
                    (t1 * t1 * t1) / ((t1 - tm2) * (t1 - tm1) * (t1 - t0))
            )

outputs = torch.zeros_like(inputs)
inputs_b_ = inputs_b1 / inputs_a1 / 3.
inputs_c_ = inputs_c1 / inputs_a1 / 3.
inputs_d_ = (inputs_d1 - inputs) / inputs_a1
delta_1 = -inputs_b_.pow(2) + inputs_c_
delta_2 = -inputs_c_ * inputs_b_ + inputs_d_
delta_3 = inputs_b_ * inputs_d_ - inputs_c_.pow(2)

discriminant = 4. * delta_1 * delta_3 - delta_2.pow(2)

depressed_1 = -2. * inputs_b_ * delta_1 + delta_2
depressed_2 = delta_1

three_roots_mask = discriminant >= 0  # Discriminant == 0 might be a problem in practice.
one_root_mask = discriminant < 0

# Deal with one root cases.
p_ = torch.zeros_like(inputs)
p_[one_root_mask] = cbrt((-depressed_1[one_root_mask] + sqrt(-discriminant[one_root_mask])) / 2.)

p = p_[one_root_mask]
q = cbrt((-depressed_1[one_root_mask] - sqrt(-discriminant[one_root_mask])) / 2.)

outputs_one_root = ((p + q) - inputs_b_[one_root_mask])

outputs[one_root_mask] = outputs_one_root

# Deal with three root cases.

theta = torch.atan2(sqrt(discriminant[three_roots_mask]), -depressed_1[three_roots_mask])
theta /= 3.

cubic_root_1 = torch.cos(theta)
cubic_root_2 = torch.sin(theta)

root_1 = cubic_root_1
root_2 = -0.5 * cubic_root_1 - 0.5 * math.sqrt(3) * cubic_root_2
root_3 = -0.5 * cubic_root_1 + 0.5 * math.sqrt(3) * cubic_root_2

root_scale = 2 * sqrt(-depressed_2[three_roots_mask])
root_shift = -inputs_b_[three_roots_mask]

root_1 = root_1 * root_scale + root_shift
root_2 = root_2 * root_scale + root_shift
root_3 = root_3 * root_scale + root_shift

root1_mask = ((input_left_cumwidths[three_roots_mask] - eps) < root_1).float()
root1_mask *= (root_1 < (input_right_cumwidths[three_roots_mask] + eps)).float()

root2_mask = ((input_left_cumwidths[three_roots_mask] - eps) < root_2).float()
root2_mask *= (root_2 < (input_right_cumwidths[three_roots_mask] + eps)).float()

root3_mask = ((input_left_cumwidths[three_roots_mask] - eps) < root_3).float()
root3_mask *= (root_3 < (input_right_cumwidths[three_roots_mask] + eps)).float()

roots = torch.stack([root_1, root_2, root_3], dim=-1)

masks = torch.stack([root1_mask, root2_mask, root3_mask], dim=-1)
mask_index = torch.argsort(masks, dim=-1, descending=True)[..., 0][..., None]
output_three_roots = torch.gather(roots, dim=-1, index=mask_index).view(-1)
outputs[three_roots_mask] = output_three_roots

# Deal with a -> 0 (almost quadratic) cases.

quadratic_mask = inputs_a1.abs() < quadratic_threshold
a = inputs_b1[quadratic_mask]
b = inputs_c1[quadratic_mask]
c = (inputs_d1[quadratic_mask] - inputs[quadratic_mask])
alpha = (-b + sqrt(b.pow(2) - 4 * a * c)) / (2 * a)
outputs[quadratic_mask] = alpha
# + input_left_cumwidths[quadratic_mask]

# Deal with b-> 0 (almost linear) cases.
linear_mask = inputs_b1.abs() < linear_threshold
linear_mask = linear_mask * quadratic_mask
b = inputs_c1[linear_mask]
c = (inputs_d1[linear_mask] - inputs[linear_mask])
alpha = c / b
outputs[linear_mask] = alpha
outputs = torch.clamp(outputs, input_left_cumwidths, input_right_cumwidths)
logabsdet = -torch.log(
    (torch.abs(
        (3 * inputs_a1 * outputs.pow(2)
         + 2 * inputs_b1 * outputs
         + inputs_c1))
    )
)

In [73]:
print(torch.squeeze(inputs))
print(torch.squeeze(outputs))
print(torch.squeeze(logabsdet))

tensor([0.9105, 0.9741, 0.0593, 0.2728, 0.1440, 0.3274, 0.8865, 0.9599, 0.2562,
        0.2804])
tensor([0.9593, 0.9800, 0.0426, 0.2146, 0.0599, 0.2886, 0.9149, 0.9761, 0.1622,
        0.1529])
tensor([-0.8344, -0.8055, -0.3708, -0.0959, -0.6799,  0.3054, -0.2164, -0.7469,
        -0.1144, -0.2871])


In [74]:
print(t[0,...])

tensor([[[[ 0.0000,  0.0100,  0.0200,  0.2035,  0.5046,  0.7957,  0.9800,
            0.9900,  1.0000,  1.0100,  1.0200, -0.0200, -0.0100]]]])
