In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict

In [2]:
class BinarizeLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input

In [31]:
class BinarizedNN(nn.Module):
    def __init__(self):
        super(BinarizedNN, self).__init__()
        self.fc1 = nn.Linear(4, 4, bias=False)

    def forward(self, x):
#         x = x.view(x.size(0), -1)  # Flatten the input
        x = BinarizeLayer.apply(self.fc1(x))
        return x

In [15]:
H4 = torch.tensor([[1.0, 1.0, 1.0, 1.0], [1.0, -1.0, 1.0, -1.0], [1.0, 1.0, -1.0, -1.0], [1.0, -1.0, -1.0, 1.0]])
perm = torch.randperm(pow(2, 4))
print(perm)
inputs = [torch.tensor([0.0]*(4-len(bin(num)[2:]))+[float(i) for i in bin(num)[2:]]) for num in perm]
print(inputs)
outputs = [np.matmul(H4, inp) for inp in inputs]
print(outputs)

tensor([12,  7, 10, 11, 15,  6,  4, 14,  2,  3,  5, 13,  1,  8,  0,  9])
[tensor([1., 1., 0., 0.]), tensor([0., 1., 1., 1.]), tensor([1., 0., 1., 0.]), tensor([1., 0., 1., 1.]), tensor([1., 1., 1., 1.]), tensor([0., 1., 1., 0.]), tensor([0., 1., 0., 0.]), tensor([1., 1., 1., 0.]), tensor([0., 0., 1., 0.]), tensor([0., 0., 1., 1.]), tensor([0., 1., 0., 1.]), tensor([1., 1., 0., 1.]), tensor([0., 0., 0., 1.]), tensor([1., 0., 0., 0.]), tensor([0., 0., 0., 0.]), tensor([1., 0., 0., 1.])]
[tensor([2., 0., 2., 0.]), tensor([ 3., -1., -1., -1.]), tensor([2., 2., 0., 0.]), tensor([ 3.,  1., -1.,  1.]), tensor([4., 0., 0., 0.]), tensor([ 2.,  0.,  0., -2.]), tensor([ 1., -1.,  1., -1.]), tensor([ 3.,  1.,  1., -1.]), tensor([ 1.,  1., -1., -1.]), tensor([ 2.,  0., -2.,  0.]), tensor([ 2., -2.,  0.,  0.]), tensor([ 3., -1.,  1.,  1.]), tensor([ 1., -1., -1.,  1.]), tensor([1., 1., 1., 1.]), tensor([0., 0., 0., 0.]), tensor([2., 0., 0., 2.])]


In [16]:
print(outputs[0])

tensor([2., 0., 2., 0.])


In [33]:
learning_rate = 0.01
num_epochs = 100

# Initialize the BNN model and optimizer
model = BinarizedNN()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [34]:
for epoch in range(num_epochs):
    total_loss = 0.0
    for i in range(16):
        input_bf = inputs[i]
        optimizer.zero_grad()

        # Forward pass
        output = model(input_bf)

        # Compute loss
        loss = F.mse_loss(output, outputs[i])

        # Backpropagation
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if(epoch==0 or epoch%5==4):
        print(f'Epoch {epoch + 1}, Loss: {total_loss / 16}')

Epoch 1, Loss: 2.5
Epoch 5, Loss: 0.9375
Epoch 10, Loss: 0.8125
Epoch 15, Loss: 0.8125
Epoch 20, Loss: 0.8125
Epoch 25, Loss: 0.8125
Epoch 30, Loss: 0.8125
Epoch 35, Loss: 0.8125
Epoch 40, Loss: 0.8125
Epoch 45, Loss: 0.8125
Epoch 50, Loss: 0.8125
Epoch 55, Loss: 0.8125
Epoch 60, Loss: 0.8125
Epoch 65, Loss: 0.8125
Epoch 70, Loss: 0.8125
Epoch 75, Loss: 0.8125
Epoch 80, Loss: 0.8125
Epoch 85, Loss: 0.8125
Epoch 90, Loss: 0.8125
Epoch 95, Loss: 0.8125
Epoch 100, Loss: 0.8125


In [64]:
weight_dict = OrderedDict(model.named_parameters())
# weightFunction = weight_dict['fc1.weight'].T
# print("Final weight function: \n", weightFunction)
print(torch.sign(torch.tensor(weight_dict['fc1.weight'])))

tensor([[ 1.,  1.,  1.,  1.],
        [ 1., -1.,  1., -1.],
        [ 1.,  1., -1., -1.],
        [ 1., -1., -1.,  1.]])


  print(torch.sign(torch.tensor(weight_dict['fc1.weight'])))


In [67]:
model.eval()
for i in range(16):
    if(not torch.eq(outputs[i], model(inputs[i]))):
        print(i, ": ", outputs[i], inputs[i])
# ws = model(torch.tensor([-2.0, 0, 2, 2]))
# print(ws)

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [75]:
# torch.eq(torch.tensor([4, 3]), torch.tensor([4, 3]))
i = 0
outputs[i]#, model(inputs[i])

tensor([2., 0., 2., 0.])