In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import scipy
import scipy.ndimage
import torch as t

In [3]:
#获得图片完整路径
def readfile(path):
    filenames = os.listdir(path)
    data_dir = [path + f for f in filenames]
    return data_dir

#bicubic 放大到想要的尺寸
def preproccess(path, scale, is_gray = True):
    img = Image.open(path).convert('YCbCr')
    if is_gray: img, Cb, Cr = img.split()
    size = np.array(img.size) - (np.array(img.size) % scale)
    img = img.crop([0,0,*size])
    
    #normailzed
    img = np.array(img).astype(np.float)
    label = img / 255
    
    #bicubic,先缩放，再放大到原来的尺寸
    result = scipy.ndimage.interpolation.zoom(label, (1./scale), prefilter=False)
    result = scipy.ndimage.interpolation.zoom(result, (scale/1.), prefilter=False)
    
    return result,label  # label : high-resolution      result : low-resolution

#提取重叠块
def get_sub_images(img, sub_size, stride):
    for i in range(0, img.size[0]-sub_size+1, stride):
        for j in range(0, img.size[1]-sub_size+1, stride):
            yield img.crop([i, j, i+sub_size, j+sub_size])
            
            
def get_train_data(path, image_size=33, label_size=21, stride=14, scale=3, is_gray=True):
    files = readfile(path)
    padding = abs(image_size - label_size) / 2
    sub_images = []
    sub_labels = []
    for f in files:
        image, label = preproccess(f, scale, is_gray)   #得到实验组，对照组
        if len(image.shape) == 3: height, width, _ = image.shape
        else: height, width = image.shape
        #对图像进行切分，取中心区域作为对照
        for i in range(0, height - image_size + 1, stride):
            for j in range(0, width - image_size + 1, stride):
                sub_i = image[i:i+image_size, j:j+image_size]   #33 x 33
                sub_l = label[i+int(padding):i+int(padding)+label_size, j+int(padding):j+int(padding)+label_size]   #21 x 21
                    
                sub_i = sub_i[np.newaxis, :]
                sub_l = sub_l[np.newaxis, :]
                    
                sub_images.append(sub_i)
                sub_labels.append(sub_l)
    sub_images = np.array(sub_images)
    sub_labels = np.array(sub_labels)
    return sub_images, sub_labels
    
def get_test_data(path, scale, is_gray=True):
    files = readfile(path)
    images = []
    labels = []
    for f in files:
        image, label = preproccess(f, scale, is_gray)
        images.append(image)
        labels.append(label)
    return images, labels
