# import

In [None]:
#export
from fastai.vision import *

In [None]:
#export
import cv2

In [None]:
#export
import os

In [None]:
import numpy as np

# functions

In [None]:
#export
def get_y(x, ds_rootdir, imgdir, maskdir):
    yfn = os.path.join(ds_rootdir, maskdir, '%s_mask%s' % (x.stem, x.suffix))
    #print(yfn, x.stem, x.suffix)
    return yfn

In [None]:
#export
def imgp_CLAHE(pil_img):
    '''
    对图片进行限制对比度自适应直方图均衡化
    '''
    img = cv2.cvtColor(np.asarray(pil_img),cv2.COLOR_RGB2BGR)
    #print(img.shape)
    clahe = cv2.createCLAHE(clipLimit=4.0, tileGridSize=(8,8))
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    hsv[:, :, 2] = clahe.apply(hsv[:, :, 2])
    img2 = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    ret = PIL.Image.fromarray(cv2.cvtColor(img2,cv2.COLOR_BGR2RGB))
    return ret
    

In [None]:
#export
def get_databunch(ds_root_dir = 'dataset_20200708', ds_imgdir = 'image'
                  , ds_maskdir = 'mask', bs = 16, valid_pct = 0.2
                  , device = torch.device('cuda')
                  , transforms = get_transforms(max_zoom = 1.)
                  , img_processor = []):
    '''
    获取databunch
    参数：
        ds_root_dir：数据集的根目录
        ds_imgdir：图片子目录
        ds_maskdir: mask图片子目录
        bs：batch_size
        valid_pct:验证集百分比
        device: 设备
        transforms: 无缩放，其余默认参数。
        img_process: 图片处理。取值范围：
            'CLAHE': 比度自适应直方图均衡化
    返回值：
        databunch
    '''
    def imgp_afteropen(pil_img, img_processor):
        #import pdb; pdb.set_trace()
        ret = pil_img
        for imgp in img_processor:
            if 'CLAHE' == imgp:
                ret = imgp_CLAHE(ret)
            else:
                assert False, '没有实现'
        return ret
    
    img_processor_func = None
    if img_processor is not None and len(img_processor) > 0:
        img_processor_func = partial(imgp_afteropen, img_processor = img_processor)
        
    #import pdb; pdb.set_trace()
    data = SegmentationItemList.from_folder(os.path.join(ds_root_dir, ds_imgdir)
                , after_open = img_processor_func)
    
    data = data.split_by_rand_pct(valid_pct)
    data = data.label_from_func( \
            partial(get_y, ds_rootdir = ds_root_dir, imgdir = ds_imgdir, maskdir = ds_maskdir) \
            , classes=['bg', 'water'])
    
    #import pdb; pdb.set_trace()
    if transforms is not None:
        data = data.transform(transforms, tfm_y = True)
    data = data.databunch(bs=bs, num_workers = 0, device = device)
    data = data.normalize(imagenet_stats)
    
    '''
    data = (UnetSegmentationItemList
        .from_folder(os.path.join(ds_root_dir, ds_imgdir))
        .split_by_rand_pct(0.2)
        .label_from_func(partial(get_y, ds_rootdir = ds_root_dir, imgdir = ds_imgdir, maskdir = ds_maskdir), classes=['0', '255'])
        .transform(get_transforms(), tfm_y=True)
        .databunch(bs=bs, num_workers = 0)
        .normalize(imagenet_stats)
       )
    '''    
    return data

#data = get_databunch(bs = 4)
#data.show_batch()

# test

In [None]:
device = torch.device('cuda')

In [None]:
data = get_databunch(bs = 4, device = device
        , ds_root_dir = 'label/dataset_20200713')

In [None]:
#限制对比度自适应直方图均衡化
data = get_databunch(bs = 4, device = device, transforms = None
        , ds_root_dir = 'label/dataset_20200713', img_processor = ['CLAHE'])

# export

In [None]:
!python notebook2script.py --fname 'databunch.ipynb' --outputDir './exp/'