In [1]:
from __future__ import print_function
import argparse,random
from math import log10
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from utils.data import get_training_set, get_test_set
from torch.nn.modules.module import _addindent
from pandas import DataFrame
import pandas as pd
from collections import OrderedDict
from copy import deepcopy
import quant

In [2]:
# model 1을 사용할시에
#from net.model import Net
# model 2을 사용할시에
from net.model_dw import Net

In [3]:
# cuda import
cuda = True
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")
torch.manual_seed(random.randint(1,1000))
if cuda:
    torch.cuda.manual_seed(random.randint(1,1000))

In [4]:
# dataset import
train_set = get_training_set(2,"BSDS300")
test_set = get_test_set(2,"BSDS300")
training_data_loader = DataLoader(dataset=train_set, num_workers=11, batch_size=16, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=10, batch_size=100, shuffle=False)

In [5]:
upscale_factor=2
train=Net(upscale_factor)
weight_name='weight2'

In [6]:
# load weigh
model=torch.load(weight_name)
keys=model.keys()

# Architecture
___

```
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2=nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3=nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
```

<br>
<br>
__upscale_factor는 2 이다__

<br><br>
####  weight1 초기 PSNR

```
===> Avg. PSNR: 27.7404 dB
```

# Pruning method
___

1. Pruning을 할 수 있는 spot을 선택 
1. 1개를 선택하여 prunging 한 후에 Net의 아키텍쳐를 바꾼다.
1. L1 norm 을 기준으로 pruning 한다.
1. PSNR 값을 구한다.
1. retraining을 진행.
1. 반복(일정 PSNR 이하 로 내려가기전까지)





In [7]:
def bit_truncation(model,bits=8):
    quant_method='linear'
    overflow_rate=0.0
    state_dict = deepcopy(model)
    state_dict_quant = OrderedDict()
    sf_dict = OrderedDict()
    for k, v in state_dict.items():
        if bits >=32:
            print("Ignoring {}".format(k))
            state_dict_quant[k] = v
            continue
        if quant_method == 'linear':
            sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=overflow_rate)
            v_quant  = quant.linear_quantize(v, sf, bits=bits)
        elif args.quant_method == 'log':
            v_quant = quant.log_minmax_quantize(v, bits=bits)
        elif args.quant_method == 'minmax':
            v_quant = quant.min_max_quantize(v, bits=bits)
        else:
            v_quant = quant.tanh_quantize(v, bits=bits)
        state_dict_quant[k] = v_quant
    return state_dict_quant
criterion = nn.MSELoss()
def PSNR(net):
    '''
    네트워크 를 받아서 psnr 값을 구하여 반환한다.
    테스트 배치 만큼 수행 평균
    '''
    avg_psnr = 0
    for batch in testing_data_loader:
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda()
            target=target.cuda()
        prediction = net(input)
        mse = criterion(prediction, target)
        psnr = 10 * log10(1 / mse.data[0])
        avg_psnr += psnr
    return avg_psnr/len(testing_data_loader)
def _train(net):
    '''
    네트워크와 에폭,학습률을 입력 받아 그에 맞게 학습시킨다.
    '''
    #net=nn.DataParallel(net) 이렇게 하는게 더 느리다 % time으로 확인
    optimizer = optim.Adam(net.parameters(), lr=0.005)
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        loss = criterion(net(input), target)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
    #print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch_loss / len(training_data_loader)))

def test(net):
    avg_psnr = 0
    for batch in testing_data_loader:
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            net=net.cuda()
            input = input.cuda()
            target = target.cuda()
        prediction = net(input)
        mse = criterion(prediction, target)
        psnr = 10 * log10(1 / mse.data[0])
        avg_psnr += psnr
        print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))

In [8]:
def bit_truncation(model,bits=8):
    quant_method='linear'
    overflow_rate=0.0
    state_dict = deepcopy(model)
    state_dict_quant = OrderedDict()
    sf_dict = OrderedDict()
    for k, v in state_dict.items():
        if bits >=32:
            print("Ignoring {}".format(k))
            state_dict_quant[k] = v
            continue
        if quant_method == 'linear':
            sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=overflow_rate)
            v_quant  = quant.linear_quantize(v, sf, bits=bits)
        elif args.quant_method == 'log':
            v_quant = quant.log_minmax_quantize(v, bits=bits)
        elif args.quant_method == 'minmax':
            v_quant = quant.min_max_quantize(v, bits=bits)
        else:
            v_quant = quant.tanh_quantize(v, bits=bits)
        state_dict_quant[k] = v_quant
    return state_dict_quant

In [9]:
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(100):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)


학습전
===> Avg. PSNR: 27.5622 dB
학습후
===> Avg. PSNR: 25.0758 dB


In [17]:
model=torch.load(weight_name)
bit=8
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(350):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

학습전
===> Avg. PSNR: 26.9096 dB
학습후
===> Avg. PSNR: 26.5956 dB


In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(400):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

학습전
===> Avg. PSNR: 27.5622 dB


In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(500):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(600):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(700):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(800):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(900):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1000):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1100):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1200):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1300):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1400):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1500):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1600):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1700):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1800):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(1900):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)

In [None]:
model=torch.load(weight_name)
bit=12
model=bit_truncation(model,bit)
train.load_state_dict(model)
print('학습전')
test(train)
for i in range(2000):
    _train(train)
    model=train.state_dict()
    model=bit_truncation(model,bit)
    train.load_state_dict(model)
print('학습후')
test(train)