In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.init as init
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import skimage.transform
import scipy.misc
import skimage.measure

In [None]:
import Ipynb_importer
from utils import get_train_data
from utils import get_test_data
from utils import preproccess
from utils import readfile

In [None]:
class MyDataset(data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
    def __getitem__(self, index):
        img, target = self.images[index], self.labels[index]
        return img, target
    def __len__(self):
        return len(self.images)
    
def load_train(images, labels):
    train_loader = data.DataLoader(MyDataset(images, labels), batch_size=64, num_workers=0)
    return train_loader

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN,self).__init__()
        self.conv1 = nn.Sequential(      
            nn.Conv2d(                   #input shape (1, 33, 33)
                in_channels = 1,         #input height
                out_channels = 64,       #n_filters
                kernel_size = 9,         #filter size
                stride = 1,              #filter step   
                padding = 0,
            ),                           #output shape (64, 25, 25)
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(      
            nn.Conv2d(                   #input shape (64, 25, 25)
                in_channels = 64,         #input height
                out_channels = 32,       #n_filters
                kernel_size = 1,         #filter size
                stride = 1,              #filter step
                padding = 0,
            ),                           #output shape (32, 25, 25)
            nn.ReLU(),
        )
        self.conv3 = nn.Sequential(      
            nn.Conv2d(
                in_channels = 32,         #input height
                out_channels = 1,       #n_filters
                kernel_size = 5,         #filter size
                stride = 1,              #filter step
                padding = 0,
            ),                           #output shape (1, 21, 21)
            nn.ReLU(),
        )
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x
    def initParams(self):  #参数初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.normal_(m.weight, 0, 0.001)  #高斯分布
                init.constant_(m.bias, 0)
        

In [None]:
def run_train(path, image_size, label_size, stride, scale, epoch, cuda):
    images, labels = get_train_data(path, image_size, label_size, stride, scale)
    train_loader = load_train(images,labels)
    srcnn = SRCNN()
    srcnn.initParams()
    if cuda: srcnn.cuda()
    lr = 1e-4
    conv3_params = list(map(id, srcnn.conv3.parameters()))
    base_params = filter(lambda p:id(p) not in conv3_params, srcnn.parameters())
    #前两层学习率1e-4，第三层1e-5
    optimizer = torch.optim.SGD([
                {'params': base_params},
                {'params': srcnn.conv3.parameters(), 'lr': lr * 0.1}],
                lr = lr, momentum = 0.9)
    loss_func = nn.MSELoss()
    #trainning
    for e in range(epoch):
        for batch_idx, (image, label) in enumerate(train_loader):
            image = image.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)
            if cuda: image, label = image.cuda(), label.cuda()
            output = srcnn(image)
            loss = loss_func(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('epoch: ', e, '|loss: ', loss.cpu().data)
    torch.save(srcnn.state_dict(), './result/params.pkl')  #保存参数
    print('Train step done!')
    
def run_test(read_path, save_path, scale, cuda):
    srcnn = SRCNN()
    if cuda: srcnn.cuda()
    srcnn.load_state_dict(torch.load('./result/params.pkl'))   #load model
    #change padding
    i = 0
    for m in srcnn.modules():
        if isinstance(m, nn.Conv2d):
            if i==0: m.padding = 4
            elif i==1: m.padding = 0
            elif i==2: m.padding = 2
            i = i + 1
    #load data
    images, labels = get_test_data(read_path, scale=scale)
    num = len(images)
    for i in range(num):
        ref = np.uint8(labels[i]*255)
        bicubic = np.uint8(images[i]*255)
        img = images[i][np.newaxis, np.newaxis, :]
        img = torch.from_numpy(img)
        if cuda: img = img.type(torch.cuda.FloatTensor)
        else: img = img.type(torch.FloatTensor)
        img = srcnn(img)
        img = img.cpu().detach().numpy()
        img = np.squeeze(img)
        img[img > 1] = 1
        target = np.uint8(img*255)
        #保存图片
        picture = Image.fromarray(target, mode='L')
        picture.save(save_path+str(i)+'_target.bmp')
        picture = Image.fromarray(ref, mode='L')
        picture.save(save_path+str(i)+'_ref.bmp')
        picture = Image.fromarray(bicubic, mode='L')
        picture.save(save_path+str(i)+'_bicubic.bmp')
        #计算PSNR值
        print('num:', i)
        print('ref--bicubic:', skimage.measure.compare_psnr(ref,bicubic))
        print('ref--target:', skimage.measure.compare_psnr(ref,target))
    
    print('Test step done!')