# imports

In [None]:
%matplotlib inline

In [None]:
from IPython.core import debugger as idb

In [None]:
# export
import pandas as pd

In [None]:
# export
import re

In [None]:
# export
import numpy as np

In [None]:
# export
from fastai.vision import *

In [None]:
# export
import os

In [None]:
# export
import cv2

In [None]:
# export
from matplotlib import pyplot as plt

In [None]:
# export
import random

In [None]:
# export
#from FLAI.detect_symbol.exp import databunch as databunch_detsym

# functions

In [None]:
#export
pat_coord = re.compile(r'\d+')
#图片在data_root目录下面的image目录里面，数字命名。
pat_imgName = re.compile(r'(\w+/\d+\.jpg)$')
pat_clas = re.compile(r'\w+')
pat_num = re.compile(r'\d+')    

def get_label_from_df(fn, df, pat_imgName, coord_col, cat_col, asbbox = False):
    '''
    fn: 
        file path.
    df: 
        a dataframe stores all the label information, imageName shoud be as index.
    repat_imgName: 
        a regular expression pattern, used to find the imageName from fn, where imageName is stored in df 
    box_col:
        the column name of bounding boxs
    cat_col:
        the column name of categories
    '''
    #print('glfd:', type(fn), fn)    
    fn = pat_imgName.findall(str(fn))[0]
    coords = df.loc[fn,coord_col]
    coords = pat_num.findall(coords)
    coords = list(map(np.long, coords))
    coords = np.array(coords).reshape(-1, 2) * 1.0
    coords = coords.tolist()
    
    if asbbox:#暂时还沿用之前的bbox的形式。
        #import pdb;pdb.set_trace()
        ncoords = []
        for c in coords:
            ncoords += [[c[0], c[1], c[0] + 1, c[1] + 1]]
        coords = ncoords
    
    cats = df.loc[fn,cat_col]
    cats = pat_clas.findall(cats)
    if fn.find('00000') >= 0:
        print('get_label_from_df', coords, cats, asbbox)
        
    return (coords, cats)

In [None]:
# export
# 这个函数是为了在其它模块的设计时快速构造databunch
def get_databunch(data_root='./ds_20200818'
        , csv_name='gends.csv', valid_pct=0.2, bs=64, device=torch.device('cpu'), cache=False):
    '''
    --------------------------------
    参数：
    -- data_root：数据集的总目录
    -- img_path: 图片目录
    -- csv_name：存放标注信息的csv文件名，其要符合“对csv的要求”
    -- valid_pct：随机分割训练/验证集，该参数指定验证集的比例
    -- bs：batch size
    -- device：在datalaoder迭代时，dataloader先将batch加载到该device，做batch transform，然后返回。
    -- cache：dataset是否将所有图片预缓存入内存
    --------------------------------
    返回值：
    -- 一个databunch对象
    --------------------------------
    对csv的要求：
    1，带index
    2，存放图片名的列名称为"image"
    3，存放位置信息信息的列名称为"coord"
    4，存放类别信息的列名称为"clas"
    --------------------------------
    '''
    data_root = Path(data_root)
    csv_name = Path('gends.csv')
    # 读入csv，稍作处理，方便get_label函数操作
    csv_path = data_root/csv_name
    df = pd.read_csv(csv_path,index_col=0)
    df = df.set_index('image')

    # ItemList
    data = ObjectItemList.from_csv(path=data_root, csv_name=csv_name
                                   , cols='image')    

    # split ItemList to get ItemLists
    data = data.split_by_rand_pct(valid_pct=valid_pct)

    # label ItemLists to get LabelLists
    func = partial(get_label_from_df, df=df, pat_imgName=pat_imgName
                   , coord_col='coord', cat_col='clas', asbbox = True)
    data = data.label_from_func(func=func)

    # add transforms
#     trn_tfms = [*zoom_crop(scale=(0.9,1.1),do_rand=True,p=1),
#                 rot90_affine(use_on_y=True)]
#     val_tfms = []
#    data = data.transform(tfms=[trn_tfms,val_tfms], tfm_y=True, remove_out=True)
    # create DataBunch from LabelLists
    data = data.databunch(bs=bs, device=device, collate_fn=bb_pad_collate, num_workers=0)

    # normalize
    data = data.normalize()
    
    # 缓存图片
    if cache:
        data.cache_ds_img()
        
    return data

In [None]:
# export
def get_label_from_df_points(fn, df, pat_imgName, coord_col, cat_col):
    '''
    fn: 
        file path.
    df: 
        a dataframe stores all the label information, imageName shoud be as index.
    repat_imgName: 
        a regular expression pattern, used to find the imageName from fn, where imageName is stored in df 
    box_col:
        the column name of bounding boxs
    cat_col:
        the column name of categories
    '''
    fn = pat_imgName.findall(str(fn))[0]
    coords = df.loc[fn,coord_col]
    coords = pat_num.findall(coords)
    coords = list(map(np.long, coords))
    coords = np.array(coords).reshape(-1, 2) * 1.0    
    coords = coords.tolist()    
    cats = df.loc[fn,cat_col]
    cats = pat_clas.findall(cats)
     
    #print('ctre', coords[0])
    return Tensor(coords[0])
    #return Tensor(coords)

In [None]:
# export
# 只能返回一个点的列表
def get_databunch_points(data_root='./ds_20200818'
        , csv_name='gends.csv', valid_pct=0.2, bs=64, device=torch.device('cpu'), cache=False):
    '''
    --------------------------------
    参数：
    -- data_root：数据集的总目录
    -- img_path: 图片目录
    -- csv_name：存放标注信息的csv文件名，其要符合“对csv的要求”
    -- valid_pct：随机分割训练/验证集，该参数指定验证集的比例
    -- bs：batch size
    -- device：在datalaoder迭代时，dataloader先将batch加载到该device，做batch transform，然后返回。
    -- cache：dataset是否将所有图片预缓存入内存
    --------------------------------
    返回值：
    -- 一个databunch对象
    --------------------------------
    对csv的要求：
    1，带index
    2，存放图片名的列名称为"image"
    3，存放位置信息信息的列名称为"coord"
    4，存放类别信息的列名称为"clas"
    --------------------------------
    '''
    data_root = Path(data_root)
    csv_name = Path('gends.csv')
    # 读入csv，稍作处理，方便get_label函数操作
    csv_path = data_root/csv_name
    df = pd.read_csv(csv_path,index_col=0)
    df = df.set_index('image')

    data = PointsItemList.from_folder(data_root).split_by_rand_pct(valid_pct = valid_pct)
    func = partial(get_label_from_df_points, df=df, pat_imgName=pat_imgName
                   , coord_col='coord', cat_col='clas')
    #import pdb;pdb.set_trace();
    data = data.label_from_func(func=func)
    
    data = data.databunch(bs=bs, device=device)
    data = data.normalize(imagenet_stats)
    
    return data
    
    # 缓存图片
    if cache:
        data.cache_ds_img()
        
    return data

In [None]:
from fastai.vision.data import ObjectCategoryProcessor, _get_size

In [None]:
from fastai.vision.image import _draw_rect

In [None]:
ASBBOX = True
class ImageBBox_T(ImagePoints):
    "Support applying transforms to a `flow` of bounding boxes."
    def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection=None,
                 classes:dict=None, pad_idx:int=0):
        super().__init__(flow, scale, y_first)
        self.pad_idx = pad_idx
        if labels is not None and len(labels)>0 and not isinstance(labels[0],Category):
            labels = array([Category(l,classes[l]) for l in labels])
        self.labels = labels

    def clone(self) -> 'ImageBBox':
        "Mimic the behavior of torch.clone for `Image` objects."
        flow = FlowField(self.size, self.flow.flow.clone())
        return self.__class__(flow, scale=False, y_first=False, labels=self.labels, pad_idx=self.pad_idx)

    @classmethod
    def create(cls, h:int, w:int, bboxes:Collection[Collection[int]], labels:Collection=None, classes:dict=None,
               pad_idx:int=0, scale:bool=True)->'ImageBBox':
        "Create an ImageBBox object from `bboxes`."
        if isinstance(bboxes, np.ndarray) and bboxes.dtype == np.object: bboxes = np.array([bb for bb in bboxes])
        bboxes = tensor(bboxes).float()
        tr_corners = torch.cat([bboxes[:,0][:,None], bboxes[:,3][:,None]], 1)
        bl_corners = bboxes[:,1:3].flip(1)
        bboxes = torch.cat([bboxes[:,:2], tr_corners, bl_corners, bboxes[:,2:]], 1)
        flow = FlowField((h,w), bboxes.view(-1,2))
        return cls(flow, labels=labels, classes=classes, pad_idx=pad_idx, y_first=True, scale=scale)

    def _compute_boxes(self) -> Tuple[LongTensor, LongTensor]:
        bboxes = self.flow.flow.flip(1).view(-1, 4, 2).contiguous().clamp(min=-1, max=1)
        mins, maxes = bboxes.min(dim=1)[0], bboxes.max(dim=1)[0]
        bboxes = torch.cat([mins, maxes], 1)
        mask = (bboxes[:,2]-bboxes[:,0] > 0) * (bboxes[:,3]-bboxes[:,1] > 0)
        if len(mask) == 0: return tensor([self.pad_idx] * 4), tensor([self.pad_idx])
        res = bboxes[mask]
        if self.labels is None: return res,None
        return res, self.labels[to_np(mask).astype(bool)]

    @property
    def data(self)->Union[FloatTensor, Tuple[FloatTensor,LongTensor]]:
        bboxes,lbls = self._compute_boxes()
        lbls = np.array([o.data for o in lbls]) if lbls is not None else None
        return bboxes if lbls is None else (bboxes, lbls)

    def show(self, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
        color:str='white', **kwargs):
        "Show the `ImageBBox` on `ax`."
        if ax is None: _,ax = plt.subplots(figsize=figsize)
        bboxes, lbls = self._compute_boxes()
        h,w = self.flow.size
        bboxes.add_(1).mul_(torch.tensor([h/2, w/2, h/2, w/2])).long()
        for i, bbox in enumerate(bboxes):
            if lbls is not None: text = str(lbls[i])
            else: text=None
            _draw_rect(ax, bb2hw(bbox), text=text, color=color)

In [None]:
ASBBOX = False
class ImageBBox_T(ImagePoints):
    "Support applying transforms to a `flow` of bounding boxes."
    def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection=None,
                 classes:dict=None, pad_idx:int=0):
        super().__init__(flow, scale, y_first)
        self.pad_idx = pad_idx
        if labels is not None and len(labels)>0 and not isinstance(labels[0],Category):
            labels = array([Category(l,classes[l]) for l in labels])
        self.labels = labels

    def clone(self) -> 'ImageBBox_T':
        "Mimic the behavior of torch.clone for `Image` objects."
        flow = FlowField(self.size, self.flow.flow.clone())
        return self.__class__(flow, scale=False, y_first=False, labels=self.labels, pad_idx=self.pad_idx)

    @classmethod
    def create(cls, h:int, w:int, pts:Collection[Collection[int]], labels:Collection=None, classes:dict=None,
               pad_idx:int=0, scale:bool=True)->'ImageBBox':
        "Create an ImageBBox object from `bboxes`."
        if isinstance(pts, np.ndarray) and pts.dtype == np.object: pts = np.array([bb for bb in pts])
        pts = tensor(pts).float()        
        flow = FlowField((h,w), pts.view(-1,2))
        return cls(flow, labels=labels, classes=classes, pad_idx=pad_idx, y_first=True, scale=scale)

    @property
    def data(self)->Union[FloatTensor, Tuple[FloatTensor,LongTensor]]:
        #import pdb;pdb.set_trace()
        #pts = self.flow.flow.flip(1).view(-1, 2, 2).contiguous().clamp(min=-1, max=1)
        pts = self.flow.flow.flip(1).contiguous().clamp(min=-1, max=1)
        
        lbls = np.array([o.data for o in self.labels]) if self.labels is not None else None
        return pts if lbls is None else (pts, lbls)

    def show(self, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
        color:str='white', **kwargs):
        "Show the `ImageBBox` on `ax`."
        if ax is None: _,ax = plt.subplots(figsize=figsize)
        pts = self.flow.flow.flip(1).contiguous().clamp(min=-1, max=1)
        #import pdb; pdb.set_trace()
        lbls = np.array([o.data for o in self.labels]) if self.labels is not None else None
        h,w = self.flow.size
        pts.add_(1).mul_(torch.tensor([h/2, w/2])).long()
        for i, pt in enumerate(pts):
            if lbls is not None: text = str(lbls[i])
            else: text=None
            #print('draw:', pt)
            text = None #反正没有别的东西了。不需要显示文字
            _draw_rect(ax, np.array([pt[1],pt[0], 5, 5]), text=text, color=color)
            
def ptslbl_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
    "Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
    if isinstance(samples[0][1], int): return data_collate(samples)
    #import pdb;pdb.set_trace()
    max_len = max([len(s[1].data[1]) for s in samples])
    pts = torch.zeros(len(samples), max_len, 2)
    labels = torch.zeros(len(samples), max_len).long() + pad_idx
    imgs = []
    for i,s in enumerate(samples):
        imgs.append(s[0].data[None])
        
        tpts, lbls = s[1].data
        if not (pts.nelement() == 0):
            pts[i,-len(lbls):] = tpts
            labels[i,-len(lbls):] = tensor(lbls)
    return torch.cat(imgs,0), (pts,labels)            

In [None]:
class ObjectCategoryList_PtLbl(MultiCategoryList):
    "`ItemList` for labelled bounding boxes."
    _processor = ObjectCategoryProcessor
    
    def get(self, i):
        #import pdb;pdb.set_trace()
        return ImageBBox_T.create(*_get_size(self.x,i), *self.items[i], classes=self.classes, pad_idx=self.pad_idx)
    
    def analyze_pred(self, pred): return pred

    def reconstruct(self, t, x):
        (bboxes, labels) = t
        if len((labels - self.pad_idx).nonzero()) == 0: return
        i = (labels - self.pad_idx).nonzero().min()
        bboxes,labels = bboxes[i:],labels[i:]
        return ImageBBox_T.create(*x.size, bboxes, labels=labels, classes=self.classes, scale=False)


class ObjectItemList_PtLbl(ImageList):
    "`ItemList` suitable for object detection."
    _label_cls,_square_show_res = ObjectCategoryList_PtLbl,False
    
# ItemList
data = ObjectItemList_PtLbl.from_csv(path=data_root, csv_name=csv_name
                               , cols='image')
# split ItemList to get ItemLists
data = data.split_by_rand_pct(valid_pct=0.2)
# label ItemLists to get LabelLists
func = partial(get_label_from_df, df=df, pat_imgName=pat_imgName
               , coord_col='coord', cat_col='clas', asbbox = ASBBOX)
data = data.label_from_func(func=func)
# add transforms
#     trn_tfms = [*zoom_crop(scale=(0.9,1.1),do_rand=True,p=1),
#                 rot90_affine(use_on_y=True)]
#     val_tfms = []
#    data = data.transform(tfms=[trn_tfms,val_tfms], tfm_y=True, remove_out=True)
# create DataBunch from LabelLists

#data = data.databunch(bs=bs, device=device, collate_fn=bb_pad_collate, num_workers=0)
data = data.databunch(bs=bs, device=device, collate_fn=ptslbl_pad_collate, num_workers=0)

# normalize
data = data.normalize()
data.show_batch(rows = 3)

In [None]:
data.show_batch(rows = 3)

# test

In [None]:
# 做些设置
data_root = './ds_20200818/'
data_root = Path(data_root)

csv_name = 'gends.csv'
csv_path = data_root/csv_name

img_subpath = 'image'
img_path = data_root/img_subpath

bs = 64


device = 'cpu'
device = torch.device('cuda')

In [None]:
# 读入csv，稍作处理，方便get_label函数操作
df = pd.read_csv(csv_path,index_col=0)
df = df.set_index('image')
df.head()

In [None]:
data = get_databunch_points()

In [None]:
data.show_batch(rows=3)

In [None]:
data = get_databunch()

In [None]:
data.show_batch()

In [None]:
URLs.BIWI_HEAD_POSE

In [None]:
# # add transforms
# trn_tfms = [*zoom_crop(scale=(0.9,1.1),do_rand=True,p=1),
#             rot90_affine(use_on_y=True)]
# val_tfms = []

# data = data.transform(tfms=[trn_tfms,val_tfms], tfm_y=True, remove_out=True)

In [None]:
# 查看统计信息
#databunch_detsym.databunch_statistics(data);

In [None]:
#lesson3-head-pose.ipynb
data = PointsItemList.from_folder('./ds_20200818/image').split_by_rand_pct(valid_pct = 0.2)
func = partial(get_label_from_df, df=df, pat_imgName=pat_imgName
                   , coord_col='coord', cat_col='clas')
data = data.label_from_func(func=func)

data = data.databunch(bs=8, device=device, collate_fn=bb_pad_collate)
data = data.normalize(imagenet_stats)

In [None]:
data

# export

In [None]:
!python ../notebook2script.py --fname 'databunch.ipynb' --outputDir '../exp/'