Skip to content

Commit

Permalink
Merge pull request #134 from sberbank-ai-lab/setup_fix
Browse files Browse the repository at this point in the history
augmentation_chain pickling
  • Loading branch information
ivkireev86 committed Apr 6, 2022
2 parents b3ecbfd + 42a9e7f commit eed22a9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
18 changes: 14 additions & 4 deletions dltranz/data_load/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,22 @@ def create_validation_loader(dataset, params):
return valid_loader


def augmentation_chain(*i_filters):
def _func(x):
for f in i_filters:
# def augmentation_chain(*i_filters):
# def _func(x):
# for f in i_filters:
# x = f(x)
# return x
# return _func


class augmentation_chain:
def __init__(self, *i_filters):
self.i_filters = i_filters

def __call__(self, x):
for f in self.i_filters:
x = f(x)
return x
return _func


class IterableAugmentations(IterableProcessingDataset):
Expand Down
22 changes: 22 additions & 0 deletions tests/dltranz_tests/test_data_load/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.utils.data import DataLoader

from dltranz.data_load import padded_collate, ZeroDownSampler, DropoutTrxDataset, TrxDataset, LastKTrxDataset
from dltranz.data_load import augmentation_chain
from tests.dltranz_tests.test_trx_encoder import gen_trx_data


Expand Down Expand Up @@ -46,3 +47,24 @@ def test_last_k_trx_dataset():
assert all(torch.tensor(res) == torch.tensor([50, 50, 50]))
res = [len(next(iter(x.values()))) for x, _ in LastKTrxDataset(TrxDataset(data), .2)]
assert all(torch.tensor(res) == torch.tensor([20, 20, 20]))


def _inc(x):
return x + 1


def _double(x):
return x * 2


def test_augmentation_chain():
a = augmentation_chain(_inc, _double, _inc)
out = a(2)
assert out == 7


def test_augmentation_chain_pickle():
import pickle

a = augmentation_chain(_inc, _double)
pickle.dumps(a)

0 comments on commit eed22a9

Please sign in to comment.