In [5]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class BitLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False, a=-1, b=1, epsilon = 1e-5):
        super(BitLinear, self).__init__(in_features, out_features, bias)
        self.epsilon = epsilon
        self.a = a
        self.b = b
        self.in_features = in_features
        self.out_features = out_features
        self.gamma_forward = nn.Parameter(torch.ones(in_features))
        self.beta_forward = nn.Parameter(torch.ones(out_features))
    
    def get_binary_weight(self):
        Wb = self.binarize(self.weight)
        return Wb
        
    def round_clip(self, W, a=-1, b=1):
        # make sure we broadcast a and b to the same shape as W
        a = a * torch.ones_like(W)
        b = b * torch.ones_like(W)
        W = torch.max(a, torch.min(b, W.round()))
        return W

    def binarize(self, W):  
        gamma = torch.sum(torch.abs(W)) / (W.shape[0] * W.shape[1])  
        W = W / (gamma + self.epsilon)  
        W_bin = self.round_clip(W, self.a, self.b)  
        W = W + (W_bin - W).detach()  # STE for the rounding operation  
        return torch.nn.Parameter(W, requires_grad=True)  
    
    def forward(self, input):
        # Ensure input is at least 2D
        if input.dim() == 1:
            input = input.unsqueeze(1)

        input_norm = F.layer_norm(input, (self.in_features,))

        # Absmax Quantization
        quant_scale = torch.max(torch.abs(input_norm), dim=1, keepdim=True).values
        input_quant = torch.sign(input_norm) * (quant_scale / self.gamma_forward)
        
        weights_bin = self.get_binary_weight()

        # Calculate the positive and negative parts of the weight
        weight_pos = torch.clamp(weights_bin, min=0)
        weight_neg = torch.clamp(weights_bin, max=0)

        # Calculate the output as the sum of the positive and negative parts
        output_pos = torch.sum(input_quant * weight_pos, dim=1)
        output_neg = torch.sum(input_quant * weight_neg, dim=1)
        output = output_pos + output_neg


        return output



In [6]:
binary_linear = BitLinear(10, 1)
regular_linear = nn.Linear(10, 1)

In [7]:
# pass dummy data
x = torch.rand(10)
print(binary_linear(x))

# check the weights
print(binary_linear.weight)

#get the binary weights
print(binary_linear.get_binary_weight())

tensor([0.1226], grad_fn=<SqueezeBackward4>)
Parameter containing:
tensor([[-0.1646, -0.1849,  0.2831, -0.0491, -0.2510,  0.3143,  0.3045, -0.0498,
          0.0652,  0.1303]], requires_grad=True)
Parameter containing:
tensor([[-1., -1.,  1.,  0., -1.,  1.,  1.,  0.,  0.,  1.]],
       requires_grad=True)


In [8]:
from torch import optim
#now let's optimize the binary layer on outputting all ones

# create a binary layer
binary_linear = BitLinear(10, 1)
# create an optimizer
optimizer = optim.SGD(binary_linear.parameters(), lr=0.01)
# create a loss function
criterion = nn.BCEWithLogitsLoss()
# pass dummy data
x = torch.rand(10)

# now lets update the weights
for i in range(10000):
    # first we need to zero the gradients
    optimizer.zero_grad()
    # then we can update the weights
    output = binary_linear(x)
    loss = criterion(output, torch.ones(1))
    loss.backward()
    optimizer.step()
    print(loss.item())
    if loss.item() < 0.01:
        print(loss.item())
        break


0.6521852612495422
0.6492019295692444
0.646237850189209
0.6432930827140808
0.640367329120636
0.6374606490135193
0.6345727443695068
0.6317034959793091
0.628852903842926
0.6260206699371338
0.6232067346572876
0.6204110383987427
0.6176334619522095
0.6148737668991089
0.6121319532394409
0.6094077825546265
0.6067011952400208
0.6040120720863342
0.6013401746749878
0.5986855626106262
0.5960480570793152
0.5934275388717651
0.5908238291740417
0.5882368087768555
0.5856664180755615
0.5831125378608704
0.5805750489234924
0.5780538320541382
0.5755486488342285
0.5730595588684082
0.5705863833427429
0.5681290626525879
0.565687358379364
0.5632612705230713
0.5608507394790649
0.5584554672241211
0.556075394153595
0.5537105202674866
0.5513606667518616
0.5490257740020752
0.5467056035995483
0.5444002151489258
0.5421094298362732
0.5398330688476562
0.537571132183075
0.5353235006332397
0.5330900549888611
0.5308706760406494
0.52866530418396
0.5264737606048584
0.5242959856987
0.5221319198608398
0.5199815034866333
0.51

In [9]:
#now let's check that both the binary and regular layers are outputting the same thing

x = torch.rand(10)

#first regular pass

print(binary_linear(x)) 


#now pass the regular weights
non_binary_weights = binary_linear.weight

print(F.linear(x, non_binary_weights, binary_linear.bias))

tensor([4.8344], grad_fn=<SqueezeBackward4>)
tensor([4.8344], grad_fn=<SqueezeBackward4>)


In [10]:
print("Floating point weights")
print(non_binary_weights)
print("Binary weights")
print(binary_linear.get_binary_weight())

Floating point weights
Parameter containing:
tensor([[ 0.1133,  1.4978,  0.5285,  3.1385,  0.6349,  0.2923,  1.5121,  0.9594,
          0.5896, -0.0079]], requires_grad=True)
Binary weights
Parameter containing:
tensor([[0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]], requires_grad=True)
