In [1]:
import torch
import torch.autograd as autograd

# Main idea: compensation with max differentiation

In [2]:
# Define x1 and x2
x1 = torch.ones(4) * 1.14
x2 = torch.tensor([1.0,2.0,3.0,4.0]) 

def max1(x):
    res = x[0]
    for i in range(3):
        if x[i + 1] > res:
            res = x[i + 1]
    return res

def max2(x):
    return torch.max(x)

def zero(t):
    z = t * x1 # choose x1 or x2 here
    return max1(z) - max2(z)

# Test zero
print("zero is constantly zero")
print([zero(t).item() for t in [0.0, 0.1, 1.0, -1e-4, -6.0]])

# Compute backward derivatives using autograd
print("\nbut its backward derivative is not")
for t in [0.0, 0.1, 1.0, -1e-4, -6.0]:
    t_tensor = torch.tensor([t], requires_grad=True)
    output = zero(t_tensor)
    output.backward()
    print(t_tensor.grad.item())

zero is constantly zero
[0.0, 0.0, 0.0, 0.0, 0.0]

but its backward derivative is not
5.960464477539063e-08
5.960464477539063e-08
5.960464477539063e-08
5.960464477539063e-08
5.960464477539063e-08


# Pure relu implementation
The only nonlinearity in the following is `jax` native `nn.relu` function. From this we build variations of relu with different derivatives and max functions for four numbers with different derivatives. We then exploit the same mechanism as above.

In [13]:
# Two relu functions
def relu(x):
    return torch.relu(x)

def relu2(x):
    return torch.where(x >= 0, x, torch.tensor(0.0))

def max01(x):
    return (x[0] + x[1]) / 2 + relu((x[0] - x[1]) / 2) + relu((x[1] - x[0]) / 2)

def max02(x):
    return (x[0] + x[1]) / 2 + relu2((x[0] - x[1]) / 2) + relu((x[1] - x[0]) / 2)

In [14]:
def max1(x):
    return max01(torch.stack([max01(x[0:2]), max01(x[2:4])]))

def max2(x):
    return max02(torch.stack([max02(x[0:2]), max02(x[2:4])]))

In [3]:
# Function to test zero_2
def zero_2(t):
    z = t * x1 # choose x1 or x2 here
    return max1(z) - max2(z)

zero_2_values = [zero_2(torch.tensor(t, requires_grad=True)).item() for t in [0.0, 0.1, 1.0, -1e-4, -6.0]]
print("zero_2 is constantly zero")
print(zero_2_values)

# Compute backward derivatives using autograd
print("\nbut its backward derivative is not")
for t in [0.0, 0.1, 1.0, -1e-4, -6.0]:
    t_tensor = torch.tensor([t], requires_grad=True)
    output = zero_2(t_tensor)
    output.backward()
    print(t_tensor.grad.item())

zero_2 is constantly zero
[0.0, 0.0, 0.0, 0.0, 0.0]

but its backward derivative is not
5.960464477539063e-08
5.960464477539063e-08
5.960464477539063e-08
5.960464477539063e-08
5.960464477539063e-08
