In [None]:
# default_exp dataloaders

# Dataloaders

> API details.

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

In [None]:
#export

from fastai.vision.all import *
from fastai_object_detection.core import *
#from fastai.torch_core import merge

# temp bug fix
# https://github.com/fastai/fastai/issues/3384
#TensorMultiCategory.register_func(Tensor.__getitem__, TensorMultiCategory, TensorBBox)

In [None]:
#export

class ObjectDetectionDataLoaders(DataLoaders):
    "Basic wrapper around `DataLoader`s with factory method for object dections problems"
    df = pd.DataFrame()
    img_id_col, img_path_col, class_col = "","","" 
    bbox_cols = []
    mask_path_col,object_id_col = "",""

    @classmethod
    @delegates(DataLoaders.from_dblock)
    def from_df(cls, df, valid_pct=0.2, img_id_col="image_id", img_path_col="image_path",
                bbox_cols=["x_min", "y_min", "x_max", "y_max"], class_col="class_name",
                mask_path_col="mask_path", object_id_col="object_id",
                seed=None, vocab=None, add_na=True, item_tfms=None, batch_tfms=None, debug=False, **kwargs):
        
        if vocab is None :
                vocab = [c for c in df[class_col].unique()]

        cls.df = df
        cls.img_id_col,cls.img_path_col,cls.class_col = img_id_col,img_path_col,class_col
        cls.bbox_cols = bbox_cols
        cls.mask_path_col,cls.object_id_col = mask_path_col,object_id_col
        
        with_mask = mask_path_col in df.columns
        
        #if item_tfms is None: item_tfms = [Resize(800, method="pad", pad_mode="zeros")]
            
        if not with_mask:
            dblock = DataBlock(
                blocks=(ImageBlock(cls=PILImage), BBoxBlock, BBoxLblBlock(vocab=vocab, add_na=add_na)),
                n_inp=1,
                splitter=RandomSplitter(valid_pct),
                get_items=cls._get_images,
                get_y=[cls._get_bboxes, cls._get_labels],
                item_tfms=item_tfms,
                batch_tfms=batch_tfms)
            if debug: print(dblock.summary(df))
            res = cls.from_dblock(dblock, df, path=".", before_batch=[bb_pad], **kwargs)
            
        else:            
            dblock = DataBlock(
                blocks=(ImageBlock(cls=PILImage), BinaryMasksBlock, 
                        BBoxBlock, BBoxLblBlock(vocab=vocab, add_na=add_na)),
                n_inp=1,
                splitter=RandomSplitter(valid_pct),
                get_items=cls._get_images,
                get_y=[cls._get_masks, cls._get_bboxes, cls._get_labels],
                item_tfms=item_tfms,
                batch_tfms=[TensorBinMasks2TensorMask(), *batch_tfms])
            if debug: print(dblock.summary(df))
            res = cls.from_dblock(dblock, df, path=".", before_batch=[_bin_mask_stack_and_padding],**kwargs)
            
        return res
    
    def _get_images(df):
        img_path_col = ObjectDetectionDataLoaders.img_path_col
        
        fns = L(fn for fn in df[img_path_col].unique())
        return fns

    def _get_bboxes(fn):
        df = ObjectDetectionDataLoaders.df
        img_path_col = ObjectDetectionDataLoaders.img_path_col
        x_min_col, y_min_col, x_max_col, y_max_col = ObjectDetectionDataLoaders.bbox_cols
        
        filt = df[img_path_col] == fn #Path(fn)
        bboxes = [list(i) for i in zip(df.loc[filt,x_min_col], df.loc[filt,y_min_col], 
                                       df.loc[filt,x_max_col], df.loc[filt,y_max_col])]
        return bboxes

    def _get_labels(fn):
        df = ObjectDetectionDataLoaders.df
        img_path_col = ObjectDetectionDataLoaders.img_path_col
        class_col = ObjectDetectionDataLoaders.class_col
        
        filt = df[img_path_col] == fn #Path(fn)
        labels = [l for l in df.loc[filt, class_col]]
        return labels
    
    def _get_masks(fn):
        df = ObjectDetectionDataLoaders.df
        img_path_col = ObjectDetectionDataLoaders.img_path_col
        mask_path_col = ObjectDetectionDataLoaders.mask_path_col
        
        filt = df[img_path_col] == fn
        mask_paths = [m for m in df.loc[filt, mask_path_col]]
        return mask_paths