In [31]:
from __future__ import print_function
import argparse,random
from math import log10
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from utils.data import get_training_set, get_test_set
from torch.nn.modules.module import _addindent
from pandas import DataFrame
import pandas as pd

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


`%matplotlib` prevents importing * from pylab and numpy
  "\n`%matplotlib` prevents importing * from pylab and numpy"


In [2]:
# model 1을 사용할시에
from net.model import Net
# model 2을 사용할시에
#from net.model_dw import Net

In [3]:
# cuda import
cuda = True
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")
torch.manual_seed(random.randint(1,1000))
if cuda:
    torch.cuda.manual_seed(random.randint(1,1000))

In [4]:
# dataset import
train_set = get_training_set(2,"BSDS300")
test_set = get_test_set(2,"BSDS300")
training_data_loader = DataLoader(dataset=train_set, num_workers=6, batch_size=16, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=6, batch_size=100, shuffle=False)

In [5]:

upscale_factor=2
train=Net(upscale_factor)
weight_name='weight1'

In [6]:
# load weigh
model=torch.load(weight_name)
keys=model.keys()

# Architecture
___

```
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2=nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3=nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
```

<br>
<br>
__upscale_factor는 2 이다__

# Pruning method
___

1. Pruning을 할 수 있는 spot을 선택 
1. 1개를 선택하여 prunging 한 후에 Net의 아키텍쳐를 바꾼다.
1. L1 norm 을 기준으로 pruning 한다.
1. PSNR 값을 구한다.
1. retraining을 진행.
1. 반복(일정 PSNR 이하 로 내려가기전까지)

In [7]:
criterion = nn.MSELoss()
def modify(net,alpha=64,beta=64,gamma=32):
    """
    네트워크와 하이퍼 파라미터를 입력받아
    네트워크를 하이퍼 파라미터에 맞게 변형한다.
    """
    net.conv1 = nn.Conv2d(1, alpha, (5, 5), (1, 1), (2, 2))
    net.conv2=nn.Conv2d(alpha, beta, (3, 3), (1, 1), (1, 1))
    net.conv3=nn.Conv2d(beta, gamma, (3, 3), (1, 1), (1, 1))
    net.conv4 = nn.Conv2d(gamma, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
    
def PSNR(net):
    '''
    네트워크 를 받아서 psnr 값을 구하여 반환한다.
    테스트 배치 만큼 수행 평균
    '''
    avg_psnr = 0
    for batch in testing_data_loader:
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda()
            target=target.cuda()
        prediction = net(input)
        mse = criterion(prediction, target)
        psnr = 10 * log10(1 / mse.data[0])
        avg_psnr += psnr
    return avg_psnr/len(testing_data_loader)

def _train(net,epoch=50):
    '''
    네트워크와 에폭,학습률을 입력 받아 그에 맞게 학습시킨다.
    '''
    #net=nn.DataParallel(net) 이렇게 하는게 더 느리다 % time으로 확인
    optimizer = optim.Adam(net.parameters(), lr=0.005)
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        loss = criterion(net(input), target)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
        #print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))
    if epoch%10 is 0:
        print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))

def rank(_list):
    '''
    _list를 입력받아 계산한뒤 pruning 할 (conv,weight)의 위치를 tuple로 반환한다.
    pruning 할 layer를 고른다.
    기준은 L2와 PSNR 을 기준으로 rank를 만든 뒤 합하여 결정한다.
    '''
    df=DataFrame(_list,columns=['conv','weight','psnr','L2'])
    df['rank']=df['psnr'].rank( ascending=False,method='max')+df['L2'].rank(method='max')
    for idx,i in enumerate(df["L2"]):
        if i == min(df["L2"]):
            print(df['conv'][idx],"th conv's ",df['weight'][idx],"'s layer will pruning")
            return(df['conv'][idx],df['weight'][idx])

In [8]:
def cal_pruning(net,alpha=64,beta=64,gamma=32,retrain=True,epoch=200):
    global weight_name
    # train을 넣으면 뎀
    # 네트워크의 파라미터 값과 pruning 하고자하는 네트워크의 instance를 넣는다
    # 그러면 weight를 load하여 모든 weight를 pruning 하여 psnr 값과 weight의 abs sum을 구해서 반환한다.
    # pruning 후 retraing 을 한다.
    # 만약 weight를 1개 이상 pruning 시에 'weight_name'을 반드시 신경써야한다.
    print("===> Starting Calculate Pruning")
    _model=torch.load(weight_name)
    keys_list=list(_model.keys())
    #print(keys_list)
    # ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias']
    psnr_list=[]
    for i in range(0,len(keys_list)-2,2):
        if i is 0:
            alpha-=1
            modify(net,alpha,beta,gamma)
        elif i is 2:
            alpha+=1
            beta-=1
            modify(net,alpha,beta,gamma)
        elif i is 4:
            beta+=1
            gamma-=1
            modify(net,alpha,beta,gamma)
            _model=torch.load(weight_name)
        for j in range(len(_model[keys_list[i]])):
            _model=torch.load(weight_name)
            weight_matrix=_model[keys_list[i]]
            bias_matrix=_model[keys_list[i+1]]
            temp_weight=0
            if j is 0:
                temp_weight=weight_matrix[0].abs().sum()
                _model[keys_list[i]]=weight_matrix[1:len(_model[keys_list[i]])]
                _model[keys_list[i+1]]=bias_matrix[1:len(_model[keys_list[i]])+1]        
            elif j is len(_model[keys_list[i]])-1:
                temp_weight=weight_matrix[len(model[keys_list[i]])-1].abs().sum()
                _model[keys_list[i]]=weight_matrix[0:len(_model[keys_list[i]])-1]                    
                _model[keys_list[i+1]]=bias_matrix[0:len(_model[keys_list[i]])]
            else:
                temp_weight=weight_matrix[j].abs().sum()
                _model[keys_list[i]]=torch.cat((weight_matrix[0:j],weight_matrix[j+1:len(_model[keys_list[i]])]))
                _model[keys_list[i+1]]=torch.cat((bias_matrix[0:j],bias_matrix[j+1:len(_model[keys_list[i]])+1]))
            if i is 0:
                _model[keys_list[i+2]].resize_(alpha,beta,3,3)
            elif i is 2:
                _model[keys_list[i+2]].resize_(beta,gamma,3,3)
            elif i is 4:
                _model[keys_list[i+2]].resize_(gamma,upscale_factor ** 2,3,3)
            net.load_state_dict(_model)
            net=net.cuda()
            if retrain is True:
                for k in range(0,epoch):
                    _train(net,k)
            prnr=PSNR(net)
            print('conv:',i,' num:',j,' psnr:',prnr,' size:',temp_weight)
            psnr_list.append(tuple([i,j,prnr,temp_weight]))
    return psnr_list

In [9]:
def _pruning(_model,_weight,index,epoch=50):
    """
    _model: 모델을 받는다.
    _weight:weight값을 받는다.
    index: pruning을 시행할 convolution 의 index 값을 받는다.
    epoch: pruning후 retrain 할 횟수
    """
    keys_list=list(_weight.keys())
    temp=[]
    for i in range(0,len(keys_list)-2,2):
        temp.append(len(_weight[keys_list[i]]))
    alpha,beta,gamma=temp
    
    print('temp',temp,'alpha',alpha,'beta',beta,'gamma',gamma)
    i=0
    j=index[1]
    
    if index[0] is 1:
        i=0
        alpha-=1
        modify(_train,alpha,beta,gamma)
    elif index[0] is 2:
        i=2
        beta-=1
        modify(_train,alpha,beta,gamma)
    elif index[0] is 3:
        i=4
        gamma-=1
        modify(_train,alpha,beta,gamma)
    if j>len(_weight[keys_list[i]]):
        print("illegal pruning")
        return
    weight_matrix=_weight[keys_list[i]]
    bias_matrix=_weight[keys_list[i+1]]
    if index[1] is 0:
        _weight[keys_list[i]]=weight_matrix[1:len(_weight[keys_list[i]])]
        _weight[keys_list[i+1]]=bias_matrix[1:len(_weight[keys_list[i]])+1]
    elif j is len(model[keys_list[i]])-1:
        _weight[keys_list[i]]=weight_matrix[0:len(_weight[keys_list[i]])-1]
        _weight[keys_list[i+1]]=bias_matrix[0:len(_weight[keys_list[i]])]
    else:
        _weight[keys_list[i]]=torch.cat((weight_matrix[0:j],weight_matrix[j+1:len(_weight[keys_list[i]])]))
        _weight[keys_list[i+1]]=torch.cat((bias_matrix[0:j],bias_matrix[j+1:len(_weight[keys_list[i]])+1]))
    if i is 0:
        _weight[keys_list[i+2]].resize_(alpha,beta,3,3)
    elif i is 2:
        _weight[keys_list[i+2]].resize_(beta,gamma,3,3)
    elif i is 4:
        _weight[keys_list[i+2]].resize_(gamma,upscale_factor ** 2,3,3)
    _model=_model.cuda()
    for k in range(0,epoch):
        _train(_model,k)
    prnr=PSNR(_model)
    print('conv:',i,' num:',j,'is pruning psnr:',prnr)    

In [26]:
_list=[(0, 0, 20.77661644506897, 2.926562786102295), (0, 1, 20.41768770453681, 3.5600385665893555), (0, 2, 21.549112205585175, 2.7003331184387207), (0, 3, 18.773605118373368, 3.595670461654663), (0, 4, 19.113337161462688, 3.4622962474823), (0, 5, 18.17286976993348, 3.3184256553649902), (0, 6, 20.490117540226827, 4.114063739776611), (0, 7, 18.82515003003567, 3.536667823791504), (0, 8, 18.47267363819492, 3.7133820056915283), (0, 9, 20.531406996545527, 3.8943541049957275), (0, 10, 20.932796348766146, 3.2882728576660156), (0, 11, 18.503472247654425, 3.313321113586426), (0, 12, 21.339368770946724, 3.788851499557495), (0, 13, 17.689891602408903, 3.192880392074585), (0, 14, 16.23228693744893, 3.399325370788574), (0, 15, 20.157074926646633, 2.5385217666625977), (0, 16, 20.765971907352263, 3.666475534439087), (0, 17, 19.462767336023347, 3.028347969055176), (0, 18, 19.791597683117182, 3.0369980335235596), (0, 19, 19.639318127883687, 3.623494863510132), (0, 20, 16.311553520967813, 4.507632732391357), (0, 21, 14.523354704676274, 4.4181718826293945), (0, 22, 20.913047971014787, 3.20487642288208), (0, 23, 20.447936083034165, 3.7541511058807373), (0, 24, 18.19479218971695, 3.6289429664611816), (0, 25, 17.954477342464198, 3.1502187252044678), (0, 26, 14.195130846601486, 4.05465841293335), (0, 27, 20.563746834062464, 3.0270135402679443), (0, 28, 20.027483126400934, 3.27731990814209), (0, 29, 14.113506024629734, 3.9907984733581543), (0, 30, 20.42573416450304, 4.119686126708984), (0, 31, 14.689463874664684, 3.4438977241516113), (0, 32, 14.790337829936288, 3.987121343612671), (0, 33, 13.333680565835406, 4.550487518310547), (0, 34, 16.63787669518929, 3.6898210048675537), (0, 35, 19.58542680623781, 3.15950345993042), (0, 36, 16.918825026736112, 3.5656967163085938), (0, 37, 20.032690201173565, 3.4216361045837402), (0, 38, 19.108164090114713, 2.4204752445220947), (0, 39, 19.860290665880022, 7.463708400726318), (0, 40, 14.384238523570883, 3.291673183441162), (0, 41, 19.720079122526922, 3.303051710128784), (0, 42, 17.515932842870395, 2.8272957801818848), (0, 43, 18.972669597411997, 2.9348325729370117), (0, 44, 17.770815466189287, 3.683912992477417), (0, 45, 19.246577581747363, 2.943221092224121), (0, 46, 18.94378585033919, 3.127673387527466), (0, 47, 20.54100007820095, 3.9300012588500977), (0, 48, 19.04955246197035, 4.009859085083008), (0, 49, 20.558149001425868, 3.2004384994506836), (0, 50, 20.127893556941288, 3.247309446334839), (0, 51, 20.0509792227718, 3.6704821586608887), (0, 52, 20.963362288971254, 3.3864877223968506), (0, 53, 18.97732953534056, 2.581829309463501), (0, 54, 19.089153273514494, 4.3378520011901855), (0, 55, 18.680672036490517, 3.335524797439575), (0, 56, 12.75876394909985, 3.4207117557525635), (0, 57, 21.33518581889275, 3.569721221923828), (0, 58, 20.90141303412857, 3.6195387840270996), (0, 59, 18.048461014275837, 2.8329243659973145), (0, 60, 19.39379577409051, 3.5531005859375), (0, 61, 16.606880597875342, 3.13732647895813), (0, 62, 19.651043312346324, 4.168920040130615), (0, 63, 20.355940304883667, 3.7978355884552), (2, 0, 19.84801097224956, 61.58285140991211), (2, 1, 18.29354919437869, 43.41734313964844), (2, 2, 21.437423943726458, 27.323036193847656), (2, 3, 20.455409031104086, 71.69295501708984), (2, 4, 17.0305104164365, 49.75033950805664), (2, 5, 18.762330702209844, 27.425106048583984), (2, 6, 19.274372895716123, 57.3931884765625), (2, 7, 20.52870080690264, 51.70585632324219), (2, 8, 20.41645884894605, 49.07394027709961), (2, 9, 18.321255392820845, 27.537050247192383), (2, 10, 19.501449172037564, 70.44744110107422), (2, 11, 19.637076367838212, 40.71329879760742), (2, 12, 17.66924119483204, 48.57335662841797), (2, 13, 14.66911976284413, 32.46393585205078), (2, 14, 17.54124475907022, 54.75020980834961), (2, 15, 21.149967980767848, 62.5579948425293), (2, 16, 16.344787927436997, 29.935394287109375), (2, 17, 15.017591931291461, 51.999637603759766), (2, 18, 20.049288555520235, 52.67927551269531), (2, 19, 19.332250259487747, 32.18766403198242), (2, 20, 19.9910874876099, 32.61397933959961), (2, 21, 17.611382815795302, 47.3238639831543), (2, 22, 20.311692160041332, 27.193157196044922), (2, 23, 18.006809300320548, 59.52326965332031), (2, 24, 21.93169867182915, 28.06719207763672), (2, 25, 14.827178323718995, 58.91081237792969), (2, 26, 20.344084682327743, 28.05398178100586), (2, 27, 16.80306685896461, 44.767005920410156), (2, 28, 20.697745290284615, 30.19063377380371), (2, 29, 20.843417953631764, 67.73876953125), (2, 30, 18.038150113648967, 48.499996185302734), (2, 31, 19.151107185862816, 45.57038497924805), (2, 32, 18.862774723055463, 58.91701889038086), (2, 33, 21.320642643533866, 27.56694221496582), (2, 34, 17.616544494252498, 28.709367752075195), (2, 35, 19.42758640620747, 29.082334518432617), (2, 36, 17.5640640377351, 58.85453414916992), (2, 37, 19.60980443249504, 40.0869140625), (2, 38, 21.054781870494644, 29.688751220703125), (2, 39, 19.261350628241235, 51.473243713378906), (2, 40, 21.998663574800965, 29.55811882019043), (2, 41, 16.764795041539777, 30.073591232299805), (2, 42, 20.6674441687555, 29.600261688232422), (2, 43, 19.863221692115342, 26.521034240722656), (2, 44, 21.933867138144386, 54.325965881347656), (2, 45, 17.282180968392808, 70.47332000732422), (2, 46, 20.582060193683805, 62.675682067871094), (2, 47, 21.628978277272655, 29.943920135498047), (2, 48, 19.827257177355197, 60.13360595703125), (2, 49, 20.052616916919852, 27.96942710876465), (2, 50, 21.70877665301356, 27.981142044067383), (2, 51, 19.845638980271808, 31.30326271057129), (2, 52, 20.37056211052075, 58.34745407104492), (2, 53, 22.286414140046993, 49.973365783691406), (2, 54, 17.391261775481485, 26.96883201599121), (2, 55, 21.350673412371826, 60.6932487487793), (2, 56, 18.340536303548944, 49.263763427734375), (2, 57, 19.053739083452392, 30.553173065185547), (2, 58, 21.728673967604077, 37.83761978149414), (2, 59, 21.020756192923464, 56.79644775390625), (2, 60, 21.588592981387563, 56.2254638671875), (2, 61, 19.50795896634498, 27.578920364379883), (2, 62, 19.159659867256867, 28.392845153808594), (4, 0, 19.871072444584467, 60.41963577270508), (4, 1, 21.5413794288299, 27.628808975219727), (4, 2, 18.668293080360947, 70.2166748046875), (4, 3, 20.818288208370692, 31.705013275146484), (4, 4, 17.38612512760871, 30.10431671142578), (4, 5, 18.67889838336065, 27.371559143066406), (4, 6, 19.324197187854637, 38.51417922973633), (4, 7, 13.7493529393428, 32.97629165649414), (4, 8, 16.615573060918674, 27.24710464477539), (4, 9, 17.565734177644178, 27.652854919433594), (4, 10, 19.858207513791243, 35.70361328125), (4, 11, 14.947060847668794, 53.870113372802734), (4, 12, 21.754126969475884, 51.15950393676758), (4, 13, 20.052216134386995, 41.39530944824219), (4, 14, 19.952340357880406, 56.67424011230469), (4, 15, 17.543991877899423, 58.58237075805664), (4, 16, 20.105585205197983, 68.0105209350586), (4, 17, 19.44102192173052, 42.417171478271484), (4, 18, 19.576666243010493, 47.270389556884766), (4, 19, 19.636329371562972, 29.60125160217285), (4, 20, 15.602501271329185, 47.81952667236328), (4, 21, 14.678765775037535, 27.569072723388672), (4, 22, 18.199114800571447, 81.61869812011719), (4, 23, 19.282585689567494, 34.710174560546875), (4, 24, 18.693415685947162, 29.281709671020508), (4, 25, 20.005635567936373, 75.50933074951172), (4, 26, 19.75584278749047, 27.864791870117188), (4, 27, 16.921454887362163, 30.664382934570312), (4, 28, 20.42240778240071, 28.4443359375), (4, 29, 22.525673032530854, 34.72040557861328), (4, 30, 20.55029897855336, 29.055706024169922), (4, 31, 17.023319499242422, 34.85955047607422)]

In [16]:
_pruning(train,model,(1,3),epoch=50)

temp [64, 64, 32] alpha 64 beta 64 gamma 32
===> Epoch 0 Complete: Avg. Loss: 0.1044
===> Epoch 10 Complete: Avg. Loss: 0.0107
===> Epoch 20 Complete: Avg. Loss: 0.0135
===> Epoch 30 Complete: Avg. Loss: 0.0080
===> Epoch 40 Complete: Avg. Loss: 0.0123
conv: 0  num: 3 is pruning psnr: 23.842403480265723
