In [None]:
#| default_exp data.split
#| default_cls_lvl 3

In [None]:
#| export
from fastai.data.all import *
from tsfast.data.core import CreateDict

## 4. Split in Training, Validation
Splitting kann anhand von vorher bekannten Indizes, dem Dateipfad oder anderen allgemeinen Funktion durchgeführt werden.

Splitting innerhalb einer Sequenzen sollte in der Praxis nur dann geschehen wenn eine einzige Sequenz vorhanden ist. Diese kann dann vorher manuell geteilt werden.


### 4.1 Splitting mit vorgegebenem Index

In [None]:
from nbdev.config import get_config

In [None]:
project_root = get_config().config_file.parent
f_path = project_root / 'test_data/WienerHammerstein'
hdf_files = get_files(f_path,extensions='.hdf5',recurse=True).sorted()

In [None]:
splitter = IndexSplitter([1,2])
test_eq(splitter(hdf_files),[[0],[1,2]])

In [None]:
list_dict = CreateDict()(hdf_files)
list_dict

[{'path': '/Users/daniel/Development/tsfast/test_data/WienerHammerstein/test/WienerHammerstein_test.hdf5'},
 {'path': '/Users/daniel/Development/tsfast/test_data/WienerHammerstein/train/WienerHammerstein_train.hdf5'},
 {'path': '/Users/daniel/Development/tsfast/test_data/WienerHammerstein/valid/WienerHammerstein_valid.hdf5'}]

In [None]:
test_eq(splitter(list_dict),splitter(hdf_files))

### 4.2 Splitting mit allgemeiner Funktion
Items, bei denen die definierte Funktion `True` zurück gibt, werden den Validierungsdatensatz zugeordnet, der Rest dem Training. In diesem Fall wird nach dem Übergeordneten Ordnernamen gesucht.

In [None]:
splitter = FuncSplitter(lambda o: Path(o).parent.name == 'valid')
splitter(hdf_files)
test_eq(splitter(hdf_files),[[0,1],[2]])

### 4.3 Splitting anhand des Parent-Folders
Splitter, der Explizit Training und Validierungsordner den Datensätzen zuordnet

In [None]:
#| export
def _parent_idxs(items, name): return mask2idxs(Path(o).parent.name == name for o in items)

def ParentSplitter(train_name='train', valid_name='valid'):
    "Split `items` from the parent folder names (`train_name` and `valid_name`)."
    def _inner(o, **kwargs):
        #if dictionaries are provided, extract the path
        if isinstance(o[0],dict):
            o = [d['path'] for d in o]
        return _parent_idxs(o, train_name),_parent_idxs(o, valid_name)
    return _inner

In [None]:
splitter = ParentSplitter()
test_eq(splitter(hdf_files),[[1],[2]])

In [None]:
test_eq(splitter(list_dict),splitter(hdf_files))

### 4.4 Percentage Splitter

In [None]:
#| export
def PercentageSplitter(pct=0.8):
    "Split `items` in order in relative quantity."
    def _inner(o, **kwargs):
        split_idx=int(len(o)*pct)
        return L(range(split_idx)),L(range(split_idx,len(o)))
    return _inner

In [None]:
splitter = PercentageSplitter(0.7)
#test_eq(splitter(hdf_files),[[0,1],[2]])

### 4.5 Apply To Dictionary
In Case of the Datablock API your items are a list of dictionaries. If you want to apply a Splitter to the path stored within you need a wrapper function.

In [None]:
#| export
def ApplyToDict(fn,key='path'):
    return lambda x:fn([i[key] for i in x])

In [None]:
splitter = FuncSplitter(lambda o: Path(o).parent.name == 'valid')
test_fail(lambda: splitter(list_dict))

In [None]:
dict_splitter = ApplyToDict(splitter)
test_eq(dict_splitter(list_dict),splitter(hdf_files))
dict_splitter(list_dict)

((#2) [np.int64(0),np.int64(1)], (#1) [2])

### 4.6 Valid Column
Using the 'valid' column of the Dataframe that has been created by a transformation.

In [None]:
#| export
valid_clm_splitter =  FuncSplitter(lambda o:o['valid'])

In [None]:
from tsfast.data.core import CreateDict, ValidClmContains,DfHDFCreateWindows

In [None]:
tfm_src = CreateDict([ValidClmContains(['valid']),DfHDFCreateWindows(win_sz=100+1,stp_sz=10,clm='u')])
src_dicts = tfm_src(hdf_files)

In [None]:
valid_clm_splitter(src_dicts)

((#16780) [np.int64(0),np.int64(1),np.int64(2),np.int64(3),np.int64(4),np.int64(5),np.int64(6),np.int64(7),np.int64(8),np.int64(9),np.int64(10),np.int64(11),np.int64(12),np.int64(13),np.int64(14),np.int64(15),np.int64(16),np.int64(17),np.int64(18),np.int64(19)...],
 (#1990) [16780,16781,16782,16783,16784,16785,16786,16787,16788,16789,16790,16791,16792,16793,16794,16795,16796,16797,16798,16799...])

In [None]:
#| include: false
import nbdev
nbdev.nbdev_export()