<a href="https://colab.research.google.com/github/seonhe/PyTorch/blob/master/Untitled8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch.nn as nn
from torch.autograd import Function,Variable
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.optim as optim

In [31]:
PATH_DATASETS='/content/CIFAR'
BATCH_SIZE=10
epochs=600
LEARNING_RATE=0.005
MOMENTUM=0.9

device = 'cuda' if torch.cuda.is_available() else 'cpu'

#for reproducibility
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [13]:
class BinActiv(Function):
    '''
    Binarize the input activations and calculate the mean across channel dimension
    '''
    @staticmethod
    def forward(ctx,input):
        ctx.save_for_backward(input)
        input = input.sign()
        return input #tensor.Forward should has only one output, or there will be another grad
    
    @classmethod
    def Mean(cls,input):
        return torch.mean(input.abs(),1,keepdim=True) #the shape of mnist data is (N,C,W,H)

    @staticmethod
    def backward(ctx,grad_output): #grad_output is a Variable
        input,=ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input #Variable

BinActive = BinActiv.apply

In [14]:

class BinConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=False):
        super(BinConv2d,self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.layer_type = 'BinConv2d'

        self.bn = nn.BatchNorm2d(in_channels,eps=1e-4,momentum=0.1,affine=True)
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,
                            groups=groups,bias=bias)
        self.relu = nn.ReLU()

    def forward(self,x):
        #block structure is BatchNorm -> BinActiv -> BinConv -> Relu
        x = self.bn(x)
        A = BinActiv().Mean(x)
        x = BinActive(x)
        k = torch.ones(1,1,self.kernel_size,self.kernel_size).mul(1/(self.kernel_size**2)) #out_channels and in_channels are both 1.constrain kernel as square
        k = Variable(k)
        K = F.conv2d(A,k,bias=None,stride=self.stride,padding=self.padding,dilation=self.dilation)
        x = self.conv(x)
        x = torch.mul(x,K)
        x = self.relu(x)
        return x

In [15]:
class BinLinear(nn.Module):
    def __init__(self,in_features,out_features):
        super(BinLinear,self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bn = nn.BatchNorm1d(in_features,eps=1e-4,momentum=0.1,affine=True)
        self.linear = nn.Linear(in_features,out_features,bias=False)

    def forward(self,x):
        x = self.bn(x)
        beta = BinActiv().Mean(x).expand_as(x)
        x = BinActive(x)
        x = torch.mul(x,beta)
        x = self.linear(x)
        return x

In [16]:
class XNOR_NET(nn.Module):
    def __init__(self):
        super(XNOR_NET,self).__init__()
        self.conv1 = BinConv2d(3,128,kernel_size = 3, padding=1, stride=1)
        self.conv2 = BinConv2d(128,128,kernel_size = 3, padding=1, stride=1)
        self.conv3 = BinConv2d(128,256,kernel_size = 3, padding=1, stride=1)
        self.conv4 = BinConv2d(256,256,kernel_size = 3, padding=1, stride=1)
        self.conv5 = BinConv2d(256,512,kernel_size = 3, padding=1, stride=1)
        self.conv6 = BinConv2d(512,512,kernel_size = 3, padding=1, stride=1)

        self.fc1 = BinLinear(8192,1024)
        self.fc2 = BinLinear(1024,1024)
        self.fc3 = BinLinear(1024,10)

    def forward(self,x):
        x = self.conv1(x)

        x = self.conv2(x)
        x = F.max_pool2d(x,2)

        x = self.conv3(x)

        x = self.conv4(x)
        x = F.max_pool2d(x,2)

        x = self.conv5(x)

        x = self.conv6(x)
        x = F.max_pool2d(x,2)

        x = x.view(-1,8192)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        x = F.max_pool1d(x,1)

        return x

In [17]:
#x=torch.randn(10,3,32,32)
#print(x.shape)

xnor_net=XNOR_NET()
#y=xnor(x)
#print(y.shape)

In [18]:
criterion=nn.CrossEntropyLoss().to(device)
optimizer=torch.optim.SGD(xnor_net.parameters(),lr=LEARNING_RATE, momentum=MOMENTUM)
lr_sche=optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)



In [19]:
#CIFAR DATASET
cifar_train=dsets.CIFAR10(root=PATH_DATASETS, train=True, transform=transforms.ToTensor(),download=True)
cifar_test=dsets.CIFAR10(root=PATH_DATASETS, train=False, transform=transforms.ToTensor(),download=True)

train_loader=torch.utils.data.DataLoader(dataset=cifar_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=cifar_test, batch_size=BATCH_SIZE, shuffle=False)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

Files already downloaded and verified
Files already downloaded and verified


In [23]:
#TRAINING

for epoch in range(10):
    #running_loss=0.0
    lr_sche.step()
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs=inputs.to(device)
        labels=labels.to(device)

        #zero the parmeter gradients
        optimizer.zero_grad()

        #forward+backward+optimize
        outputs=xnor_net(inputs)
        loss=criterion(outputs, labels)
        loss.backward()
        optimizer.step()

KeyboardInterrupt: ignored

In [33]:
total=0
correct=0
num=0

with torch.no_grad():
    for data in test_loader:
        num+=1
        images, labels= data
        #images=images.to(device)
        #labels=labels.to(device)
        outputs=xnor_net(images)

        _, predicted = torch.max(outputs.data, 1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()

        print('Accuracy of the network on the 10000 test images: %d %%'%(100*correct/total))

Accuracy of the network on the 10000 test images: 40 %
Accuracy of the network on the 10000 test images: 39 %
Accuracy of the network on the 10000 test images: 38 %
Accuracy of the network on the 10000 test images: 36 %
Accuracy of the network on the 10000 test images: 36 %
Accuracy of the network on the 10000 test images: 37 %
Accuracy of the network on the 10000 test images: 38 %
Accuracy of the network on the 10000 test images: 38 %
Accuracy of the network on the 10000 test images: 39 %
Accuracy of the network on the 10000 test images: 40 %
Accuracy of the network on the 10000 test images: 41 %
Accuracy of the network on the 10000 test images: 41 %
Accuracy of the network on the 10000 test images: 40 %


KeyboardInterrupt: ignored