In [222]:
import torch
from torch import nn, optim
from torchtools.vq import VectorQuantize # https://github.com/pabloppp/pytorch-tools

In [223]:
X = torch.tensor([[0, 0],[0, 1],[1, 0],[1, 1]]).float()
y = torch.tensor([[0],[1],[1],[0]]).float()

In [236]:
l1 = nn.Sequential(
    nn.Linear(2, 8),
    nn.Tanh()
)

l2 = nn.Sequential(
    nn.Linear(8, 1),
    nn.Sigmoid()
)

def quantize(x):
    q = (x > 0.5).float()
    q.requires_grad = True
    return q

vq = VectorQuantize(8, 8)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        try:
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.fill_(0)
        except AttributeError:
            print("Skipping initialization of ", classname)
        
l1.apply(weights_init)
l2.apply(weights_init)

Sequential(
  (0): Linear(in_features=8, out_features=1, bias=True)
  (1): Sigmoid()
)

In [214]:
# Regular MLP
optimizer = optim.Adam(list(l1.parameters())+list(l2.parameters()), lr=0.001)

for i in range(3000):
    pred = l2(l1(X))
    loss = (pred - y).pow(2).sum()
    
    # first backpropagate l2
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (i+1) % 300 == 0:
        print(i+1, loss.item(), pred.detach().numpy().tolist())

300 0.7257276773452759 [[0.40045949816703796], [0.5635364651679993], [0.5805628895759583], [0.44601789116859436]]
600 0.2603875994682312 [[0.23179921507835388], [0.7389718294143677], [0.741919755935669], [0.26817086338996887]]
900 0.07400226593017578 [[0.12879358232021332], [0.8632490634918213], [0.8614370226860046], [0.13969239592552185]]
1200 0.030479181557893753 [[0.08482348173856735], [0.9130933880805969], [0.9108487963676453], [0.08822393417358398]]
1500 0.016016945242881775 [[0.06253042817115784], [0.9373887777328491], [0.9352784752845764], [0.0632285475730896]]
1800 0.009681724943220615 [[0.04920986294746399], [0.9515368342399597], [0.9496293663978577], [0.048726122826337814]]
2100 0.006383006926625967 [[0.04033684730529785], [0.9607846736907959], [0.9590693116188049], [0.039278268814086914]]
2400 0.004459144547581673 [[0.03397722542285919], [0.9673147797584534], [0.9657679200172424], [0.03262719884514809]]
2700 0.0032450323924422264 [[0.029176900163292885], [0.9721834063529968]

In [218]:
optimizer = optim.Adam(list(l1.parameters())+list(l2.parameters()), lr=0.001)

for i in range(3000):
    discrete = l1(X)
    quantized = quantize(discrete)
    pred = l2(quantized)
    loss = (pred - y).pow(2).sum()
    
    optimizer.zero_grad()
    
    loss.backward()
    discrete.backward(quantized.grad)
    
    optimizer.step()
    
    if (i+1) % 300 == 0:
        print(i+1, loss.item(), pred.detach().numpy().tolist())

300 0.8163377046585083 [[0.43204426765441895], [0.4940449297428131], [0.6399938464164734], [0.4940449297428131]]
600 0.7620685696601868 [[0.37148717045783997], [0.5079894065856934], [0.6479514241218567], [0.5079894065856934]]
900 0.5148895382881165 [[0.19342926144599915], [0.6210345029830933], [0.70762038230896], [0.49837133288383484]]
1200 0.4038856029510498 [[0.17204031348228455], [0.679912269115448], [0.7246944904327393], [0.4427622854709625]]
1500 0.24194872379302979 [[0.11756282299757004], [0.8646786212921143], [0.7200552821159363], [0.3625558912754059]]
1800 0.18765972554683685 [[0.08906986564397812], [0.8989024758338928], [0.7289723753929138], [0.30991870164871216]]
2100 0.1528913378715515 [[0.06801759451627731], [0.9119315147399902], [0.7439187169075012], [0.27373576164245605]]
2400 0.125566765666008 [[0.0544358491897583], [0.9334274530410767], [0.7608671188354492], [0.24695558845996857]]
2700 0.10639818012714386 [[0.044791653752326965], [0.9356807470321655], [0.777659416198730

In [237]:
optimizer = optim.Adam(list(l1.parameters())+list(l2.parameters())+list(vq.parameters()), lr=0.001)
# optimizer = optim.Adam(l2.parameters(), lr=0.001)

for i in range(3000):
    e = l1(X)
    q, q_grad = vq(e)
    pred = l2(q)
    loss_recon = (pred - y).pow(2).sum()   
    loss_vq = (q_grad - e.detach()).pow(2).sum()
    loss_commit = (e - q_grad.detach()).pow(2).sum() * 0.25
    
    loss = loss_recon + loss_vq + loss_commit
    
    optimizer.zero_grad()    
    loss.backward()    
    optimizer.step()
    
    if (i+1) % 300 == 0:
        print(i+1, loss.item(), pred.detach().numpy().tolist())

300 1.1997461318969727 [[0.49139320850372314], [0.5096949934959412], [0.5008904933929443], [0.49806979298591614]]
600 0.9896625280380249 [[0.4808160960674286], [0.5077378153800964], [0.5042719841003418], [0.5078343749046326]]
900 0.9828028678894043 [[0.4890400767326355], [0.5129714608192444], [0.4973489046096802], [0.500495433807373]]
1200 0.961938202381134 [[0.48719459772109985], [0.5423017740249634], [0.48038920760154724], [0.4864928424358368]]
1500 0.6794148087501526 [[0.39802515506744385], [0.6427853107452393], [0.5492122769355774], [0.40654125809669495]]
1800 0.37637051939964294 [[0.28365492820739746], [0.7145309448242188], [0.6965494155883789], [0.30399253964424133]]
2100 0.1721193790435791 [[0.19136583805084229], [0.793925940990448], [0.8009101748466492], [0.21235743165016174]]
2400 0.08564083278179169 [[0.13530048727989197], [0.8496260046958923], [0.859707236289978], [0.15446625649929047]]
2700 0.04983270913362503 [[0.10270783305168152], [0.88374263048172], [0.8937110900878906]

In [239]:
with torch.no_grad():
    pred = l2(l1(X))
    pred_q = l2(quantize(l1(X)))
    pred_q2 = l2(vq(l1(X))[0])
    
    print(pred.numpy().tolist())
    print(pred_q.numpy().tolist())
    print(pred_q2.numpy().tolist())

[[0.08205469697713852], [0.9063223600387573], [0.9156545400619507], [0.09569525718688965]]
[[0.9026834964752197], [0.9993128776550293], [0.9990249872207642], [0.9834163188934326]]
[[0.08216940611600876], [0.9059074521064758], [0.9152330756187439], [0.09622524678707123]]
