
# 基于内容自适应重采样(CAR)的学习图像降采样方法

CAR算法是一种高效的图像下采样和上采样方法，能够帮助图像数据的存储并减小图像传输所需带宽，同时不损失图像的细节。算法设计了一个重采样网络，用于生成低分辨率图像，同时引入了一个可差分的超分辨率网络来恢复低分辨率图像，通过重构损失来更新整个模型的参数。实验证明，该算法达到了最先进的超分辨率性能。

# 模型简介

![show_images](images/model.jpg)

如上图所示，CAR算法采用ResamplerNet生成下采样图像所需的权重与偏移，ResamplerNet由卷积和残差块组成。得到权重与偏移后，通过Downscaling进行图像下采样，下采样过程由cuda实现，然后将下采样图像通过超分辨率网络恢复，最后将恢复后的图像和原始图像通过L1范数对比，得到重构损失，并更新网络参数。

## 数据处理

开始实验之前，请确保本地已经安装了Python环境并安装了MindSpore Vision套件。

## 数据准备

训练数据采用DIV2K中的高清图像，训练集包含800张高清图像，验证集包含100张高清图像。
训练集下载地址：http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
验证集下载地址：http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
请将解压后的数据集放到./datasets/DIV2K/下，文件目录如下所示：

```text

.datasets/
    └── DIV2K
            ├── DIV2K_train_HR
            |    ├── 0001.png
            |    ├── 0002.png
            |    ├── ...
            ├── DIV2K_valid_HR
            |    ├── 000801.png
            |    ├── 000802.png
            |    ├── ...

```

In [1]:
import os
import argparse
from typing import Optional, Callable

import numpy as np
import mindspore as ms
import mindspore.dataset as ds
from mindspore import context, nn, ops, Model
from mindspore.dataset import vision
from mindvision.io.images import imread
from mindvision.check_param import Validator

class DIV2KHR:
    def __init__(self, path: str, split: str):
        Validator.check_string(split, ["train", "valid"], "split")
        realpath = os.path.realpath(path)
        self.path = os.path.join(realpath, f"DIV2K_{split}_HR")
        data = os.listdir(self.path)
        self.data_list = [os.path.join(self.path, idx) for idx in data]

    def __getitem__(self, index):
        """ Get a list of datasets """
        return imread(self.data_list[index], 'RGB')

    def __len__(self):
        """ Get the length of each line """
        return len(self.data_list)


def default_transform(image):
    image = np.asarray(image)
    height, width, _ = image.shape
    image = image[:height // 8 * 8, :width // 8 * 8, :] # 输入图像必须保证长宽8对齐
    image = vision.py_transforms.ToTensor()(image)

    return image

def build_dataset(dataset,
                  batch_size: int = 1,
                  repeat_num: int = 1,
                  shuffle: Optional[bool] = False,
                  num_parallel_workers: Optional[int] = 1,
                  num_shards: Optional[int] = None,
                  shard_id: Optional[int] = None,
                  transform: Optional[Callable] = default_transform):
    dataset = ds.GeneratorDataset(dataset, ['image'],
                                  num_parallel_workers=num_parallel_workers,
                                  shuffle=shuffle,
                                  num_shards=num_shards,
                                  shard_id=shard_id)
    if transform:
        dataset = dataset.map(operations=transform,
                              input_columns='image',
                              num_parallel_workers=num_parallel_workers)

    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat(repeat_num)

    return dataset

# 初始化参数
parser = argparse.ArgumentParser(description='Train CAR')
parser.add_argument('--image_path', default='./datasets/DIV2K', type=str) #数据集路径
parser.add_argument('-j', '--workers', default=1, type=int)
parser.add_argument('--device_target', default='GPU', choices=['CPU', 'GPU', 'Ascend'], type=str)
parser.add_argument('--end_epoch', default=500, type=int)
parser.add_argument('--train_batchsize', default=8, type=int)
parser.add_argument('--train_repeat_num', default=1, type=int)
parser.add_argument('--train_resize', default=192, type=int)  # 4倍下采样
parser.add_argument('--scale', default=4, type=int, help='downscale factor')
parser.add_argument('--eval_proid', default=1, type=int)
parser.add_argument('--checkpoint_path', default='./checkpoint', type=str)
parser.add_argument('--output_dir', type=str, default='./exp_res', help='path to store results')
args = parser.parse_known_args()[0]

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=3)

### 数据增强

训练过程中，将每张高清图像随机裁剪到192×192(4倍下采样)或96×96(2倍下采样)，训练过程只采用随机水平翻转和随机垂直翻转。

In [2]:
# 创建训练数据集
train_transform = [vision.c_transforms.RandomCrop(args.train_resize),
                   vision.c_transforms.RandomHorizontalFlip(),
                   vision.c_transforms.RandomVerticalFlip(),
                   vision.py_transforms.ToTensor()]

train_dataloader = build_dataset(DIV2KHR(args.image_path, "train"),
                                 batch_size=args.train_batchsize,
                                 repeat_num=args.train_repeat_num,
                                 shuffle=True,
                                 transform=train_transform)
step_size = train_dataloader.get_dataset_size()

## 构建网络

ResamplerNet网络中使用3x3卷积和LeakyReLU将特征升维到128，然后使用5个残差结构提取特征，最后用两个相同的结构分支计算采样权重和偏移，分支由‘Conv-LeakyReLU’ 对组成，并且将特征维度升至256。超分辨率网络采用EDSR，由32个残差结构组成，每个残差结构的特征维度为256.

In [3]:
import mindspore.nn as nn
import mindspore.ops as ops

from src.model.block import ResBlock, default_conv, Upsampler
from src.model.block import MeanShift, NormalizeBySum, DownsampleBlock, ReflectionPad2d, UpsampleBlock, ResidualBlock

class DSN(nn.Cell):
    def __init__(self, k_size, input_channels=3, scale=4):
        super().__init__()

        self.k_size = k_size
        self.sub_mean = MeanShift(1)
        self.normalize = NormalizeBySum()

        self.ds_1 = nn.SequentialCell(
            ReflectionPad2d(2),
            nn.Conv2d(input_channels, 64, 5, pad_mode="valid", has_bias=True),
            nn.LeakyReLU(0.2),
        )

        self.ds_2 = DownsampleBlock(2, 64, 128, ksize=1)
        self.ds_4 = DownsampleBlock(2, 128, 128, ksize=1)

        res_4 = []
        for _ in range(5):
            res_4 += [ResidualBlock(128, 128)]
        self.res_4 = nn.SequentialCell(*res_4)

        self.ds_8 = DownsampleBlock(2, 128, 256)

        self.kernels_trunk = nn.SequentialCell(
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="pad", has_bias=True),
            nn.ReLU(),
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="pad", has_bias=True),
            nn.ReLU(),
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="pad", has_bias=True),
            nn.ReLU(),
            UpsampleBlock(8 // scale, 256, 256, ksize=1),
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="pad", has_bias=True),
            nn.ReLU()
        )

        self.kernels_weight = nn.SequentialCell(
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="pad", has_bias=True),
            nn.ReLU(),
            ReflectionPad2d(1),
            nn.Conv2d(256, k_size**2, 3, pad_mode="pad", has_bias=True),
        )

        self.offsets_trunk = nn.SequentialCell(
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="valid", has_bias=True),
            nn.ReLU(),
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="valid", has_bias=True),
            nn.ReLU(),
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="valid", has_bias=True),
            nn.ReLU(),
            UpsampleBlock(8 // scale, 256, 256, ksize=1),
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="valid", has_bias=True),
            nn.ReLU(),
        )

        self.offsets_h_generation = nn.SequentialCell(
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="valid", has_bias=True),
            nn.ReLU(),
            ReflectionPad2d(1),
            nn.Conv2d(256, k_size**2, 3, pad_mode="valid", has_bias=True),
            nn.Tanh(),
        )

        self.offsets_v_generation = nn.SequentialCell(
            ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3, pad_mode="valid", has_bias=True),
            nn.ReLU(),
            ReflectionPad2d(1),
            nn.Conv2d(256, k_size**2, 3, pad_mode="valid", has_bias=True),
            nn.Tanh(),
        )

    def construct(self, img):
        x = self.sub_mean(img)

        x = self.ds_1(x)
        x = self.ds_2(x)
        x = self.ds_4(x)
        x = x + self.res_4(x)
        x = self.ds_8(x)

        kt = self.kernels_trunk(x)

        kt = self.kernels_weight(kt)
        k_weight = ops.clip_by_value(kt, 1e-6, 1)
        kernels = self.normalize(k_weight)

        ot = self.offsets_trunk(x)
        offsets_h = self.offsets_h_generation(ot)
        offsets_v = self.offsets_v_generation(ot)

        return kernels, offsets_h, offsets_v

class EDSR(nn.Cell):
    """
    Upscaling module to guide the training of the proposed CAR model.

    Args:
        n_resblocks(int): The number of net blocks. Default: 16.
        n_feats(int): The hidden layer features dimensions. default: 64.
        scale(int): Upscaling rate. Default: 4.
        conv(Cell): The convolution layer. Default: default_conv.

    Inputs:
        - **x** (Tensor) - The downscale image tensors.

    Outputs:
        Tensor, The super-resolution image tensors.
    """

    def __init__(self, n_resblocks=16, n_feats=64, scale=4, conv=default_conv):
        super(EDSR, self).__init__()

        kernel_size = 3
        act = nn.ReLU()
        self.sub_mean = MeanShift(1)
        self.add_mean = MeanShift(1, sign=1)

        # define head module
        m_head = [conv(3, n_feats, kernel_size)]

        # define body module
        m_body = [
            ResBlock(conv, n_feats, kernel_size, act=act, res_scale=0.1)
            for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, 3, kernel_size),
        ]
        self.head = nn.SequentialCell(m_head)
        self.body = nn.SequentialCell(*m_body)
        self.tail = nn.SequentialCell(*m_tail)

    def construct(self, x):
        x = self.sub_mean(x)
        x = self.head(x)
        res = self.body(x)
        res += x
        x = self.tail(res)
        x = self.add_mean(x)

        return x

scale = args.scale
kernel_size = 3 * scale + 1

# create model
kernel_generation_net = DSN(k_size=kernel_size, scale=scale)  # ResamplerNet
upscale_net = EDSR(32, 256, scale=scale)    # SRNet

## 下采样过程

下采样过程通过cuda编程实现，mindspore采用aot方式添加自定义算子，需要执行以下命令

```sh
cd codebase/course/application_example/CAR/src/plug_in/adaptive_gridsampler
python setup.py
```

编译完成后，可以生成so文件，即可正常导入下采样算子。

In [4]:
import mindspore.ops as ops
from mindspore.nn import Cell
from mindspore.ops import DataType, CustomRegOp

from src.model.block import ReflectionPad2d

so_path = "./src/plug_in/adaptive_gridsampler/adaptive_gridsampler_cuda.so"

sampler_gpu_info = (CustomRegOp()
                    .input(0, "img")
                    .input(1, "kernels")
                    .input(2, "offsets_h")
                    .input(3, "offsets_v")
                    .input(4, "offset_unit")
                    .input(5, "padding")
                    .output(0, "output")
                    .dtype_format(DataType.F32_Default,
                                  DataType.F32_Default,
                                  DataType.F32_Default,
                                  DataType.F32_Default,
                                  DataType.None_Default,
                                  DataType.None_Default,
                                  DataType.F32_Default)
                    .target("GPU")
                    .get_op_info())

sampler_bprop_gpu_info = (CustomRegOp()
                          .input(0, "img")
                          .input(1, "kernels")
                          .input(2, "offsets_h")
                          .input(3, "offsets_v")
                          .input(4, "offset_unit")
                          .input(5, "padding")
                          .input(6, "grad_output")
                          .output(0, "grad_k")
                          .output(1, "grad_oh")
                          .output(2, "grad_ob")
                          .dtype_format(DataType.F32_Default,
                                        DataType.F32_Default,
                                        DataType.F32_Default,
                                        DataType.F32_Default,
                                        DataType.None_Default,
                                        DataType.None_Default,
                                        DataType.F32_Default,
                                        DataType.F32_Default,
                                        DataType.F32_Default,
                                        DataType.F32_Default)
                          .target("GPU")
                          .get_op_info())


def infer_shape_backward(x1, x2, x3, x4, x5, x6, x7):
    _ = (x1, x3, x4, x5, x6, x7)
    return (x2, x2, x2)


def infer_type_backward(x1, x2, x3, x4, x5, x6, x7):
    _ = (x2, x3, x4, x5, x6, x7)
    return (x1, x1, x1)


aot_bprop = ops.Custom(
    so_path + ":adaptive_gridsampler_cuda_backward",
    infer_shape_backward,
    infer_type_backward,
    "aot",
    reg_info=sampler_bprop_gpu_info,
)


def backward(img, kernels, offsets_h, offsets_v, offset_unit, padding, out, dout):
    _ = out
    input_img = img[..., padding:-padding, padding:-padding]
    grad_k, grad_h, grad_v = aot_bprop(
        input_img, kernels, offsets_h, offsets_v, offset_unit, padding, dout
    )

    return (ops.ZerosLike()(img), grad_k, grad_h, grad_v, None, None)


def infer_shape_forward(x1, x2, x3, x4, x5, x6):
    _ = (x3, x4, x5, x6)
    shape = (x1[0], x1[1], x2[2], x2[3])
    return shape


def infer_type_forward(x1, x2, x3, x4, x5, x6):
    _ = (x2, x3, x4, x5, x6)
    return x1


class Downsampler(Cell):
    """
    Downsampler
    """
    def __init__(self, k_size):
        super().__init__()
        self.k_size = k_size
        self.ops = ops.Custom(
            so_path + ":adaptive_gridsampler_cuda_forward",
            infer_shape_forward,
            infer_type_forward,
            "aot",
            backward,
            sampler_gpu_info,
        )

    def construct(self, img, kernels, offsets_h, offsets_v, offset_unit):
        padding = self.k_size // 2
        img = ReflectionPad2d(padding)(img)

        return self.ops(img, kernels, offsets_h, offsets_v, offset_unit, padding)

downsampler_net = Downsampler(kernel_size)

## 量化过程

下采样后的数据为连续的浮点类型，而一般图像使用0-255的整型数据表示像素值，将浮点类型量化为整型数据的过程是一个不可导的过程，为了能对整个网络端到端的求导，论文中采用soft round函数拟合下采样过程。公式如下：

$$
round_{soft}(x)=x-α*\frac{\sin(2\pi x)}{2\pi}
$$

该函数仅在反向传播时用于求导，计算其导函数可得：

$$
round_{soft}'(x)=1-α*\cos(2\pi x)
$$

In [5]:
class Quantization(nn.Cell):
    def __init__(self):
        super().__init__()
        self.round = ops.Round()
        self.clip = ops.clip_by_value
        self.cos = ops.Cos()
        self.pi = 3.1415926
        self.alp = 1.

    def construct(self, img):
        img = self.clip(img, 0, 1)
        img = img * 255
        img = self.round(img)
        return img / 255

    def bprop(self, img, out, grad_output):
        _ = out

        grad_input = grad_output
        grad_input = grad_output*(1-self.alp*self.cos(2*self.pi*img))
        return (grad_input,)
quant = Quantization()

## 损失函数

论文中采用了三种损失函数，分别是L1loss、offsetloss 和 partial TV loss

### L1loss

L1loss定义为恢复后的图像和原始图像的L1范数。

$$
\frac{1}{N} \sum_{\boldsymbol{p} \in \mathbf{I}}|\boldsymbol{p}-\hat{\boldsymbol{p}}|
$$

其中 $\hat{\mathbf{I}}$ 表示超分辨率结果, $\boldsymbol{p}$ 和 $\hat{\boldsymbol{p}}$ 分别表示ground-truth和重构像素值, N 为像素点数量和颜色通道的乘积。

### offsetloss

offsetloss用于保证下采样后的图片仍然有很好的拓扑结构。对于下采样后的每个点，远离采样中心的像素和采样点的相关性更低，通过offsetloss约束偏移矩阵的权重。

$$\sum_{i=0}^{m-1} \sum_{j=0}^{n-1} \eta+\sqrt{\Delta X_{x, y}(i, j)^{2}+\Delta Y_{x, y}(i, j)^{2}} \cdot w(i, j)$$

其中$w(i, j) = \sqrt{\left(i-\frac{m}{2}\right)^{2}+\left(j-\frac{n}{2}\right)^{2}} / \sqrt{\frac{m}{2}^{2}+\frac{n^{2}}{2}}$，(m, n)表示采样中心坐标，(i, j)表示偏移矩阵(i, j)位置的值，$w(i, j)$为(i, j)到(m, n)的距离。

In [6]:
class OffsetLoss(nn.Cell):
    def __init__(self, kernel_size=13, offsetloss_weight=1.):
        super(OffsetLoss, self).__init__()
        self.offsetloss_weight = offsetloss_weight # loss 权重
        x = ms.numpy.arange(0, kernel_size, dtype=ms.float32)
        y = ms.numpy.arange(0, kernel_size, dtype=ms.float32)
        x_m, y_m = ops.Meshgrid()((x, y))
        self.sqrt = ops.Sqrt()
        weight = self.sqrt((x_m-kernel_size/2)**2 + (y_m-kernel_size/2)**2)/kernel_size
        self.weight = weight.view(1, kernel_size**2, 1, 1)

    def construct(self, offsets_h, offsets_v):
        b, _, h, w = offsets_h.shape
        loss = self.sqrt(offsets_h * offsets_h + offsets_v * offsets_v)*self.weight
        return self.offsetloss_weight*loss.sum()/(h * w * b)

### partial TV loss

相邻采样核偏移不一致可能导致下采样图像的像素相移，表现为锯齿状，特别是在垂直和水平的尖锐边缘。因此引入 partial TV loss保证偏移的一致性。

$$
Loss^{TV} = \sum_{x, y}\left(\sum_{i, j}\left|\Delta X_{\cdot, y+1}(i, j)-\Delta X_{\cdot, y}(i, j)\right| \cdot \mathbf{K}(i, j) + \sum_{i, j}\left|\Delta Y_{x+1, \cdot}(i, j)-\Delta Y_{x, \cdot}(i, j)\right| \cdot \mathbf{K}(i, j)\right)
$$


In [7]:
class TvLoss(nn.Cell):
    def __init__(self, tvloss_weight=1):
        super(TvLoss, self).__init__()
        self.tvloss_weight = tvloss_weight  # loss 权重
        self.abs = ops.Abs()

    def construct(self, offsets_h, offsets_v, kernel):
        batch, _, _, _ = offsets_h.shape
        diff_1 = self.abs(offsets_v[..., 1:] - offsets_v[..., :-1]) * kernel[..., :-1]
        diff_2 = self.abs(offsets_h[:, :, 1:, :] - offsets_h[:, :, :-1, :]) * kernel[:, :, :-1, :]
        tv_loss = diff_1.sum()+diff_2.sum()
        return self.tvloss_weight * tv_loss / batch

## 创建训练网络

将ResamplerNet，SRNet, 下采样过程, 量化过程和loss组合起来。构建训练网络，初始学习率$10^{−4}$ ,训练500epoch，每100个epoch降低学习率。优化器采用Adam， β1 = 0.9, β2 = 0.999
, $\epsilon$ = 10−6.


In [8]:
class NetWithLoss(nn.Cell):
    def __init__(self, net1, net2, aux_net1, aux_net2, offset, loss1, loss2):
        super(NetWithLoss, self).__init__()
        self.net1 = net1
        self.net2 = net2
        self.dsn = aux_net1
        self.quant = aux_net2
        self.offset_unit = offset
        self.tv_loss = loss1
        self.offset_loss = loss2
        self.l1_loss = nn.L1Loss()

    def construct(self, image):
        kernels, offsets_h, offsets_v = self.net1(image)
        downscaled_img = self.dsn(image, kernels, offsets_h, offsets_v, self.offset_unit)
        downscaled_img = self.quant(downscaled_img)
        reconstructed_img = self.net2(downscaled_img)
        loss1 = self.l1_loss(reconstructed_img, image)
        loss2 = self.tv_loss(offsets_h, offsets_v, kernels)
        loss3 = self.offset_loss(offsets_h, offsets_v)

        return loss1 + loss2 + loss3


network = NetWithLoss(kernel_generation_net,
                      upscale_net,
                      downsampler_net,
                      quant,
                      scale,
                      TvLoss(0.005),
                      OffsetLoss(offsetloss_weight=0.001))

num_epochs = args.end_epoch
total_steps = step_size * num_epochs
lr = nn.dynamic_lr.piecewise_constant_lr([int(0.2*total_steps), int(0.4*total_steps),
                                          int(0.6*total_steps), int(0.8*total_steps),
                                          total_steps],
                                         [1e-4, 5e-5, 1e-5, 5e-6, 1e-6])
opt_para = list(kernel_generation_net.trainable_params())+list(upscale_net.trainable_params())
opt = nn.optim.Adam(opt_para, learning_rate=lr, eps=1e-6)
model = Model(network=network, optimizer=opt)

## 创建评估网络

评估网络采用DIV2KHR中验证集的10张图片，由于每张图片的大小不一致，因此测试集的batchsize设置为1。评估网络通过callback调用，并保存网络权重文件。

In [9]:
import numpy as np

from collections import OrderedDict

from mindspore import ops, Tensor, save_checkpoint
from mindspore.train.callback import Callback

from src.utils.metric import cal_psnr

class ValidateCell(nn.Cell):
    def __init__(self, net1, net2, aux_net1, aux_net2, scale, offset):
        super(ValidateCell, self).__init__()
        self.net1 = net1
        self.net2 = net2
        self.dsn = aux_net1
        self.quant = aux_net2
        self.offset_unit = offset
        self.scale = scale

    def construct(self, image):
        kernels, offsets_h, offsets_v = self.net1(image)
        downscaled_img = self.dsn(image, kernels, offsets_h, offsets_v, self.offset_unit)
        downscaled_img = self.quant(downscaled_img)
        reconstructed_img = self.net2(downscaled_img)

        return downscaled_img, reconstructed_img

class SaveCheckpoint(Callback):
    def __init__(self, eval_model, ds_eval, scale, save_path, eval_period=1):
        """init"""
        super(SaveCheckpoint, self).__init__()
        self.model = eval_model
        self.ds_eval = ds_eval
        self.m_psnr = 0.
        self.eval_period = eval_period
        path = os.path.realpath(save_path)
        if not os.path.exists(path):
            os.makedirs(path, exist_ok=True)
        self.save_path = path
        self.scale = scale

    def epoch_end(self, run_context):
        cb_params = run_context.original_args()
        cur_epoch = cb_params.cur_epoch_num
        scale = self.scale
        psnr_list = []
        if ((cur_epoch + 1) % self.eval_period) == 0:
            print("Validating...")
            for i, data in enumerate(self.ds_eval.create_dict_iterator()):
                if i > 10:
                    break
                image = data['image']
                _, reconstructed_img = self.model(image)
                image = image.asnumpy().transpose(0, 2, 3, 1)
                orig_img = np.uint8(image * 255).squeeze()
                reconstructed_img = ops.clip_by_value(reconstructed_img, 0, 1) * 255
                reconstructed_img = reconstructed_img.asnumpy().transpose(0, 2, 3, 1)
                recon_img = np.uint8(reconstructed_img).squeeze()

                psnr = cal_psnr(orig_img[scale:-scale, scale:-scale, ...],
                                recon_img[scale:-scale, scale:-scale, ...])
                psnr_list.append(psnr)
            m_psnr = np.mean(psnr_list)
            if m_psnr > self.m_psnr:
                self.m_psnr = m_psnr
                save_path = os.path.join(self.save_path, f"{self.scale}x")
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                net = cb_params.train_network
                net.init_parameters_data()
                param_dict = OrderedDict()
                for _, param in net.parameters_and_names():
                    param_dict[param.name] = param
                param_kgn = []
                param_usn = []
                for (key, value) in param_dict.items():
                    if "net1" in key:
                        each_param = {"name": key.replace("net1.", "")}
                        param_data = Tensor(value.data.asnumpy())
                        each_param["data"] = param_data
                        param_kgn.append(each_param)
                    elif "net2" in key:
                        each_param = {"name": key.replace("net2.", "")}
                        param_data = Tensor(value.data.asnumpy())
                        each_param["data"] = param_data
                        param_usn.append(each_param)
                save_checkpoint(param_kgn, os.path.join(save_path, "kgn.ckpt"))  # 将resampler和SRNet的权重分开保存
                save_checkpoint(param_usn, os.path.join(save_path, "usn.ckpt"))
                print(f"epoce {cur_epoch}, Save model at {self.save_path}, m_psnr for 10 images: {m_psnr}")
            else:
                print(f"epoce {cur_epoch}, m_psnr for 10 images: {m_psnr}")
            print("Validating Done.")

    def end(self, run_context):
        cb_params = run_context.original_args()
        cur_epoch = cb_params.cur_epoch_num
        print(f"Finish training, totally epoches: {cur_epoch}, best psnr: {self.m_psnr}")

eval_network = ValidateCell(kernel_generation_net, upscale_net, downsampler_net, quant, scale, scale)
val_dataloader = build_dataset(DIV2KHR(args.image_path, "valid"),
                               batch_size=1,
                               repeat_num=1,
                               shuffle=False,
                               num_parallel_workers=args.workers)

cb_savecheckpoint = SaveCheckpoint(eval_network, val_dataloader, scale, args.checkpoint_path, args.eval_proid)

In [10]:
from mindspore.train.callback import LossMonitor

training = False
if training:
    print("start training..")
    model.train(args.end_epoch, train_dataloader, callbacks=[LossMonitor(), cb_savecheckpoint], dataset_sink_mode=False)


start training..


## 模型评估

模型评估采用"Set5", "BSDS100", "Set14", "Urban100", "DIV2KHR",将解压后的数据集放到./datasets/下，文件目录如下所示：

```text

        └── datasets
             ├── Set5
             |    ├── baby.png
             |    ├── bird.png
             |    ├── ...
             ├── Set14
             |    ├── baboon.png
             |    ├── barbara.png
             |    ├── ...
             ├── BSDS100
             |    ├── 101085.png
             |    ├── 101087.png
             |    ├── ...
             ├── Urban100
             |    ├── img_001.png
             |    ├── img_002.png
             |    ├── ...
             └── DIV2K
                    ├── DIV2K_train_HR
                    |    ├── 0001.png
                    |    ├── 0002.png
                    |    ├── ...
                    ├── DIV2K_valid_HR
                    |    ├── 000801.png
                    |    ├── 000802.png
                    |    ├── ...

```

权重文件存放在./checkpoint下，目录如下：

```text

        └── checkpoint
             ├── 2x
             |    ├── kgn.ckpt
             |    ├── usn.ckpt
             └── 4x
                  ├── kgn.ckpt
                  └── usn.ckpt

```

In [13]:
from tqdm import tqdm
from mindspore import load_checkpoint, load_param_into_net, context
from src.utils.metric import compute_psnr_ssim, ValidateCell
from src.process_dataset.dataset import Set5Test


kernel_generation_net = DSN(k_size=kernel_size, scale=scale)
upscale_net = EDSR(32, 256, scale=scale)

#load checkpoint

kgn_dict = load_checkpoint(os.path.join(args.checkpoint_path, f"{scale}x", "kgn.ckpt"))
usn_dict = load_checkpoint(os.path.join(args.checkpoint_path, f"{scale}x", "usn.ckpt"))
load_param_into_net(kernel_generation_net, kgn_dict, strict_load=True)
load_param_into_net(upscale_net, usn_dict, strict_load=True)
kernel_generation_net.set_train(False)
upscale_net.set_train(False)
downsampler_net.set_train(False)
quant.set_train(False)
valid_net = ValidateCell(kernel_generation_net, upscale_net, downsampler_net, quant, scale, scale)

#read data
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

target_dataset = ["Set5", "BSDS100", "Set14", "Urban100", "DIV2KHR"]

for data_type in target_dataset:
    if data_type == "DIV2KHR":
        val_dataloader = build_dataset(DIV2KHR("./datasets/DIV2K", "valid"), 1, 1, False) # 由于图片大小不一致，batch_size 设置为1
    else:
        val_dataloader = build_dataset(Set5Test("./datasets/", data_type), 1, 1, False)

    psnr_list = list()
    ssim_list = list()
    save_dir = os.path.join(args.output_dir, data_type)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print(f"Validating... {data_type}.Downscaling x{scale}")
    for i, data in enumerate(tqdm(val_dataloader.create_dict_iterator(), total=val_dataloader.get_dataset_size())):
        img = data['image']
        downscaled_img, reconstructed_img = valid_net(img)
        psnr, ssim = compute_psnr_ssim(img, downscaled_img, reconstructed_img, i, save_dir, scale, True)
        psnr_list.append(psnr)
        ssim_list.append(ssim)

    print(f"For \'{data_type}\', save results at {save_dir}")
    print('Mean PSNR: {0:.2f}'.format(np.mean(psnr_list)))
    print('Mean SSIM: {0:.4f}'.format(np.mean(ssim_list)))
    print("="*30)

Validating... Set5.Downscaling x4


100%|██████████| 5/5 [00:08<00:00,  1.73s/it]


For 'Set5', save results at ./exp_res/Set5
Mean PSNR: 34.17
Mean SSIM: 0.9196
Validating... BSDS100.Downscaling x4


100%|██████████| 100/100 [00:47<00:00,  2.11it/s]


For 'BSDS100', save results at ./exp_res/BSDS100
Mean PSNR: 29.49
Mean SSIM: 0.8092
Validating... Set14.Downscaling x4


100%|██████████| 14/14 [00:20<00:00,  1.46s/it]


For 'Set14', save results at ./exp_res/Set14
Mean PSNR: 30.61
Mean SSIM: 0.8427
Validating... Urban100.Downscaling x4


100%|██████████| 100/100 [04:20<00:00,  2.61s/it]


For 'Urban100', save results at ./exp_res/Urban100
Mean PSNR: 29.31
Mean SSIM: 0.8704
Validating... DIV2KHR.Downscaling x4


100%|██████████| 100/100 [14:30<00:00,  8.71s/it]

For 'DIV2KHR', save results at ./exp_res/DIV2KHR
Mean PSNR: 32.68
Mean SSIM: 0.8871





# 结果对比

与论文中的结果对比，误差均在3%以内
![show_images](images/res.jpg)