-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T2T batching #786
base: master
Are you sure you want to change the base?
T2T batching #786
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,6 +95,84 @@ def __init__(self, | |
# pylint: enable=too-few-public-methods | ||
|
||
|
||
def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1): | ||
"""Create a default set of length-bucket boundaries.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. přidal bych příklad vstupu a výstupu, moc nechápu proč length bucket step je float |
||
assert length_bucket_step > 1.0 | ||
x = min_length | ||
boundaries = [] | ||
while x < max_length: | ||
boundaries.append(x) | ||
x = max(x + 1, int(x * length_bucket_step)) | ||
return boundaries | ||
|
||
|
||
def get_batching_scheme(batch_size: int, | ||
max_length: int = None, | ||
min_length_bucket: int = 8, | ||
length_bucket_step: float = 1.1, | ||
shard_multiplier: int = 1, | ||
length_multiplier: int = 1, | ||
min_length: int = 0) -> BatchingScheme: | ||
"""Create a batching scheme based on model hyperparameters. | ||
|
||
Every batch contains a number of sequences divisible by `shard_multiplier`. | ||
|
||
Args: | ||
batch_size: int, total number of tokens in a batch. | ||
max_length: int, sequences longer than this will be skipped. Defaults | ||
to batch_size. | ||
min_length_bucket: int | ||
length_bucket_step: float greater than 1.0 | ||
shard_multiplier: an integer increasing the batch_size to suit | ||
splitting across datashards. | ||
length_multiplier: an integer multiplier that is used to increase the | ||
batch sizes and sequence length tolerance. | ||
min_length: int, sequences shorter than this will be skipped. | ||
Return: | ||
A dictionary with parameters that can be passed to input_pipeline: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tohle neni pravda |
||
* boundaries: list of bucket boundaries | ||
* batch_sizes: list of batch sizes for each length bucket | ||
* max_length: int, maximum length of an example | ||
Raises: | ||
ValueError: If min_length > max_length | ||
""" | ||
max_length = max_length or batch_size | ||
if max_length < min_length: | ||
raise ValueError("max_length must be greater or equal to min_length") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tady by se mělo kontrolovat že length_bucket_step je > 1.0 a hodit valueerror se zprávou a nenechávat to až na |
||
boundaries = _bucket_boundaries(max_length, min_length_bucket, | ||
length_bucket_step) | ||
boundaries = [boundary * length_multiplier for boundary in boundaries] | ||
max_length *= length_multiplier | ||
|
||
batch_sizes = [ | ||
max(1, batch_size // length) for length in boundaries + [max_length] | ||
] | ||
max_batch_size = max(batch_sizes) | ||
# Since the Datasets API only allows a single constant for window_size, | ||
# and it needs divide all bucket_batch_sizes, we pick a highly-composite | ||
# window size and then round down all batch sizes to divisors of that | ||
# window size, so that a window can always be divided evenly into batches. | ||
highly_composite_numbers = [ | ||
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, | ||
1680, 2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, | ||
50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 498960, | ||
554400, 665280, 720720, 1081080, 1441440, 2162160, 2882880, 3603600, | ||
4324320, 6486480, 7207200, 8648640, 10810800, 14414400, 17297280, | ||
21621600, 32432400, 36756720, 43243200, 61261200, 73513440, 110270160 | ||
] | ||
window_size = max( | ||
[i for i in highly_composite_numbers if i <= 3 * max_batch_size]) | ||
divisors = [i for i in range(1, window_size + 1) if window_size % i == 0] | ||
batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes] | ||
window_size *= shard_multiplier | ||
batch_sizes = [bs * shard_multiplier for bs in batch_sizes] | ||
|
||
ret = BatchingScheme(bucket_boundaries=boundaries, | ||
bucket_batch_sizes=batch_sizes) | ||
return ret | ||
|
||
|
||
# The protected functions below are designed to convert the ambiguous spec | ||
# structures to a normalized form. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
from termcolor import colored | ||
|
||
from neuralmonkey.logging import log, log_print, warn | ||
from neuralmonkey.dataset import Dataset | ||
from neuralmonkey.dataset import Dataset, BatchingScheme | ||
from neuralmonkey.tf_manager import TensorFlowManager | ||
from neuralmonkey.runners.base_runner import ( | ||
BaseRunner, ExecutionResult, GraphExecutor, OutputSeries) | ||
|
@@ -85,6 +85,9 @@ def training_loop(cfg: Namespace) -> None: | |
trainer_result = cfg.tf_manager.execute( | ||
batch, feedables, cfg.trainers, train=True, | ||
summaries=True) | ||
# workaround: we need to use validation batching scheme | ||
# during evaluation | ||
batch.batching = BatchingScheme(batch_size=cfg.batch_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tohle neni validation batching scheme. zahoď tuhle změnu, v mým refaktoru už to funguje správně a tohle by zbytečně zaneslo konflikt. |
||
train_results, train_outputs, f_batch = run_on_dataset( | ||
cfg.tf_manager, cfg.runners, cfg.dataset_runner, batch, | ||
cfg.postprocess, write_out=False) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray: | |
|
||
return np.array(numbers, dtype=dtype) | ||
|
||
def reader(files: List[str])-> Iterable[List[np.ndarray]]: | ||
def reader(files: List[str]) -> Iterable[List[np.ndarray]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tohle nesouvisí s tou změnou, jen to zanese konflikt do branche s tf datasetem. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ale jestli to jinak neprojde přes travis, tak to tu nechej |
||
for path in files: | ||
current_line = 0 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ tf_manager=<tf_manager> | |
output="tests/outputs/hier-multiattention" | ||
overwrite_output_dir=True | ||
epochs=1 | ||
batch_size=1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. batch size by neměla být povinná jen kvůli tomu, že je někde nějaký workaround.. |
||
train_dataset=<train_data> | ||
val_dataset=<val_data> | ||
trainer=<trainer> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chybí typový anotace