In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.init as init

from pruning.layers import MaskedLinear, MaskedConv2d 
from pruning.methods import filter_prune
from pruning.utils import to_var, prune_rate

import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class polynom_act(nn.Module):

    def __init__(self, alpha=None, beta=None, c=None):
        super(polynom_act, self).__init__()
#         self.alpha = nn.Parameter(torch.randn(1), requires_grad=True)
#         self.beta = nn.Parameter(torch.randn(1), requires_grad=True)
#         self.c = nn.Parameter(torch.randn(1), requires_grad=True)
        
        self.alpha = nn.Parameter(torch.FloatTensor([0.1]), requires_grad=True)
        self.beta = nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
        self.c = nn.Parameter(torch.FloatTensor([0.1]), requires_grad=True)
    def forward(self, x):
        return (self.alpha * (x ** 2) + self.beta * x + self.c)

In [19]:
channel_list=[256,128,64]
#channel_list=[64,128,256]

In [20]:
class hefNet(nn.Module):
    def __init__(self):
        super(hefNet, self).__init__()
        
        activation=nn.ReLU(inplace=True)
        #activation=polynom_act()
        
        self.conv1 = MaskedConv2d(3, channel_list[0], kernel_size=3, padding=0, stride=1)
        nn.init.xavier_uniform(self.conv1.weight)
        
        self.bn1=nn.BatchNorm2d(channel_list[0])
        self.relu1 = activation

        self.conv2 = MaskedConv2d(channel_list[0], channel_list[1], kernel_size=3, stride=1,groups=8)
        nn.init.xavier_uniform(self.conv2.weight)
        
        self.bn2=nn.BatchNorm2d(channel_list[1])
        self.relu2= activation
        self.avgpool1=nn.AvgPool2d(kernel_size=2,stride=2)
        
        self.conv3 = MaskedConv2d(channel_list[1], channel_list[2], kernel_size=3, stride=1,groups=8)
        nn.init.xavier_uniform(self.conv3.weight)
        
        self.bn3=nn.BatchNorm2d(channel_list[2])
        self.relu3= activation
        self.avgpool2=nn.AvgPool2d(kernel_size=2,stride=2)
        
        self.conv4 = MaskedConv2d(channel_list[2], channel_list[2], kernel_size=3, stride=1,groups=8)
        nn.init.xavier_uniform(self.conv4.weight)
        
        self.bn4=nn.BatchNorm2d(channel_list[2])
        self.relu4= activation
        self.avgpool3=nn.AvgPool2d(kernel_size=2,stride=2)
        
        self.dropout=nn.Dropout2d(0.5)
        self.linear1 = nn.Linear(channel_list[2]*2*2,10)
        
    def forward(self, x):
        
        x = self.relu1(self.bn1(self.conv1(x)))
        
        x = self.avgpool1(self.relu1(self.bn2(self.conv2(x))))
        
        x = self.avgpool2(self.relu2(self.bn3(self.conv3(x))))
        
        x = self.avgpool3(self.relu3(self.bn4(self.conv4(x))))
        
        #print(out.shape)
        
        x = x.view(x.size(0),-1)
        x = self.linear1(x)
   
        return x

    def set_masks(self, masks):
        
        self.conv1.set_mask(torch.from_numpy(masks[0]))
        self.conv2.set_mask(torch.from_numpy(masks[1]))
        self.conv3.set_mask(torch.from_numpy(masks[2]))
        self.conv4.set_mask(torch.from_numpy(masks[3]))
        self.linear1.set_mask(torch.from_numpy(masks[4]))

In [21]:
param = {
    'pruning_perc': 50.,
    'batch_size': 128, 
    'test_batch_size': 100,
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'weight_decay': 5e-4,
    'momentum':0.9,
    'amsgrad':True,
}

In [6]:
# Data loaders

transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_dataset = datasets.CIFAR10(root='../data/', train=True, download=True,transform=transform_train)
loader_train = torch.utils.data.DataLoader(train_dataset, batch_size=param['batch_size'], shuffle=True)


test_dataset = datasets.CIFAR10(root='../data/', train=False, transform=transform_test)
loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=param['test_batch_size'], shuffle=False)

Files already downloaded and verified


In [7]:
def train(model, loss_fn, optimizer, param, loader_train, loader_val=None):

    model.train()
    for epoch in range(param['num_epochs']):
        print('Starting epoch %d / %d' % (epoch + 1, param['num_epochs']))

        for t, (x, y) in enumerate(loader_train):
            
            x_var, y_var = to_var(x), to_var(y.long())

            scores = model(x_var)
            loss = loss_fn(scores, y_var)

            if (t + 1) % 100 == 0:

                print('t = %d, loss = %.8f' % (t + 1, loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if (epoch+1) % 10 ==0:
            
            torch.save(model.state_dict(), 'models/hefNet_pretrained'+str(epoch+1)+'.pkl')

In [8]:
def test(model, loader):

    model.eval()
    num_correct, num_samples = 0, len(loader.dataset)
    for x, y in loader:
        x_var = to_var(x, volatile=True)
        scores = model(x_var)
        _, preds = scores.data.cpu().max(1)
        num_correct += (preds == y).sum()

    acc = float(num_correct) / num_samples

    print('Test accuracy: {:.2f}% ({}/{})'.format(
        100.*acc,
        num_correct,
        num_samples,
        ))
    
    return acc

In [22]:
net=hefNet()

  if __name__ == '__main__':
  from ipykernel import kernelapp as app


In [23]:
net.to(device)  #level

hefNet(
  (conv1): MaskedConv2d(3, 256, kernel_size=(3, 3), stride=(1, 1))
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace=True)
  (conv2): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), groups=8)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU(inplace=True)
  (avgpool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv3): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), groups=8)
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace=True)
  (avgpool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv4): MaskedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), groups=8)
  (bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu4): ReLU(inplace=True)
  (avgpool3): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (dropout): Dropout2d(p=0.5, inp

In [131]:
pretrained_dict=torch.load('./models/hefNet_pretrained30.pkl')
model_dict=net.state_dict()
pretrained_dict={k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

<All keys matched successfully>

In [24]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(net.parameters(), lr=param['learning_rate'], weight_decay=param['weight_decay'], momentum=param['momentum'])

#optimizer = torch.optim.Adam(net.parameters(), lr=param['learning_rate'], weight_decay=param['weight_decay'], amsgrad=param['amsgrad'])

train(net, criterion, optimizer, param, loader_train)

Starting epoch 1 / 100
t = 100, loss = 2.21485257
t = 200, loss = 2.14664006
t = 300, loss = 2.13221741
Starting epoch 2 / 100
t = 100, loss = 2.04743052
t = 200, loss = 1.98518133
t = 300, loss = 1.97212470
Starting epoch 3 / 100
t = 100, loss = 1.93884635
t = 200, loss = 1.96644759
t = 300, loss = 1.87321281
Starting epoch 4 / 100
t = 100, loss = 1.88484323
t = 200, loss = 1.92369521
t = 300, loss = 1.80840874
Starting epoch 5 / 100
t = 100, loss = 1.85863173
t = 200, loss = 1.75637519
t = 300, loss = 1.71697497
Starting epoch 6 / 100
t = 100, loss = 1.77851784
t = 200, loss = 1.72398472
t = 300, loss = 1.71608949
Starting epoch 7 / 100
t = 100, loss = 1.63726187
t = 200, loss = 1.72326267
t = 300, loss = 1.70693493
Starting epoch 8 / 100
t = 100, loss = 1.58367813
t = 200, loss = 1.61768746
t = 300, loss = 1.58114171
Starting epoch 9 / 100
t = 100, loss = 1.53206801
t = 200, loss = 1.53416073
t = 300, loss = 1.67399013
Starting epoch 10 / 100
t = 100, loss = 1.63632953
t = 200, loss

t = 100, loss = 0.93224657
t = 200, loss = 0.85200053
t = 300, loss = 0.95617455
Starting epoch 80 / 100
t = 100, loss = 0.94676054
t = 200, loss = 0.96045387
t = 300, loss = 1.05673158
Starting epoch 81 / 100
t = 100, loss = 0.91606289
t = 200, loss = 1.03754199
t = 300, loss = 0.85887033
Starting epoch 82 / 100
t = 100, loss = 0.96663922
t = 200, loss = 0.89596587
t = 300, loss = 0.95280510
Starting epoch 83 / 100
t = 100, loss = 0.88796335
t = 200, loss = 0.83399427
t = 300, loss = 0.93300295
Starting epoch 84 / 100
t = 100, loss = 0.88030660
t = 200, loss = 1.01327777
t = 300, loss = 0.96030098
Starting epoch 85 / 100
t = 100, loss = 0.99084955
t = 200, loss = 0.94287717
t = 300, loss = 0.95680016
Starting epoch 86 / 100
t = 100, loss = 0.87775993
t = 200, loss = 0.98804933
t = 300, loss = 0.80514556
Starting epoch 87 / 100
t = 100, loss = 0.85324484
t = 200, loss = 0.96675545
t = 300, loss = 1.01774049
Starting epoch 88 / 100
t = 100, loss = 1.04847884
t = 200, loss = 0.90897244
t

In [25]:
test(net, loader_test)

Test accuracy: 70.36% (7036/10000)


0.7036