<a href="https://colab.research.google.com/github/seonhe/PyTorch/blob/master/Untitled11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

#for reproducibility
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

# CIFAR10
def get_loaders(batch_size):
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('/content/CIFAR', train=True, download = True,
                                                              transform = transforms.Compose([
                                                                  transforms.ToTensor(),
                                                                  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))  # R, G, B
                                                ])), batch_size = params['batch_size'], shuffle = True)
    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('/content/CIFAR', train=False, download = True,
                                                              transform = transforms.Compose([
                                                                  transforms.ToTensor(),
                                                                  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                                                ])), batch_size = params['batch_size'], shuffle = True)
    return train_loader, test_loader


class XNORModel(nn.Module):
    def __init__(self):
        super().__init__()
        #self.BinConv2d_1 = BinConv2d(3, 128,  kernel_size=1).to('cuda')
        self.Conv2d_1 = nn.Sequential(nn.Conv2d(3,128,kernel_size=3,stride=1,padding=0).to(device), nn.BatchNorm2d(128,affine=False).to(device), nn.ReLU().to(device))
        self.BinConv2d_2 = BinConv2d(128, 128, kernel_size=3).to(device)
        self.BinConv2d_3 = BinConv2d(128, 256, kernel_size=3).to(device)
        self.BinConv2d_4 = BinConv2d(256, 256, kernel_size=3).to(device)
        self.BinConv2d_5 = BinConv2d(256, 512, kernel_size=3).to(device)
        self.BinConv2d_6 = BinConv2d(512, 512, kernel_size=3).to(device)
        
        self.pool        = nn.MaxPool2d(kernel_size=2, stride=2).to(device)
        
        self.fc1         = BinLinear(51200, 1024).to(device)
        self.fc2         = BinLinear(1024, 512).to(device)
        self.fc3         = BinLinear(512, 10).to(device)

    def forward(self, I):
        I = I.to('cuda')
        I = self.Conv2d_1(I) #30
        #I = self.BinConv2d_1(I) #32
        I = self.BinConv2d_2(I) #28
        I = self.BinConv2d_3(I) #26
        #I = self.pool(I) #14
        
        I = self.BinConv2d_4(I) #24
        I = self.BinConv2d_5(I) #22
        I = self.BinConv2d_6(I) #20
        I = self.pool(I) #10

        I = I.view(-1, 51200)
        I = self.fc1(I)
        I = F.relu(I)
        I = self.fc2(I) 
        I = F.relu(I)
        I = self.fc3(I)
        
        return I


class BinConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
        super().__init__()
        self.in_channels  = in_channels
        self.out_channels = out_channels
        self.kernel_size  = kernel_size
        self.stride       = stride
        self.padding      = padding
        self.bn           = nn.BatchNorm2d(in_channels) # default eps = 1e-5, momentum = 0.1, affine = True
        self.conv         = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size)
        self.relu         = nn.ReLU()
        self.pool         = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, I):
        I = self.bn(I)
        A = BinActiv().Mean(I)
        I = BinActive(I)
        k = torch.ones(1,1,self.kernel_size,self.kernel_size).mul(1/(self.kernel_size**2)).to('cuda') # 4d - batch,channel,height,width
        K = F.conv2d(A,k).to(device) # default stride=1, padding=0
        I = self.conv(I)
        I = torch.mul(I, K)
        I = self.relu(I)
        #I = self.pool(I)

        return I



class BinActiv(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        input = torch.sign(input)

        return input

    def Mean(self, input):
        return torch.mean(input.abs(), 1, keepdim=True)  # 1: channel // batch[0], channel[1], height[2], width[3]


    @staticmethod
    def backward(ctx, grad_output):
        input,  = ctx.saved_tensors

        # STE (Straight Through Estimator)
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0    # ge: greater or equal
        grad_input[input.le(-1)] = 0   # le: less or equal
        return grad_input

BinActive = BinActiv.apply

class BinLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_feature  = in_features
        self.out_feature = out_features
        self.bn          = nn.BatchNorm1d(in_features)
        self.linear      = nn.Linear(in_features, out_features)

    def forward(self, I):
        I = self.bn(I)
        beta = BinActiv().Mean(I).expand_as(I)
        I = BinActive(I)
        I = torch.mul(I, beta)
        I = self.linear(I)
        return I



class WeightOperation:
    def __init__(self,model):

        self.count_group_weights = 0
        self.weight = []
        self.saved_weight = []


        for m in model.modules():
            if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):

                self.count_group_weights += 1
                self.weight.append(m.weight)
                self.saved_weight.append(m.weight.data)


    def WeightSave(self):
        for index in range(self.count_group_weights):
            self.saved_weight[index].copy_(self.weight[index].data)


    def WeightBinarize(self):
        for index in range(self.count_group_weights):

            n                 = self.weight[index].data[0].nelement()
            dim_group_weights = self.weight[index].data.size()

            if len(dim_group_weights) == 4:
                alpha = self.weight[index].data.norm(1, 3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(dim_group_weights)

            elif len(dim_group_weights) == 2:
                alpha = self.weight[index].data.norm(1, 1, keepdim=True).div(n).expand(dim_group_weights)

            self.weight[index].data = self.weight[index].data.sign()* alpha


    def WeightRestore(self):
        for index in range(self.count_group_weights):
            self.weight[index].data.copy_(self.saved_weight[index])


    def WeightGradient(self):
        for index in range(self.count_group_weights):
            n = self.weight[index].data[0].nelement()
            dim_group_weights = self.weight[index].data.size()

            if len(dim_group_weights) == 4:
                alpha = self.weight[index].data.norm(1, 3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(dim_group_weights)

            elif len(dim_group_weights) == 2:
                alpha = self.weight[index].data.norm(1, 1, keepdim=True).div(n).expand(dim_group_weights)

            alpha[self.weight[index].data.le(-1.0)] = 0
            alpha[self.weight[index].data.ge( 1.0)] = 0

            self.weight[index].grad = self.weight[index].grad / n + self.weight[index].grad * alpha

In [17]:
# model = XNORModel()
model = XNORModel().to(device)

WeightOperation = WeightOperation(model)

optimizer = optim.Adam(model.parameters(),lr=0.0003)
lr_sche=optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8)
params = {'epochs':30, 'batch_size':100}
loss_fn   = torch.nn.CrossEntropyLoss()

train_loader, test_loader = get_loaders(batch_size=params['batch_size'])

for epoch in range(10):
    lr_sche.step()
    # training
    for batch_idx, (train_inputs, train_labels) in enumerate(train_loader): # train_inputs size:[32,1,28,28], labels size: [32]
        train_inputs = train_inputs.to(device)
        train_labels = train_labels.to(device)
        
        optimizer.zero_grad()

        WeightOperation.WeightSave()
        WeightOperation.WeightBinarize()

        predicted = model(train_inputs)
        loss = loss_fn(predicted, train_labels)
        loss.backward()     # gradient

        WeightOperation.WeightRestore()
        WeightOperation.WeightGradient()

        optimizer.step()    # update
      

        if((batch_idx*len(train_inputs))%1000==0):
           print('[%d, %5d] loss: %.3f' %(epoch, batch_idx*len(train_inputs), loss.item()))  # loss: loss tensor(2.3027, grad_fn=<NllLossBackward>)

  


Files already downloaded and verified
Files already downloaded and verified




[0,     0] loss: 2.413
[0,  1000] loss: 1.965
[0,  2000] loss: 1.951
[0,  3000] loss: 1.763
[0,  4000] loss: 1.805
[0,  5000] loss: 1.699
[0,  6000] loss: 1.958
[0,  7000] loss: 1.821
[0,  8000] loss: 1.509
[0,  9000] loss: 1.746
[0, 10000] loss: 1.591
[0, 11000] loss: 1.389
[0, 12000] loss: 1.622
[0, 13000] loss: 1.509
[0, 14000] loss: 1.735
[0, 15000] loss: 1.555
[0, 16000] loss: 1.515
[0, 17000] loss: 1.400
[0, 18000] loss: 1.555
[0, 19000] loss: 1.510
[0, 20000] loss: 1.442
[0, 21000] loss: 1.257
[0, 22000] loss: 1.421
[0, 23000] loss: 1.559
[0, 24000] loss: 1.679
[0, 25000] loss: 1.300
[0, 26000] loss: 1.484
[0, 27000] loss: 1.412
[0, 28000] loss: 1.430
[0, 29000] loss: 1.485
[0, 30000] loss: 1.535
[0, 31000] loss: 1.257
[0, 32000] loss: 1.665
[0, 33000] loss: 1.311
[0, 34000] loss: 1.318
[0, 35000] loss: 1.240
[0, 36000] loss: 1.389
[0, 37000] loss: 1.364
[0, 38000] loss: 1.378
[0, 39000] loss: 1.479
[0, 40000] loss: 1.112
[0, 41000] loss: 1.428
[0, 42000] loss: 1.410
[0, 43000] 

KeyboardInterrupt: ignored

In [20]:
for epoch in range(2):
    lr_sche.step()
    # training
    for batch_idx, (train_inputs, train_labels) in enumerate(train_loader): # train_inputs size:[32,1,28,28], labels size: [32]
        train_inputs = train_inputs.to(device)
        train_labels = train_labels.to(device)
        
        optimizer.zero_grad()

        WeightOperation.WeightSave()
        WeightOperation.WeightBinarize()

        predicted = model(train_inputs)
        loss = loss_fn(predicted, train_labels)
        loss.backward()     # gradient

        WeightOperation.WeightRestore()
        WeightOperation.WeightGradient()

        optimizer.step()    # update
      

        if((batch_idx*len(train_inputs))%1000==0):
           print('[%d, %5d] loss: %.3f' %(epoch, batch_idx*len(train_inputs), loss.item()))  # loss: loss tensor(2.3027, grad_fn=<NllLossBackward>)




[0,     0] loss: 0.879
[0,  1000] loss: 1.092
[0,  2000] loss: 0.917
[0,  3000] loss: 0.767
[0,  4000] loss: 0.925
[0,  5000] loss: 0.985
[0,  6000] loss: 0.851
[0,  7000] loss: 0.850
[0,  8000] loss: 0.983
[0,  9000] loss: 0.970
[0, 10000] loss: 0.799
[0, 11000] loss: 0.871
[0, 12000] loss: 0.932
[0, 13000] loss: 0.805
[0, 14000] loss: 0.990
[0, 15000] loss: 0.902
[0, 16000] loss: 0.891
[0, 17000] loss: 0.770
[0, 18000] loss: 0.948
[0, 19000] loss: 0.965
[0, 20000] loss: 0.873
[0, 21000] loss: 0.801
[0, 22000] loss: 1.074
[0, 23000] loss: 0.990
[0, 24000] loss: 0.918
[0, 25000] loss: 0.936
[0, 26000] loss: 0.801
[0, 27000] loss: 0.966
[0, 28000] loss: 0.849
[0, 29000] loss: 0.751
[0, 30000] loss: 0.900
[0, 31000] loss: 0.780
[0, 32000] loss: 0.914
[0, 33000] loss: 0.988
[0, 34000] loss: 0.909
[0, 35000] loss: 0.800
[0, 36000] loss: 0.795
[0, 37000] loss: 0.958
[0, 38000] loss: 0.678
[0, 39000] loss: 0.829
[0, 40000] loss: 0.923
[0, 41000] loss: 0.744
[0, 42000] loss: 0.857
[0, 43000] 

In [27]:
optimizer = optim.Adam(model.parameters(),lr=0.000003)
lr_sche=optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

for epoch in range(2):
    lr_sche.step()
    # training
    for batch_idx, (train_inputs, train_labels) in enumerate(train_loader): # train_inputs size:[32,1,28,28], labels size: [32]
        train_inputs = train_inputs.to(device)
        train_labels = train_labels.to(device)
        
        optimizer.zero_grad()

        WeightOperation.WeightSave()
        WeightOperation.WeightBinarize()

        predicted = model(train_inputs)
        loss = loss_fn(predicted, train_labels)
        loss.backward()     # gradient

        WeightOperation.WeightRestore()
        WeightOperation.WeightGradient()

        optimizer.step()    # update
      

        if((batch_idx*len(train_inputs))%1000==0):
           print('[%d, %5d] loss: %.3f' %(epoch, batch_idx*len(train_inputs), loss.item()))  # loss: loss tensor(2.3027, grad_fn=<NllLossBackward>)



[0,     0] loss: 0.184
[0,  1000] loss: 0.239
[0,  2000] loss: 0.270
[0,  3000] loss: 0.270
[0,  4000] loss: 0.229
[0,  5000] loss: 0.244
[0,  6000] loss: 0.191
[0,  7000] loss: 0.291
[0,  8000] loss: 0.226
[0,  9000] loss: 0.174
[0, 10000] loss: 0.164
[0, 11000] loss: 0.247
[0, 12000] loss: 0.220
[0, 13000] loss: 0.252
[0, 14000] loss: 0.223
[0, 15000] loss: 0.306
[0, 16000] loss: 0.142
[0, 17000] loss: 0.200
[0, 18000] loss: 0.195
[0, 19000] loss: 0.240
[0, 20000] loss: 0.189
[0, 21000] loss: 0.260
[0, 22000] loss: 0.225
[0, 23000] loss: 0.253
[0, 24000] loss: 0.234
[0, 25000] loss: 0.185
[0, 26000] loss: 0.322
[0, 27000] loss: 0.202
[0, 28000] loss: 0.174
[0, 29000] loss: 0.184
[0, 30000] loss: 0.265
[0, 31000] loss: 0.278
[0, 32000] loss: 0.276
[0, 33000] loss: 0.206
[0, 34000] loss: 0.210
[0, 35000] loss: 0.277
[0, 36000] loss: 0.192
[0, 37000] loss: 0.196
[0, 38000] loss: 0.258
[0, 39000] loss: 0.212
[0, 40000] loss: 0.261
[0, 41000] loss: 0.170
[0, 42000] loss: 0.172
[0, 43000] 

In [22]:
torch.save(model,'/content/xnor_con6/model1' )

In [26]:

    # test
correct = 0
WeightOperation.WeightSave()
WeightOperation.WeightBinarize()

for (test_inputs, test_labels) in test_loader:
    test_input = test_inputs.to('cuda')
    test_labels = test_labels.to('cuda')
    predicted = model(test_inputs)
    pred = predicted.data.max(1, keepdim = False)[1] # max(0):column-wise, max(1):row-wise, [0]:values [1]:index
    correct += pred.eq(test_labels.data).sum()

    acc = 100. * correct / len(test_loader.dataset)

    WeightOperation.WeightRestore()

    print('Accuracy:', acc.item())

Accuracy: 0.699999988079071
Accuracy: 1.4199999570846558
Accuracy: 2.109999895095825
Accuracy: 2.7799999713897705
Accuracy: 3.4599997997283936
Accuracy: 4.069999694824219
Accuracy: 4.769999980926514
Accuracy: 5.46999979019165
Accuracy: 6.199999809265137
Accuracy: 6.87999963760376
Accuracy: 7.449999809265137
Accuracy: 8.109999656677246
Accuracy: 8.729999542236328
Accuracy: 9.429999351501465
Accuracy: 10.109999656677246
Accuracy: 10.829999923706055
Accuracy: 11.460000038146973
Accuracy: 12.1899995803833
Accuracy: 12.769999504089355
Accuracy: 13.399999618530273
Accuracy: 14.089999198913574
Accuracy: 14.789999961853027
Accuracy: 15.420000076293945
Accuracy: 16.149999618530273
Accuracy: 16.84000015258789
Accuracy: 17.56999969482422
Accuracy: 18.170000076293945
Accuracy: 18.799999237060547
Accuracy: 19.44999885559082
Accuracy: 20.1299991607666
Accuracy: 20.85999870300293
Accuracy: 21.5
Accuracy: 22.170000076293945
Accuracy: 22.849998474121094
Accuracy: 23.5
Accuracy: 24.189998626708984
Accur