# "Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1" paper implementation - https://arxiv.org/abs/1602.02830

In [1]:
import torch
import torch.nn as nn

def binarize(tensor,quant_mode='det'):
    if quant_mode=='det':
        return tensor.sign()
    else:
        return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)

class BinarizeLinear(nn.Linear):
    def __init__(self, *kargs, **kwargs):
        super(BinarizeLinear, self).__init__(*kargs, **kwargs)

    def forward(self, input):
        if input.size(1) != 784:
            input.data = binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data = binarize(self.weight.org)
        out = nn.functional.linear(input, self.weight)
        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)
        return out

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.infl_ratio = 3
        self.model = nn.Sequential(
            BinarizeLinear(784, 2048 * self.infl_ratio),
            nn.Hardtanh(),
            nn.BatchNorm1d(2048 * self.infl_ratio),
            BinarizeLinear(2048 * self.infl_ratio, 2048 * self.infl_ratio),
            nn.Hardtanh(),
            nn.BatchNorm1d(2048 * self.infl_ratio),
            BinarizeLinear(2048 * self.infl_ratio, 2048 * self.infl_ratio),
            nn.Hardtanh(),
            nn.BatchNorm1d(2048 * self.infl_ratio),
            nn.Linear(2048 * self.infl_ratio, 10),
            nn.LogSoftmax(),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        return self.model(x)
    
model = Net()

In [5]:
from torchvision import datasets, transforms

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=64, shuffle=True)


  warn(


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


1.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


68.4%





KeyboardInterrupt: 