In [1]:
import initialize
from mistify import _functional as F
import torch

This notebook is to check if the straight through estimators succeed in optimizing.

In [2]:

def optim_op(f, x1: torch.Tensor, t, n=400, **kwargs):
    x1.requires_grad_()
    optim = torch.optim.Adam([x1], lr=1e-2)
    for i in range(n):
        y = f(x1, **kwargs)
        optim.zero_grad()
        (y - t).pow(2).mean().backward()
        optim.step()

In [3]:

# Test union
torch.manual_seed(1)
x1_b = torch.rand(4, 4)
x2 = torch.rand(4, 4)
t = F.union(x1_b, x2).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.union, x1, t, n=800, x2=x2, g=F.ClipG(0.1))
torch.sqrt((x1 - x1_b).pow(2).mean()).item()

print(F.union(x1, x2)) 
print(F.union(x1_b, x2))


tensor([[0.7576, 0.4980, 0.9371, 0.7347],
        [0.3138, 0.7999, 0.4162, 0.7544],
        [0.5695, 0.5239, 0.7981, 0.7718],
        [0.6826, 0.8100, 0.6397, 0.9743]], grad_fn=<MaximumBackward0>)
tensor([[0.7576, 0.4980, 0.9371, 0.7347],
        [0.3138, 0.7999, 0.4162, 0.7544],
        [0.5695, 0.5239, 0.7981, 0.7718],
        [0.6826, 0.8100, 0.6397, 0.9743]])


In [4]:
# Test union
torch.manual_seed(1)
x1_b = torch.rand(4, 4)
x2 = torch.rand(4, 4)
t = F.inter(x1_b, x2).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.inter, x1, t, x2=x2, g=F.ClipG(0.1))
torch.sqrt((x1 - x1_b).pow(2).mean()).item()


print(F.inter(x1, x2)) 
print(F.inter(x1_b, x2))

tensor([[0.5725, 0.2793, 0.4031, 0.6556],
        [0.3138, 0.1980, 0.4162, 0.2843],
        [0.3398, 0.4388, 0.7981, 0.5247],
        [0.0112, 0.3051, 0.4635, 0.4550]], grad_fn=<MinimumBackward0>)
tensor([[0.5725, 0.2793, 0.4031, 0.6556],
        [0.0293, 0.1980, 0.3971, 0.2843],
        [0.3398, 0.4388, 0.6387, 0.5247],
        [0.0112, 0.3051, 0.4635, 0.4550]])


In [5]:
# Test inter on
torch.manual_seed(1)
x1_b = torch.rand(4, 4)
# x2 = torch.rand(4, 4)
t = F.inter_on(x1_b, -1, False).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.inter_on, x1, t, dim=-1, keepdim=False)
torch.sqrt((x1 - x1_b).pow(2).mean()).item()
print(F.inter_on(x1, dim=-1)) 
print(F.inter_on(x1_b, dim=-1))

tensor([0.2793, 0.0293, 0.4388, 0.3051], grad_fn=<MinBackward0>)
tensor([0.2793, 0.0293, 0.4388, 0.3051])


In [6]:
# Test union on
torch.manual_seed(1)
x1_b = torch.rand(4, 4)
# x2 = torch.rand(4, 4)
t = F.union_on(x1_b, -1, False).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.union_on, x1, t, dim=-1, keepdim=False, g=F.ClipG(0.1))
torch.sqrt((x1 - x1_b).pow(2).mean()).item()


print(F.union_on(x1, dim=-1)) 
print(F.union_on(x1_b, dim=-1))

tensor([0.7576, 0.7999, 0.6387, 0.6826], grad_fn=<MaxBackward0>)
tensor([0.7576, 0.7999, 0.6387, 0.6826])


In [7]:
# Test binary
torch.manual_seed(2)
x1_b = torch.rand(4, 4)
# x2 = torch.rand(4, 4)
t = F.binarize(x1_b).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.binarize, x1, t, n=1000, g=True, clip=0.1)
print(torch.sqrt((x1 - x1_b).pow(2).mean()).item())
print(F.binarize(x1), F.binarize(x1_b))

0.1590016782283783
tensor([[1., 0., 1., 0.],
        [1., 1., 0., 0.],
        [1., 0., 1., 1.],
        [0., 1., 1., 1.]], grad_fn=<BinaryGBackward>) tensor([[1., 0., 1., 0.],
        [1., 1., 0., 0.],
        [1., 0., 1., 1.],
        [0., 1., 1., 1.]])


In [8]:
# Test binary
torch.manual_seed(2)
x1_b = torch.randn(4, 4)
# x2 = torch.rand(4, 4)
t = F.signify(x1_b).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.signify, x1, t, n=1000, g=True, clip=0.1)
print(torch.sqrt((x1 - x1_b).pow(2).mean()).item())
print(F.signify(x1)) 
print(F.signify(x1_b))

0.7361576557159424
tensor([[-1.,  1., -1., -1.],
        [-1.,  1., -1., -1.],
        [-1.,  1., -1., -1.],
        [ 1., -1., -1., -1.]], grad_fn=<SignGBackward>)
tensor([[-1.,  1., -1., -1.],
        [-1.,  1., -1., -1.],
        [-1.,  1., -1., -1.],
        [ 1., -1., -1., -1.]])


In [9]:
# Test clamp
torch.manual_seed(2)
x1_b = torch.rand(4, 4) * 3 - 1
# x2 = torch.rand(4, 4)
t = F.clamp(x1_b).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.clamp, x1, t, n=2000, g=True, clip=0.1)
print(torch.sqrt((x1 - x1_b).pow(2).mean()).item())
print(F.clamp(x1)) 
print(F.clamp(x1_b))

0.3310493528842926
tensor([[0.8441, 0.1430, 0.9113, 0.4234],
        [1.0000, 0.8571, 0.3276, 0.0000],
        [0.8425, 0.0000, 0.6971, 0.5997],
        [0.1702, 1.0000, 0.6001, 1.0000]], grad_fn=<ClampGBackward>)
tensor([[0.8441, 0.1430, 0.9113, 0.4234],
        [1.0000, 0.8571, 0.3276, 0.0000],
        [0.8425, 0.0000, 0.6971, 0.5997],
        [0.1702, 1.0000, 0.6001, 1.0000]])


In [7]:
# Test bounded union on
torch.manual_seed(1)
x1_b = torch.rand(4, 4) ** 4
# x2 = torch.rand(4, 4)
t = F.bounded_union_on(x1_b, -1, False).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.bounded_union_on, x1, t, dim=-1, keepdim=False, g=True, clip=0.1)
torch.sqrt((x1 - x1_b).pow(2).mean()).item()
print(F.bounded_union_on(x1, dim=-1, keepdim=False)) 
print(F.bounded_union_on(x1_b, dim=-1, keepdim=False))

tensor([0.6533, 0.7580, 0.3844, 0.3148], grad_fn=<ClampGBackward>)
tensor([0.6533, 0.7580, 0.3844, 0.3148])


In [6]:
# Test bounded inter on
torch.manual_seed(1)
x1_b = torch.rand(4, 4) ** 0.25
# x2 = torch.rand(4, 4)
t = F.bounded_inter_on(x1_b, -1, False).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.bounded_inter_on, x1, t, dim=-1, keepdim=False, g=True, clip=0.1)
torch.sqrt((x1 - x1_b).pow(2).mean()).item()
print(F.bounded_inter_on(x1, dim=-1, keepdim=False)) 
print(F.bounded_inter_on(x1_b, dim=-1, keepdim=False))

tensor([0.3826, 0.0852, 0.4276, 0.2986], grad_fn=<ReluBackward0>)
tensor([0.3826, 0.0852, 0.4276, 0.2986])


In [4]:

# Test bounded union
torch.manual_seed(1)
x1_b = torch.rand(4, 4) ** 4
x2 = torch.rand(4, 4)
t = F.bounded_union(x1_b, x2).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.bounded_union, x1, t, x2=x2, g=True)
torch.sqrt((x1 - x1_b).pow(2).mean()).item()
print(F.bounded_union(x1, x2)) 
print(F.bounded_union(x1_b, x2))

tensor([[0.9020, 0.5041, 0.9635, 0.9469],
        [0.3138, 0.6073, 0.4411, 0.6082],
        [0.4450, 0.5610, 0.9645, 0.8475],
        [0.2283, 0.8186, 0.6859, 1.0000]], grad_fn=<ClampBackward1>)
tensor([[0.9020, 0.5041, 0.9635, 0.9469],
        [0.3138, 0.6073, 0.4411, 0.6082],
        [0.4450, 0.5610, 0.9645, 0.8475],
        [0.2283, 0.8186, 0.6859, 1.0000]])


In [5]:

# Test bounded inter
torch.manual_seed(1)
x1_b = torch.rand(4, 4) ** 0.5
x2 = torch.rand(4, 4) ** 0.5
t = F.bounded_inter(x1_b, x2).detach()

x1 = torch.rand(4, 4, requires_grad=True)
optim_op(F.bounded_inter, x1, t, x2=x2, g=True)
torch.sqrt((x1 - x1_b).pow(2).mean()).item()
print(F.bounded_inter(x1, x2)) 
print(F.bounded_inter(x1_b, x2))

tensor([[0.6270, 0.2342, 0.6029, 0.6668],
        [0.0000, 0.3393, 0.2753, 0.4018],
        [0.3376, 0.3862, 0.6925, 0.6028],
        [0.0000, 0.4524, 0.4806, 0.6616]], grad_fn=<ReluBackward0>)
tensor([[0.6270, 0.2342, 0.6029, 0.6668],
        [0.0000, 0.3393, 0.2753, 0.4018],
        [0.3376, 0.3862, 0.6925, 0.6028],
        [0.0000, 0.4524, 0.4806, 0.6616]])
