## 超分辨率训练和测试范例

这里提供一个超分辨率模型的训练和测试范例。

首先，我们假定您的训练数据和测试数据存在下面的位置：

In [5]:
TRAIN_RAW_DATA = "./dataset/game1/train_png/"
TEST_RAW_DATA = "./dataset/game1/test.tar"

### 处理训练数据

我们已将数据集解压在./dataset/game1/train_png/，因此您无需专门解压缩数据集

假定您要将训练数据存储在下面位置：

In [6]:
TRAIN_DATA_STORAGE = "./workspace/train_patches"

则用下面的代码处理您的训练数据：

In [7]:
import numpy as np
import subprocess
import random
import os
import cv2
import tarfile
import io
from tqdm import tqdm

#TRAIN_RAW_DATA='./train_png/'
#TRAIN_DATA_STORAGE='./train_patches/'

random.seed(100)
os.makedirs(TRAIN_DATA_STORAGE, exist_ok=True)

tasks = sorted([os.path.join(TRAIN_RAW_DATA,i) for i in os.listdir(TRAIN_RAW_DATA) if 'down4x' in i])

count = 0
for task in tqdm(tasks):
    task_origin = task.replace('_down4x.mp4','')
    frames_origin = sorted([os.path.join(task_origin,i) for i in os.listdir(task_origin)])
    frames_down4x = sorted([os.path.join(task,i) for i in os.listdir(task)])

    for k, (frame_down4x,
            frame_origin) in enumerate(zip(frames_down4x, frames_origin)):
        if random.random() < 0.1:
            img_origin = cv2.imread(frame_origin)
            if img_origin.shape[0] < 256 or img_origin.shape[1] < 256:
                continue
                
            img_down4x = cv2.imread(frame_down4x)
            img_down4x = cv2.resize(
                img_down4x, (img_origin.shape[1], img_origin.shape[0]))

            x0 = random.randrange(img_origin.shape[0] - 256 + 1)
            y0 = random.randrange(img_origin.shape[1] - 256 + 1)

            img_show = np.float32(
                np.stack((img_down4x[x0:x0 + 256, y0:y0 + 256].transpose((2, 0, 1)),
                          img_origin[x0:x0 + 256, y0:y0 + 256].transpose((2, 0, 1))))) / 256
            np.save(os.path.join(TRAIN_DATA_STORAGE, '%04d.npy' % count), img_show)
            count += 1


100%|██████████| 90/90 [05:07<00:00,  3.42s/it]


## 构建网络

作为示例，我们构建一个简单的网络。

In [1]:
import megengine as mge
import megengine.module as M
import megengine.functional as F

def addLeakyRelu(x):
    return M.Sequential(x, M.LeakyReLU(0.2))

def addPadding(x):
    shape = x.shape
    padding_shape = [(k + 1) // 2 * 2 for k in shape]
    res = mge.zeros(padding_shape, dtype=x.dtype)
    res = res.set_subtensor(x)[:shape[0], :shape[1], :shape[2], :shape[3]]
    return res

class SimpleUNet(M.Module):
    def __init__(self):
        super().__init__()

        self.conv0 = addLeakyRelu(M.Conv2d(3, 32, 4, padding=1, stride=2))
        self.conv1 = addLeakyRelu(M.Conv2d(32, 64, 4, padding=1, stride=2))
        self.conv2 = addLeakyRelu(M.Conv2d(64, 128, 4, padding=1, stride=2))
        self.conv3 = addLeakyRelu(M.Conv2d(128, 256, 4, padding=1, stride=2))
        self.conv4 = addLeakyRelu(M.Conv2d(256, 512, 4, padding=1, stride=2))
        self.conv5 = addLeakyRelu(M.Conv2d(512, 1024, 4, padding=1, stride=2))
        self.deconv5 = addLeakyRelu(M.ConvTranspose2d(1024, 512, 4, stride=2, padding=1))
        self.deconv4 = addLeakyRelu(M.ConvTranspose2d(1024, 256, 4, stride=2, padding=1))
        self.deconv3 = addLeakyRelu(M.ConvTranspose2d(512, 128, 4, stride=2, padding=1))
        self.deconv2 = addLeakyRelu(M.ConvTranspose2d(256, 64, 4, stride=2, padding=1))
        self.deconv1 = addLeakyRelu(M.ConvTranspose2d(128, 32, 4, stride=2, padding=1))
        self.deconv0 = addLeakyRelu(M.ConvTranspose2d(64, 3, 4, stride=2, padding=1))

    def forward(self, x):
        conv0 = addPadding(self.conv0(x))
        conv1 = addPadding(self.conv1(conv0))
        conv2 = addPadding(self.conv2(conv1))
        conv3 = addPadding(self.conv3(conv2))
        conv4 = addPadding(self.conv4(conv3))
        conv5 = self.conv5(conv4)

        conv5 = self.deconv5(conv5)[:, :, :conv4.shape[2], :conv4.shape[3]]  #  1/32   512
        conv4 = self.deconv4(F.concat([conv5, conv4], 1))[:, :, :conv3.shape[2], :conv3.shape[3]]  #  1/16   256
        conv3 = self.deconv3(F.concat([conv4, conv3], 1))[:, :, :conv2.shape[2], :conv2.shape[3]]  #  1/8   128
        conv2 = self.deconv2(F.concat([conv3, conv2], 1))[:, :, :conv1.shape[2], :conv1.shape[3]]  #  1/4   64
        conv1 = self.deconv1(F.concat([conv2, conv1], 1))[:, :, :conv0.shape[2], :conv0.shape[3]]  #  1/2   32
        conv0 = self.deconv0(F.concat([conv1, conv0], 1))[:, :, :x.shape[2], :x.shape[3]]  #  1/1   3

        return conv0

## 训练网络

接下来我们开始训练网络。在此之前，假定您想要把网络存储在下面位置：

In [2]:
MODEL_PATH = "./workspace/model.mge.state"

使用下面的代码训练您的网络：

In [10]:
import time
from functools import lru_cache
from megengine.optimizer import Adam

train_steps = 1000
batch_size = 8
input_h = 256
input_w = 256

net = SimpleUNet()
optimizer = Adam(net.parameters(), lr=1e-4)

random.seed(100)

@lru_cache(maxsize=None)
def load_image(path):
    return np.load(path, mmap_mode="r")

train_patches = sorted([os.path.join(TRAIN_DATA_STORAGE, f) for f in os.listdir(TRAIN_DATA_STORAGE)])

def load_batch():
    batch_train = []
    batch_gt = []
    for i in range(batch_size):
        path = random.choice(train_patches)
        img = load_image(path)
        batch_train.append(img[0])
        batch_gt.append(img[1])
    return np.array(batch_train), np.array(batch_gt)

@mge.jit.trace
def train_iter(batch_train, batch_gt):
    pred = net(batch_train)
    loss = F.abs(batch_gt - pred).mean()
    optimizer.backward(loss)
    return loss, pred

loss_acc = 0
loss_acc0 = 0

for it in range(train_steps + 1):
    for g in optimizer.param_groups:
        g['lr'] = 2e-4 * (train_steps - it) / train_steps

    begin = time.time()
    (batch_train, batch_gt) = load_batch()
    data_load_end = time.time()

    optimizer.zero_grad()
    loss, pred = train_iter(batch_train, batch_gt)
    optimizer.step()
    loss_acc = loss_acc * 0.99 + loss
    loss_acc0 = loss_acc0 * 0.99 + 1
    end = time.time()
    
    total_time = end - begin
    data_load_time = data_load_end - begin
    if it % 100 == 0:
        print(
            "{}: loss: {}, speed: {:.2f}it/sec, tot: {:.4f}s, data: {:.4f}s, data/tot: {:.4f}"
            .format(it, loss_acc / loss_acc0, 1 / total_time, total_time,
                    data_load_time, data_load_time / total_time))

# 存储模型
state = {
    'net': net.state_dict(),
    'opt': optimizer.state_dict(),
}
with open(MODEL_PATH, 'wb') as fout:
    mge.save(state, fout)
    

0: loss: Tensor([0.3767]), speed: 1.89it/sec, tot: 0.5298s, data: 0.2936s, data/tot: 0.5542
100: loss: Tensor([0.0736]), speed: 3.07it/sec, tot: 0.3257s, data: 0.2507s, data/tot: 0.7698
200: loss: Tensor([0.0468]), speed: 4.70it/sec, tot: 0.2126s, data: 0.1381s, data/tot: 0.6499
300: loss: Tensor([0.0366]), speed: 3.39it/sec, tot: 0.2951s, data: 0.2243s, data/tot: 0.7602
400: loss: Tensor([0.0323]), speed: 6.59it/sec, tot: 0.1518s, data: 0.0815s, data/tot: 0.5369
500: loss: Tensor([0.0297]), speed: 8.87it/sec, tot: 0.1128s, data: 0.0421s, data/tot: 0.3730
600: loss: Tensor([0.0276]), speed: 10.21it/sec, tot: 0.0980s, data: 0.0278s, data/tot: 0.2839
700: loss: Tensor([0.0266]), speed: 13.66it/sec, tot: 0.0732s, data: 0.0039s, data/tot: 0.0531
800: loss: Tensor([0.0267]), speed: 8.50it/sec, tot: 0.1177s, data: 0.0469s, data/tot: 0.3985
900: loss: Tensor([0.0255]), speed: 13.56it/sec, tot: 0.0737s, data: 0.0036s, data/tot: 0.0495
1000: loss: Tensor([0.0253]), speed: 8.66it/sec, tot: 0.115

## 加载网络并推理

训练完成后，就可以加载网络并进行推理：

首先，测试数据在下面：

In [3]:
TEST_PNG_PATH="./workspace/test_png/test"

In [30]:
from IPython.display import Image, display
from PIL import Image
import os
import cv2
import io
import numpy as np
net = SimpleUNet()

with open(MODEL_PATH, 'rb') as f:
    net.load_state_dict(mge.load(f)['net'])
    
@mge.jit.trace
def inference(inp):
    return net(inp)
	
#test_pngs = sorted([os.path.join(TEST_PNG_PATH, f) for f in os.listdir(TEST_PNG_PATH)])

for test_png_num in os.listdir(TEST_PNG_PATH):
	test_png_num=os.path.join(TEST_PNG_PATH,test_png_num)
	test_png_num=test_png_num+"/"
	#print(test_png_num)
	count=0
	for test_png in os.listdir(test_png_num):
		if not "png" in test_png:
			continue
		#print(test_png_num)
		#print(test_png)
		filePath=os.path.join(test_png_num,test_png)
		filePath=os.path.join("./workspace/result/",filePath)
		dirpath=os.path.join("./workspace/result/",test_png_num)
		test_png_path=os.path.join(test_png_num,test_png)
		#test_png=Image.open(test_png_path)
		#cv2.imread的返回值为3维数组
		img = cv2.imread(test_png_path,1)
		#print(test_png_path)
		#img = np.asarray(img)
		#img = cv2.imdecode(img, 1)
		img = cv2.resize(img, (0, 0), fx=4, fy=4)
		img = (np.float32(img) / 256).transpose((2, 0, 1))[None, :, :, :]
		img_out = inference(img)
		img_out = (img_out.numpy() * 256).clip(0, 255)[0].transpose((1, 2, 0)).copy()
		#content_out = cv2.imencode('.png', img_out)[1]
		#将图片存储
		#img_res=cv2.imread(img)
		if not os.path.exists(dirpath):
			os.makedirs(dirpath)
		cv2.imwrite(filePath,img_out)
		count+=1
		if count%100==0:
			print(filePath)
		#display(Image(data=content_out, width=400))

./workspace/result/./workspace/test_png/test/90/0100.png
./workspace/result/./workspace/test_png/test/90/0200.png
./workspace/result/./workspace/test_png/test/90/0300.png
./workspace/result/./workspace/test_png/test/90/0400.png
./workspace/result/./workspace/test_png/test/91/0100.png
./workspace/result/./workspace/test_png/test/91/0200.png
./workspace/result/./workspace/test_png/test/91/0300.png
./workspace/result/./workspace/test_png/test/91/0400.png
./workspace/result/./workspace/test_png/test/92/0100.png
./workspace/result/./workspace/test_png/test/92/0200.png
./workspace/result/./workspace/test_png/test/92/0300.png
./workspace/result/./workspace/test_png/test/92/0400.png
./workspace/result/./workspace/test_png/test/93/0100.png
./workspace/result/./workspace/test_png/test/93/0200.png
./workspace/result/./workspace/test_png/test/93/0300.png
./workspace/result/./workspace/test_png/test/93/0400.png
./workspace/result/./workspace/test_png/test/93/0500.png
./workspace/result/./workspace/

In [None]:
##打成tar包后，进行打分

In [None]:
import numpy as np
import sys
import cv2
from brainpp.oss import OSSPath
import tarfile
import io

#fin=tarfile.open(fileobj=OSSPath('s3://emc-share/work/topaz_release/gt.tar').open('rb'))
#fin=tarfile.open(fileobj=open('./gt.tar','rb'))

out_path='s3://fhq-dataproc/work/topaz/release/prediction_bicubic.tar'
#out_path='s3://fhq-dataproc/work/topaz/release/prediction_unet.tar'
fout_s3=OSSPath(out_path).open('rb')
#fout_s3=open('./prediction_bicubic.tar','rb')
fout=tarfile.open(fileobj=fout_s3)

cnt=0
rms=0
while True:
	tinfo=fin.next()
	if tinfo is None:
		break
	oldname=tinfo.name
	print(tinfo.name,file=sys.stderr)
	if not tinfo.isfile():
		continue
	content=fin.extractfile(tinfo).read()
	img=cv2.imdecode(np.fromstring(content,dtype='uint8'),1)

	while True:
		tinfo=fout.next()
		if tinfo is None:
			break
		if tinfo.isfile():
			break
	#print(tinfo.name)
	if tinfo is None:
		print('0')
		print('number of files mismatch',cnt)
		sys.exit(0)
	if tinfo.name != oldname.replace('gt','test'):
		print('0')
		print('filename mismatch')
		sys.exit(0)
	content=fout.extractfile(tinfo).read()
	img_out=cv2.imdecode(np.fromstring(content,dtype='uint8'),1)
	if img.shape!=img_out.shape or img.dtype!=img_out.dtype:
		print('0')
		print('image size mismatch',img.shape,img_out.shape)
		sys.exit(0)
	
	#if cnt==100:
		#import balls.supershow2 as s2
		#s2.submit('debug',{
			#'img':img,
			#'img_out':img_out,
			#'path':out_path
		#},topic='topaz_release',post_key=out_path)
	
	r=np.square(np.float32(img)-img_out).mean()
	rms+=r
	cnt+=1
	psnr=np.log10(256*256/max((rms/cnt),1e-10))*10

	print(tinfo.name,img.shape,'rms',rms/cnt,'psnr',psnr,'r',r,file=sys.stderr)
psnr=np.log10(256*256/max((rms/cnt),1e-10))*10
print(psnr)
print('looks good')

#    unet: 28.03
# bicubic: 28.61
