In [None]:
#|default_exp data.transforms

In [None]:
#|export
from __future__ import annotations

import pandas as pd
from pathlib import Path

from fastcore.foundation import mask2idxs

from fastai.data.transforms import IndexSplitter

from fastxtend.imports import *

# Processing data and basic transforms

> Additional functions for getting, splitting, and labeling data, as well as generic transforms

In [None]:
#|hide
from nbdev.showdoc import *
from fastxtend.test_utils import *

In [None]:
#|export
def KFoldColSplitter(fold:listified[int]=0, col:int|str='folds'):
    "Split `items` (supposed to be a dataframe) by `fold` in `col`"
    def _inner(o):
        assert isinstance(o, pd.DataFrame), "KFoldColSplitter only works when your items are a pandas DataFrame"
        valid_col = o.iloc[:,col] if isinstance(col, int) else o[col]
        valid_idx = valid_col.isin(fold) if is_listy(fold) else valid_col.values == fold
        return IndexSplitter(mask2idxs(valid_idx))(o)
    return _inner

In [None]:
#|hide
df = pd.DataFrame({'a': [0,1,2,3,4,5,6,7,8,9], 'b': [0,1,2,3,4,0,1,2,3,4]})
splits = KFoldColSplitter(col='b')(df)
test_eq(splits, [[1,2,3,4,6,7,8,9], [0,5]])

# Works with strings or index
splits = KFoldColSplitter(col=1)(df)
test_eq(splits, [[1,2,3,4,6,7,8,9], [0,5]])

# Works with single or multiple folds
df = pd.DataFrame({'a': [0,1,2,3,4,5,6,7,8,9], 'folds': [0,1,2,3,4,0,1,2,3,4]})
splits = KFoldColSplitter(fold=[0,1],col='folds')(df)
test_eq(splits, [[2,3,4,7,8,9], [0,1,5,6]])

In [None]:
#|hide
from fastcore.basics import ifnone, range_of

def _test_splitter(f, items=None):
    "A basic set of condition a splitter must pass"
    items = ifnone(items, range_of(30))
    trn,val = f(items)
    assert 0<len(trn)<len(items)
    assert all(o not in val for o in trn)
    test_eq(len(trn), len(items)-len(val))
    # test random seed consistency
    test_eq(f(items)[0], trn)
    return trn, val

In [None]:
#|exporti
def _parent_idxs(items, name):
    def _inner(items, name): return mask2idxs(Path(o).parent.name == name for o in items)
    return [i for n in L(name) for i in _inner(items,n)]

In [None]:
#|export
def ParentSplitter(train_name:str='train', valid_name:str='valid'):
    "Split `items` from the parent folder names (`train_name` and `valid_name`)."
    def _inner(o):
        return _parent_idxs(o, train_name),_parent_idxs(o, valid_name)
    return _inner

In [None]:
#|hide
fnames = ['dir/train/9932.png', 'dir/valid/7189.png', 
          'dir/valid/7320.png', 'dir/train/9833.png',  
          'dir/train/7666.png', 'dir/valid/925.png',
          'dir/train/724.png',  'dir/valid/93055.png']
splitter = ParentSplitter()

_test_splitter(splitter, items=fnames)
test_eq(splitter(fnames),[[0,3,4,6],[1,2,5,7]])

In [None]:
#|exporti
def _greatgrandparent_idxs(items, name):
    def _inner(items, name): return mask2idxs(Path(o).parent.parent.parent.name == name for o in items)
    return [i for n in L(name) for i in _inner(items,n)]

In [None]:
#|export
def GreatGrandparentSplitter(train_name:str='train', valid_name:str='valid'):
    "Split `items` from the great grand parent folder names (`train_name` and `valid_name`)."
    def _inner(o):
        return _greatgrandparent_idxs(o, train_name),_greatgrandparent_idxs(o, valid_name)
    return _inner

In [None]:
#|hide
fnames = ['dir/train/9/9/9932.png', 'dir/valid/7/1/7189.png', 
          'dir/valid/7/3/7320.png', 'dir/train/9/8/9833.png',  
          'dir/train/7/6/7666.png', 'dir/valid/9/2/925.png',
          'dir/train/7/2/724.png',  'dir/valid/9/3/93055.png']
splitter = GreatGrandparentSplitter()

_test_splitter(splitter, items=fnames)
test_eq(splitter(fnames),[[0,3,4,6],[1,2,5,7]])