In [1]:
from __future__ import print_function
import argparse,random
from math import log10
%pylab
%matplotlib inline
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from utils.data import get_training_set, get_test_set
from torch.nn.modules.module import _addindent

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


`%matplotlib` prevents importing * from pylab and numpy
  "\n`%matplotlib` prevents importing * from pylab and numpy"


In [2]:
# model 1을 사용할시에
from net.model import Net
# model 2을 사용할시에
#from net.model_dw import Net

In [3]:
# cuda import
cuda = True
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")
torch.manual_seed(random.randint(1,1000))
if cuda:
    torch.cuda.manual_seed(random.randint(1,1000))

In [4]:
# dataset import
train_set = get_training_set(2,"BSDS300")
test_set = get_test_set(2,"BSDS300")
training_data_loader = DataLoader(dataset=train_set, num_workers=6, batch_size=16, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=6, batch_size=100, shuffle=False)

In [5]:
alpla=64
beta=64
gamma=32
pruning=0
upscale_factor=2
train=Net(upscale_factor)
weight_name='weight1'

In [6]:
# load weigh
model=torch.load(weight_name)
keys=model.keys()

# Architecture
___

```
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2=nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3=nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
```

<br>
<br>
__upscale_factor는 2 이다__

# Pruning method
___

1. Pruning을 할 수 있는 spot을 선택 
1. 1개를 선택하여 prunging 한 후에 Net의 아키텍쳐를 바꾼다.
1. L1 norm 을 기준으로 pruning 한다.
1. PSNR 값을 구한다.
1. retraining을 진행.
1. 반복(일정 PSNR 이하 로 내려가기전까지)

In [7]:
criterion = nn.MSELoss()
def modify(net,alpla=64,beta=64,gamma=32):
    #global alpla
    #global gamma
    #global beta
    net.conv1 = nn.Conv2d(1, alpla, (5, 5), (1, 1), (2, 2))
    net.conv2=nn.Conv2d(alpla, beta, (3, 3), (1, 1), (1, 1))
    net.conv3=nn.Conv2d(beta, gamma, (3, 3), (1, 1), (1, 1))
    net.conv4 = nn.Conv2d(gamma, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
    
def PSNR(net):
    avg_psnr = 0
    for batch in testing_data_loader:
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda()
            target=target.cuda()
        prediction = net(input)
        mse = criterion(prediction, target)
        psnr = 10 * log10(1 / mse.data[0])
        avg_psnr += psnr
    return avg_psnr/len(testing_data_loader)

def _train(net,epoch=50):
    #net=nn.DataParallel(net) 이렇게 하는게 더 느리다 % time으로 확인
    optimizer = optim.Adam(net.parameters(), lr=0.005)
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        loss = criterion(net(input), target)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
        #print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))
    if epoch%10 is 0:
        print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))

In [8]:
def cal_pruning(net,alpla=64,beta=64,gamma=32,retrain=True):
    global weight_name
    # train을 넣으면 뎀
    # 네트워크의 파라미터 값과 pruning 하고자하는 네트워크의 instance를 넣는다
    # 그러면 weight를 load하여 모든 weight를 pruning 하여 psnr 값과 weight의 abs sum을 구해서 반환한다.
    # pruning 후 retraing 을 한다.
    # 만약 weight를 1개 이상 pruning 시에 'weight_name'을 반드시 신경써야한다.
    print("===> Starting Calculate Pruning")
    _model=torch.load(weight_name)
    keys_list=list(_model.keys())
    #print(keys_list)
    # ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias']
    psnr_list=[]
    for i in range(0,len(keys_list)-2,2):
        if i is 0:
            alpla-=1
            modify(net,alpla,beta,gamma)
        elif i is 2:
            alpla+=1
            beta-=1
            modify(net,alpla,beta,gamma)
        elif i is 4:
            beta+=1
            gamma-=1
            modify(net,alpla,beta,gamma)
            pruning+=1
            _model=torch.load(weight_name)
        for j in range(len(_model[keys_list[i]])):
            _model=torch.load(weight_name)
            weight_matrix=_model[keys_list[i]]
            bias_matrix=_model[keys_list[i+1]]
            temp_weight=0
            if j is 0:
                temp_weight=weight_matrix[0].abs().sum()
                _model[keys_list[i]]=weight_matrix[1:len(_model[keys_list[i]])]
                _model[keys_list[i+1]]=bias_matrix[1:len(_model[keys_list[i]])+1]        
            elif j is len(_model[keys_list[i]])-1:
                temp_weight=weight_matrix[len(model[keys_list[i]])-1].abs().sum()
                _model[keys_list[i]]=weight_matrix[0:len(_model[keys_list[i]])-1]                    
                _model[keys_list[i+1]]=bias_matrix[0:len(_model[keys_list[i]])]
            else:
                temp_weight=weight_matrix[j].abs().sum()
                _model[keys_list[i]]=torch.cat((weight_matrix[0:j],weight_matrix[j+1:len(_model[keys_list[i]])]))
                _model[keys_list[i+1]]=torch.cat((bias_matrix[0:j],bias_matrix[j+1:len(_model[keys_list[i]])+1]))
            if i is 0:
                _model[keys_list[i+2]].resize_(alpla,beta,3,3)
            elif i is 2:
                _model[keys_list[i+2]].resize_(beta,gamma,3,3)
            elif i is 4:
                _model[keys_list[i+2]].resize_(gamma,upscale_factor ** 2,3,3)
            net.load_state_dict(_model)
            net=net.cuda()
            if retrain is True:
                for k in range(0,200):
                    _train(net,k)
            prnr=PSNR(net)
            print('conv:',i,' num:',j,' psnr:',prnr,' size:',temp_weight)
            psnr_list.append(tuple([i,j,prnr,temp_weight]))
    return psnr_list

In [9]:
def _pruning(net):
    keys_list=list(keys)
    #print(keys_list)
    # ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias']
    psnr_list=[]
    pruning+=1
    for i in range(0,len(keys_list)-1,2):
        if i is 0:
            alpla-=pruning
            modify(train)
        elif i is 2:
            beta-=pruning
            modify(train)
        elif i is 4:
            gamma-=pruning
            modify(train)
            pruning+=1
        for j in range(len(model[keys_list[i]])):
            model=torch.load('weight1')
            weight_matrix=model[keys_list[i]]
            temp_weight=0
            bias_matrix=model[keys_list[i+1]]
            if j is 0:
                
                model[keys_list[i]]=weight_matrix[1:len(model[keys_list[i]])]
                temp_weight=weight_matrix[0].abs().sum()
                model[keys_list[i+1]]=bias_matrix[1:len(model[keys_list[i]])+1]        
            elif j is len(model[keys_list[i]])-1:
                temp_weight=weight_matrix[len(model[keys_list[i]])-1].abs().sum()
                model[keys_list[i]]=weight_matrix[0:len(model[keys_list[i]])-1]                    
                model[keys_list[i+1]]=bias_matrix[0:len(model[keys_list[i]])]
            else:
                temp_weight=weight_matrix[j].abs().sum()
                model[keys_list[i]]=torch.cat((weight_matrix[0:j],weight_matrix[j+1:len(model[keys_list[i]])]))
                model[keys_list[i+1]]=torch.cat((bias_matrix[0:j],bias_matrix[j+1:len(model[keys_list[i]])+1]))
            if i is 0:
                model[keys_list[i+2]].resize_(alpla,beta,3,3)
            elif i is 2:
                model[keys_list[i+2]].resize_(beta,gamma,3,3)
            elif i is 4:
                model[keys_list[i+2]].resize_(gamma,upscale_factor ** 2,3,3)
            train.load_state_dict(model)
            train=train.cuda()
            for k in range(0,100):
                _train(train,k)
            prnr=PSNR(train)
            print('conv:',i,' num:',j,' psnr:',prnr,' size:',temp_weight)
            psnr_list.append(tuple([i,j,prnr,temp_weight]))

In [10]:
def calculate(net):
    pass

In [11]:
def retrain(net):
    pass

In [12]:
def L2_distance(net):
    pass

In [None]:
cal_pruning(train)

777


  own_state[name].copy_(param)


===> Epoch 0 Complete: Avg. Loss: 0.2138
===> Epoch 10 Complete: Avg. Loss: 0.0311
===> Epoch 20 Complete: Avg. Loss: 0.0145
===> Epoch 30 Complete: Avg. Loss: 0.0218
===> Epoch 40 Complete: Avg. Loss: 0.0077
===> Epoch 50 Complete: Avg. Loss: 0.0104
===> Epoch 60 Complete: Avg. Loss: 0.0062
===> Epoch 70 Complete: Avg. Loss: 0.0062
===> Epoch 80 Complete: Avg. Loss: 0.0045
===> Epoch 90 Complete: Avg. Loss: 0.0054
===> Epoch 100 Complete: Avg. Loss: 0.0037
===> Epoch 110 Complete: Avg. Loss: 0.0064
===> Epoch 120 Complete: Avg. Loss: 0.0071
===> Epoch 130 Complete: Avg. Loss: 0.0035
===> Epoch 140 Complete: Avg. Loss: 0.0043
===> Epoch 150 Complete: Avg. Loss: 0.0045
===> Epoch 160 Complete: Avg. Loss: 0.0035
===> Epoch 170 Complete: Avg. Loss: 0.0040
===> Epoch 180 Complete: Avg. Loss: 0.0063
===> Epoch 190 Complete: Avg. Loss: 0.0062
conv: 0  num: 0  psnr: 26.19996456487369  size: 2.926562786102295
===> Epoch 0 Complete: Avg. Loss: 0.2235
===> Epoch 10 Complete: Avg. Loss: 0.0116
==