# import 

In [None]:
import cv2

In [None]:
import random

In [None]:
import os

In [None]:
from fastprogress.fastprogress import progress_bar

In [None]:
import numpy as np

In [None]:
from scipy.ndimage.interpolation import map_coordinates

In [None]:
from scipy.ndimage.filters import gaussian_filter

In [None]:
from PIL import Image

In [None]:
from matplotlib import pyplot as plt

In [None]:
import pandas as pd

In [None]:
from fastprogress.fastprogress import progress_bar

# functions

In [None]:
# export
# 修改PIL对图片像素的限制
# opencv需要修改源代码才能修改此限制，所以我们用PIL
Image.MAX_IMAGE_PIXELS = None

In [None]:
# export
#留个边
EDGE = 30

In [None]:
# export
def gen_datasets(src, output_path = '../data/dataset_20200818'
                 , output_img_path = 'image'
                 , output_ds_csv_fn = 'gends.csv', DSSIZE = 10000
                 , BLANKP = 0.2, DSIMGW = 512, DSIMGH = 512):
    '''
    生成数据集
    参数：
        src: 数据源，元祖列表，元祖内容分别是图片的带路径的完整文件名和label文件名
        output_path: 生成数据集的保存路径
        output_img_path：生成文件的保存路径，在output_path下面
        output_ds_csv_fn: csv文件名，保存在output_path下面
        DSSIZE：总的数据集的数量
        BLANKP：没有目标的占比
        DSIMGW、DSIMGH：图片宽高
    结果：
        1、True/False 成功失败
        2、错误信息（失败时）/数据集的描述信息（成功时）
            数据集描述信息：
            {
                image: 生成的图片文件名list。
                pos：list，每个元素与image一一对应，是病害点的坐标(x,y)
                tag：list，每个元素与image一一对应，是图片中包含的符号在原始大图上的位置信息列表。
            }
    '''
    if not os.path.exists(os.path.join(output_path, output_img_path)):        
        os.makedirs(os.path.join(output_path, output_img_path))
    
    imgs = []
    allpts = []#记录所有的点和所在图片序号
    lbls = []
    for it in src:
        imgfn = it[0]
        lblfn = it[1]
        img = Image.open(it[0])
        img = np.asarray(img)
        #img = img[...,[2,1,0]]#后面保存的时候转一下
        lbl = Image.open(it[1])
        lbl = np.asarray(lbl)
        pts = np.where(lbl > 0)
        for idx in range(len(pts[0])):
            y, x = pts[0][idx], pts[1][idx]
            allpts += [(len(imgs), (x, y))]
        imgs += [img]
        lbls += [lbl]
        
    rets = {'image': [], 'pos': [], 'tag': []}
    pbar = progress_bar(range(DSSIZE))
    pbar.comment = '生成中'
    for i in pbar:
        #不是空白的
        if i < DSSIZE * (1 - BLANKP):
            idx = i % len(allpts)
            retry_cnt = 0
            while True:
                dx = random.randint(EDGE, DSIMGW - EDGE)#留个边
                dy = random.randint(EDGE, DSIMGH - EDGE)
                px, py = allpts[idx][1]
                imgidx = allpts[idx][0]
                sy = py - dy
                sx = px - dx
                simg = nimg[sy : sy + DSIMGW, sx : sx + DSIMGH]
                simg = simg.copy()
                
                pos = [(dx, dy)]
                tag = ['%d %d %d %d %d' % (imgidx, sx, sy, DSIMGW, DSIMGH)]
                #可能区域内还有目标点
                cx, cy = sx + DSIMGW // 2, sy + DSIMGH // 2
                cnt = 0      
                #因为要留边，可能会出现新找到的点正好在留边上，这时要重新来一下
                retry = False
                for pt in allpts:
                    if pt[0] == imgidx:
                        x, y = pt[1]                    
                        if x == px and y == py:
                            continue
                        
                        if abs(x - cx) < DSIMGW // 2 - EDGE and \
                                abs(y - cy) < DSIMGH // 2 - EDGE:                                
                            cnt += 1
                            pos += [(x - sx, y - sy)]
                            #print(i, idx, pos, sx, sy)
                            #cv2.circle(simg, (x - sx, y - sy), 10, (0,255,255), -1)
                            
                        #这个是在留边的范围内。要重新搞一下
                        elif abs(x - cx) < DSIMGW // 2 + EDGE and \
                                abs(y - cy) < DSIMGH // 2 + EDGE:
                            retry = True
                            #print('retry!!!')
                            break
                if retry:
                    #可能一种情况是没法躲开周围的目标点，这时候就死循环了。碰到了再处理
                    retry_cnt += 1
                    if retry_cnt > 1000:
                        assert False, 'retry fail:%d %d %d %d' % (i, idx, sx, sy)
                    continue
                #assert cnt == 1, 'cnt != 1: %d' % cnt
                #cv2.circle(simg, (dx, dy), 10, (255,255,255), -1)
                #plt.imshow(simg)

                fn = os.path.join(output_path, output_img_path, '%05d.jpg' % i)
                simg = simg[...,[2,1,0]]
                cv2.imwrite(fn, simg)
                rets['image'] += [fn]
                rets['pos'] += [pos]
                rets['tag'] += [tag]
                break
        else:
            idx = i % len(imgs)
            #print(imgs[idx].shape)
            h, w, _ = imgs[idx].shape            
            while True:             
                sx = random.randint(0, w - DSIMGW)
                sy = random.randint(0, h - DSIMGH)
                cx = sx + DSIMGW // 2
                cy = sy + DSIMGH // 2
                #确保没有目标点被圈进来
                invalid = False
                for pt in allpts:
                    if pt[0] == idx:
                        x, y = pt[1]
                        #远离目标点
                        if abs(x - cx) < DSIMGW // 2 + EDGE and \
                                abs(y - cy) < DSIMGW // 2 + EDGE:
                            invalid = True
                            #print('invalid!', cx, cy, pt)
                            break
                if invalid:
                    continue
                simg = nimg[sy : sy + DSIMGW, sx : sx + DSIMGH]
                simg = simg.copy()
                simg = simg[...,[2,1,0]]
                fn = os.path.join(output_path, output_img_path, '%05d.jpg' % i)
                cv2.imwrite(fn, simg)
                rets['image'] += [fn]
                rets['pos'] += [[(-1, -1)]]
                rets['tag'] += ['%d %d %d %d %d' % (imgidx, sx, sy, DSIMGW, DSIMGH)]
                
                #确保一下里面没有包含点
                if True:
                    slbl = lbls[idx][sy : sy + DSIMGH, sx : sx + DSIMGW]
                    tpts = np.where(slbl > 0)
                    assert len(tpts[0]) == 0
                
                break
        df = pd.DataFrame(rets)
        df.to_csv(os.path.join(output_path, output_ds_csv_fn))  
    return rets
        
        

# test

In [None]:
gen_datasets(src = [('../data/src/1.jpg', '../data/src/label_1.png')]
            , output_path = './ds_20200818' , DSSIZE = 10000)

# export