In [74]:
import torch
import math
from torch import nn, optim
from torchtools.vq import VectorQuantize, Binarize # https://github.com/pabloppp/pytorch-tools
from sklearn.cluster import KMeans

In [72]:
X = torch.tensor([[0, 0],[0, 1],[1, 0],[1, 1]]).float()
y = torch.tensor([[0],[1],[1],[0]]).float()

In [81]:
l1 = nn.Sequential(
    nn.Linear(2, 8),
    nn.Tanh()
)

l2 = nn.Sequential(
    nn.Linear(8, 1),
    nn.Sigmoid()
)

binarize = Binarize(threshold=0.5)
vq = VectorQuantize(8, 4)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        try:
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.fill_(0)
        except AttributeError:
            print("Skipping initialization of ", classname)
            
def initialize_vq(samples, kbits=32):
    kmeans = KMeans(n_clusters=kbits, random_state=0).fit(samples.numpy())
    vectors = torch.from_numpy(kmeans.cluster_centers_)
    return vectors
        
l1.apply(weights_init)
l2.apply(weights_init)

Sequential(
  (0): Linear(in_features=8, out_features=1, bias=True)
  (1): Sigmoid()
)

In [214]:
# Regular MLP
optimizer = optim.Adam(list(l1.parameters())+list(l2.parameters()), lr=0.001)

for i in range(3000):
    pred = l2(l1(X))
    loss = (pred - y).pow(2).sum()
    
    # first backpropagate l2
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (i+1) % 300 == 0:
        print(i+1, loss.item(), pred.detach().numpy().tolist())

300 0.7257276773452759 [[0.40045949816703796], [0.5635364651679993], [0.5805628895759583], [0.44601789116859436]]
600 0.2603875994682312 [[0.23179921507835388], [0.7389718294143677], [0.741919755935669], [0.26817086338996887]]
900 0.07400226593017578 [[0.12879358232021332], [0.8632490634918213], [0.8614370226860046], [0.13969239592552185]]
1200 0.030479181557893753 [[0.08482348173856735], [0.9130933880805969], [0.9108487963676453], [0.08822393417358398]]
1500 0.016016945242881775 [[0.06253042817115784], [0.9373887777328491], [0.9352784752845764], [0.0632285475730896]]
1800 0.009681724943220615 [[0.04920986294746399], [0.9515368342399597], [0.9496293663978577], [0.048726122826337814]]
2100 0.006383006926625967 [[0.04033684730529785], [0.9607846736907959], [0.9590693116188049], [0.039278268814086914]]
2400 0.004459144547581673 [[0.03397722542285919], [0.9673147797584534], [0.9657679200172424], [0.03262719884514809]]
2700 0.0032450323924422264 [[0.029176900163292885], [0.9721834063529968]

In [5]:
optimizer = optim.Adam(list(l1.parameters())+list(l2.parameters()), lr=0.001)

for i in range(3000):
    discrete = l1(X)
    binarized = binarize(discrete)
    pred = l2(binarized)
    loss = (pred - y).pow(2).sum()
    
    optimizer.zero_grad()  
    loss.backward()    
    optimizer.step()
    
    if (i+1) % 300 == 0:
        print(i+1, loss.item(), pred.detach().numpy().tolist())

300 0.9425754547119141 [[0.4221487045288086], [0.8250942826271057], [0.5558110475540161], [0.7324411869049072]]
600 0.6348578333854675 [[0.3588734567165375], [0.7965839505195618], [0.5666934847831726], [0.526246190071106]]
900 0.4277668595314026 [[0.31116601824760437], [0.8168681263923645], [0.6434742212295532], [0.41266772150993347]]
1200 0.29615694284439087 [[0.2713545262813568], [0.8374587893486023], [0.7085552215576172], [0.33341261744499207]]
1500 0.2119666188955307 [[0.23809990286827087], [0.8556093573570251], [0.7584514617919922], [0.2758272886276245]]
1800 0.15645280480384827 [[0.2102527767419815], [0.8712595701217651], [0.7964065670967102], [0.23285658657550812]]
2100 0.11852261424064636 [[0.1868014633655548], [0.8846961259841919], [0.8257099390029907], [0.19988951086997986]]
2400 0.09172669053077698 [[0.16690438985824585], [0.8962589502334595], [0.8488091230392456], [0.1739216148853302]]
2700 0.07223724573850632 [[0.1498841941356659], [0.9062612652778625], [0.8674039244651794

In [82]:
optimizer = optim.Adam(list(l1.parameters())+list(l2.parameters())+list(vq.parameters()), lr=0.001)
# optimizer = optim.Adam(l2.parameters(), lr=0.001)

samples = torch.cat([l1(X).detach()] * 100, dim=0)
initial_vq = initialize_vq(samples, kbits=8)
vq.codebook.weight.data = initial_vq

for i in range(3000):
    e = l1(X)
    q, q_grad = vq(e)
    pred = l2(q)
    loss_recon = (pred - y).pow(2).sum()   
    loss_vq = (q_grad - e.detach()).pow(2).sum()
    loss_commit = (e - q_grad.detach()).pow(2).sum() * 0.25
    
    loss = loss_recon + loss_vq + loss_commit
    
    optimizer.zero_grad()    
    loss.backward()    
    optimizer.step()
    
    if (i+1) % 300 == 0:
        print(i+1, loss.item(), pred.detach().numpy().tolist())

  return_n_iter=True)


300 0.867835521697998 [[0.41707509756088257], [0.5408511161804199], [0.5144703984260559], [0.490619033575058]]
600 0.3806043565273285 [[0.26242169737815857], [0.6847149729728699], [0.6580486297607422], [0.30656322836875916]]
900 0.1052703708410263 [[0.14642642438411713], [0.8368092775344849], [0.8183486461639404], [0.15527400374412537]]
1200 0.041897937655448914 [[0.09607701748609543], [0.8978744745254517], [0.8850385546684265], [0.09492387622594833]]
1500 0.021697988733649254 [[0.07081101089715958], [0.9268664121627808], [0.9171368479728699], [0.0668327808380127]]
1800 0.013036184944212437 [[0.05582065135240555], [0.9434993267059326], [0.9356985688209534], [0.050916824489831924]]
2100 0.008575026877224445 [[0.045869067311286926], [0.954285204410553], [0.9478015899658203], [0.04069727286696434]]
2400 0.005987984593957663 [[0.03874439746141434], [0.961869478225708], [0.9563471078872681], [0.033574171364307404]]
2700 0.0043601044453680515 [[0.033365387469530106], [0.9675118327140808], [0

In [83]:
with torch.no_grad():
    pred = l2(l1(X))
    pred_q = l2(binarize(l1(X)))
    pred_q2 = l2(vq(l1(X))[0])
    
    print(pred.numpy().tolist())
    print(pred_q.numpy().tolist())
    print(pred_q2.numpy().tolist())

[[0.029129508882761], [0.9719014763832092], [0.9677055478096008], [0.02425062283873558]]
[[0.05560747906565666], [0.976898193359375], [0.962981641292572], [0.3143444359302521]]
[[0.02913060039281845], [0.9718986749649048], [0.9676977396011353], [0.024263445287942886]]
