In [32]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class BitLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False, a=-1, b=1, epsilon = 1e-5):
        super(BitLinear, self).__init__(in_features, out_features, bias)
        self.epsilon = epsilon
        self.a = a
        self.b = b
        self.in_features = in_features
        self.out_features = out_features
        self.weight = self.binarize(self.weight)
        self.gamma_forward = nn.Parameter(torch.ones(in_features))
        self.beta_forward = nn.Parameter(torch.ones(out_features))
    
    def apply_binarization(self):
        self.weight = self.binarize(self.weight)
        
    def round_clip(self, W, a=-1, b=1):
        # make sure we broadcast a and b to the same shape as W
        a = a * torch.ones_like(W)
        b = b * torch.ones_like(W)
        W = torch.max(a, torch.min(b, W.round()))
        return W

    def binarize(self, W):  
        gamma = torch.sum(torch.abs(W)) / (W.shape[0] * W.shape[1])  
        W = W / (gamma + self.epsilon)  
        W_bin = self.round_clip(W, self.a, self.b)  
        W = W + (W_bin - W).detach()  # STE for the rounding operation  
        return torch.nn.Parameter(W, requires_grad=True)  
    
def forward(self, input):
    # Ensure input is at least 2D
    if input.dim() == 1:
        input = input.unsqueeze(1)

    input_norm = F.layer_norm(input, (self.in_features,))

    # Absmax Quantization
    quant_scale = torch.max(torch.abs(input_norm), dim=1, keepdim=True).values
    input_quant = torch.sign(input_norm) * (quant_scale / self.gamma_forward)

    # Calculate the positive and negative parts of the weight
    weight_pos = torch.clamp(self.weight, min=0)
    weight_neg = torch.clamp(self.weight, max=0)

    # Calculate the output as the sum of the positive and negative parts
    output_pos = torch.sum(input * weight_pos, dim=1)
    output_neg = torch.sum(input * weight_neg, dim=1)
    output = output_pos + output_neg

    if self.bias is not None:
        output += self.bias

    return output



In [33]:
binary_linear = BitLinear(10, 1)
regular_linear = nn.Linear(10, 1)

In [34]:
# pass dummy data
x = torch.rand(10)
print(binary_linear(x))

# check the weights
print(binary_linear.weight)

tensor([-0.8660], grad_fn=<SqueezeBackward4>)
Parameter containing:
tensor([[ 1.,  1., -1., -1., -1., -1.,  1.,  0., -1.,  0.]],
       requires_grad=True)


In [36]:
#now lets try with an optimizer
import torch.optim as optim
import torch.nn.functional as F

binary_linear = BitLinear(10, 10)
optimizer = optim.SGD(binary_linear.parameters(), lr=0.01)

# pass dummy data
x = torch.rand(10)

# check the weights
print(binary_linear.weight)

# now lets update the weights
# first we need to zero the gradients
optimizer.zero_grad()
# then we can update the weights
output = binary_linear(x)
loss = F.mse_loss(output, torch.rand(10))
loss.backward()
optimizer.step()

binary_linear.apply_binarization()
# check the weights
print(binary_linear.weight)


Parameter containing:
tensor([[ 1.,  1.,  1.,  0.,  0.,  1., -1.,  0.,  1., -1.],
        [ 1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1.],
        [ 1., -1., -1.,  0.,  1., -1., -1.,  1.,  1.,  0.],
        [-1., -1.,  1.,  1.,  0., -1.,  1.,  1., -1.,  1.],
        [-1.,  1.,  1.,  1.,  1.,  1.,  0.,  1., -1.,  1.],
        [-1., -1.,  1., -1.,  0., -1., -1.,  1.,  0.,  0.],
        [ 0.,  0.,  0., -1.,  0.,  1., -1., -1.,  0.,  0.],
        [-1., -1.,  1., -1.,  1.,  1., -1., -1.,  0., -1.],
        [ 1., -1., -1.,  1.,  1.,  1.,  0.,  1., -1.,  1.],
        [ 1., -1.,  0.,  1.,  1., -1., -1.,  0., -1.,  1.]],
       requires_grad=True)
Parameter containing:
tensor([[ 1.,  1.,  1.,  0.,  0.,  1., -1.,  0.,  1., -1.],
        [ 1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1.],
        [ 1., -1., -1.,  0.,  1., -1., -1.,  1.,  1.,  0.],
        [-1., -1.,  1.,  1.,  0., -1.,  1.,  1., -1.,  1.],
        [-1.,  1.,  1.,  1.,  1.,  1.,  0.,  1., -1.,  1.],
        [-1., -1.,  1., -1.,

In [41]:
#now let's optimize the binary layer on outputting all ones

# create a binary layer
binary_linear = BitLinear(10, 1)
# create an optimizer
optimizer = optim.SGD(binary_linear.parameters(), lr=0.01)
# create a loss function
criterion = nn.BCEWithLogitsLoss()
# pass dummy data
x = torch.rand(10)

# now lets update the weights
for i in range(1000):
    # first we need to zero the gradients
    optimizer.zero_grad()
    # then we can update the weights
    output = binary_linear(x)
    loss = criterion(output, torch.ones(1))
    loss.backward()
    optimizer.step()
    binary_linear.apply_binarization()
    if loss.item() < 0.01:
        print(loss.item())
        break
