diff --git a/datadriver/datablocker.py b/datadriver/datablocker.py index dbdf852..48c2d6f 100644 --- a/datadriver/datablocker.py +++ b/datadriver/datablocker.py @@ -91,6 +91,12 @@ def get_blocks(self, records): return [records[indices : indices + block_size] for indices in range(0, num_records, block_size)] +def _sample(dataset, sample_num): + if len(dataset) < sample_num: + return dataset + else: + return random.sample(dataset, sample_num) + class ResamplingBlocker(object): @staticmethod def get_blocks_gamma(records, num_blocks, block_size, gamma): @@ -108,7 +114,7 @@ def get_blocks_gamma(records, num_blocks, block_size, gamma): nonfull_blocks = range(num_blocks) for record in records: - for block_no in random.sample(nonfull_blocks, gamma): + for block_no in _sample(nonfull_blocks, gamma): blocks[block_no].append(record) if len(blocks[block_no]) >= block_size: # Remove block from contenders list if block is