# SRCNN pytorch

In [None]:
import torch.utils.data as data
from os import listdir
from os.path import join
from PIL import Image
from os.path import exists, join, basename
from os import remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, CenterCrop, ToTensor,Scale

## createdataset

In [None]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, input_transform=None, target_transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

        self.input_transform = input_transform
        self.target_transform = target_transform
    def __getitem__(self, index):
        input_image = load_img(self.image_filenames[index])
        target = input_image.copy()
        if self.input_transform:
            input_image = self.input_transform(input_image)
        if self.target_transform:
            target = self.target_transform(target)
        return input_image, target
    def __len__(self):
        return len(self.image_filenames)

In [4]:
def download_bsd300(dest="./dataset"):
    output_image_dir = join(dest, "BSDS300/images")

    if not exists(output_image_dir):
        url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
        print("downloading url ", url)
        data = urllib.request.urlopen(url)
        file_path = join(dest, basename(url))
        with open(file_path, 'wb') as f:
            f.write(data.read())
        print("Extracting data")
        with tarfile.open(file_path) as tar:
            for item in tar:
                tar.extract(item, dest)
        remove(file_path)
    return output_image_dir

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def input_transform(crop_size, upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Scale(crop_size // upscale_factor),
        ToTensor(),
    ])

def target_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor(),
    ])


def get_training_set(upscale_factor):
    root_dir = download_bsd300()
    train_dir = join(root_dir, "train")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(train_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))

def get_test_set(upscale_factor):
    root_dir = download_bsd300()
    test_dir = join(root_dir, "test")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))

## Defining the Networks

In [5]:
import torch
import torch.nn as nn
class Net(torch.nn.Module):
    def __init__(self, num_channels, base_filter, upscale_factor=2):
        super(Net, self).__init__()
        self.layers = torch.nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=base_filter, kernel_size=9, stride=1, padding=4,
                      bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=5, stride=1, padding=2,
                      bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=base_filter // 2, out_channels=num_channels * (upscale_factor ** 2), kernel_size=5, stride=1, padding=2,
                      bias=True),
            nn.PixelShuffle(upscale_factor)
        )

    def forward(self, x):
        out = self.layers(x)
        return out

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

## Training the network

In [6]:
import sys
import time
TOTAL_BAR_LENGTH = 80
LAST_T = time.time()
BEGIN_T = LAST_T
def progress_bar(current, total, msg=None):
    global LAST_T, BEGIN_T
    if current == 0:
        BEGIN_T = time.time()  # Reset for new bar.

    current_len = int(TOTAL_BAR_LENGTH * (current + 1) / total)
    rest_len = int(TOTAL_BAR_LENGTH - current_len) - 1

    sys.stdout.write(' %d/%d' % (current + 1, total))
    sys.stdout.write(' [')
    for i in range(current_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    current_time = time.time()
    step_time = current_time - LAST_T
    LAST_T = current_time
    total_time = current_time - BEGIN_T

    time_used = '  Step: %s' % format_time(step_time)
    time_used += ' | Tot: %s' % format_time(total_time)
    if msg:
        time_used += ' | ' + msg

    msg = time_used
    sys.stdout.write(msg)

    if current < total - 1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()
# return the formatted time
def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    seconds_final = int(seconds)
    seconds = seconds - seconds_final
    millis = int(seconds*1000)
    output = ''
    time_index = 1
    if days > 0:
        output += str(days) + 'D'
        time_index += 1
    if hours > 0 and time_index <= 2:
        output += str(hours) + 'h'
        time_index += 1
    if minutes > 0 and time_index <= 2:
        output += str(minutes) + 'm'
        time_index += 1
    if seconds_final > 0 and time_index <= 2:
        output += str(seconds_final) + 's'
        time_index += 1
    if millis > 0 and time_index <= 2:
        output += str(millis) + 'ms'
        time_index += 1
    if output == '':
        output = '0ms'
    return output

In [8]:
from math import log10
import torch
import torch.nn as nn
import torch.optim as optim
#import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.optim import lr_scheduler
#from SRCNN.model import Net
#from misc import progress_bar  记录时间
class SRCNNTrainer(object):
    def __init__(self, training_loader, testing_loader):
        self.model = None
        self.lr = 0.01
        self.nEpochs = 10
        self.criterion = nn.MSELoss()
        self.optimizer = None
        self.scheduler = None
        #self.GPU_IN_USE = torch.cuda.is_available()
        self.seed = 123
        self.upscale_factor = 4
        self.training_loader = training_loader
        self.testing_loader = testing_loader

    def build_model(self):
        self.model = Net(num_channels=1, base_filter=64, upscale_factor=self.upscale_factor)
        self.model.weight_init(mean=0.0, std=0.01)
        self.criterion = nn.MSELoss()
        torch.manual_seed(self.seed)

        #if self.GPU_IN_USE:
            #torch.cuda.manual_seed(self.seed)
            #self.model.cuda()
            #cudnn.benchmark = True
            #self.criterion.cuda()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)  # lr decay

    def save(self):
        model_out_path = "SRCNN_model_path.pth"
        torch.save(self.model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

    def train(self):
        """
        data: [torch.cuda.FloatTensor], 4 batches: [64, 64, 64, 8]
        """
        self.model.train()
        train_loss = 0
        for batch_num, (data, target) in enumerate(self.training_loader):
            #if self.GPU_IN_USE:
                #data, target = Variable(data).cuda(), Variable(target).cuda()
            data = Variable(data)
            target = Variable(target)
            self.optimizer.zero_grad()
            loss = self.criterion(self.model(data), target)
            train_loss += loss.data[0]
            loss.backward()
            self.optimizer.step()
            progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))

        print("Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))

    def test(self):
        """
        data: [torch.cuda.FloatTensor], 10 batches: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
        """
        self.model.eval()
        avg_psnr = 0
        for batch_num, (data, target) in enumerate(self.testing_loader):
            #if self.GPU_IN_USE:
                #data, target = Variable(data).cuda(), Variable(target).cuda()
            data = Variable(data)
            target = Variable(target)
            prediction = self.model(data)
            mse = self.criterion(prediction, target)
            psnr = 10 * log10(1 / mse.data[0])
            avg_psnr += psnr
            progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
        print("Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))

    def validate(self):
        self.build_model()
        for epoch in range(1, self.nEpochs + 1):
            print("\n===> Epoch {} starts:".format(epoch))
            self.train()
            self.test()
            self.scheduler.step(epoch)
            if epoch == self.nEpochs:
                self.save()

## Using a trained network

In [9]:
train_set = get_training_set(4)
test_set = get_test_set(4)
training_data_loader = data.DataLoader(dataset=train_set, batch_size=32, shuffle=True)
testing_data_loader = data.DataLoader(dataset=test_set, batch_size=1, shuffle=False)
model = SRCNNTrainer(training_data_loader, testing_data_loader)
model.validate()


===> Epoch 1 starts:
    Average Loss: 8.0760
    Average PSNR: 6.2543 dB

===> Epoch 2 starts:
    Average Loss: 0.2511
    Average PSNR: 6.2808 dB

===> Epoch 3 starts:
    Average Loss: 0.2435
    Average PSNR: 6.4734 dB

===> Epoch 4 starts:
    Average Loss: 0.2351
    Average PSNR: 6.7729 dB

===> Epoch 5 starts:
    Average Loss: 0.2211
    Average PSNR: 7.1457 dB

===> Epoch 6 starts:
    Average Loss: 0.1995
    Average PSNR: 7.5678 dB

===> Epoch 7 starts:
    Average Loss: 0.1814
    Average PSNR: 8.0235 dB

===> Epoch 8 starts:
    Average Loss: 0.1671
    Average PSNR: 8.4975 dB

===> Epoch 9 starts:
    Average Loss: 0.1492
    Average PSNR: 8.9824 dB

===> Epoch 10 starts:
    Average Loss: 0.1381
    Average PSNR: 9.4639 dB


  "type " + obj.__name__ + ". It won't be checked "


Checkpoint saved to SRCNN_model_path.pth
