Skip to content

Commit

Permalink
Merge pull request #762 from ufal/bucketed_batching
Browse files Browse the repository at this point in the history
Bucketed batching
  • Loading branch information
jindrahelcl committed Nov 7, 2018
2 parents 405c0dd + 4256d61 commit 150ca10
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 38 deletions.
96 changes: 83 additions & 13 deletions neuralmonkey/dataset.py
Expand Up @@ -23,6 +23,7 @@
# pylint: disable=invalid-name
DataType = TypeVar("DataType")
DataSeries = Iterator[DataType]
DataExample = Dict[str, DataType]

# Reader: function that gets list of files and yields data
Reader = Callable[[List[str]], Any]
Expand All @@ -49,6 +50,40 @@
SERIES_OUTPUT = re.compile("s_(.*)_out")


# pylint: disable=too-few-public-methods
# After migrating to py3.7, make this dataclass or namedtuple with defaults
class BatchingScheme:

def __init__(self,
batch_size: int,
batch_bucket_span: int = None,
token_level_batching: bool = False,
bucketing_ignore_series: List[str] = None,
use_leftover_buckets: bool = True) -> None:
"""Construct the baching scheme.
Attributes:
batch_size: Number of examples in one mini-batch.
batch_bucket_span: The span of the bucket for bucketed batching.
token_level_batching: Count the batch_size per individual tokens
in the batch instead of examples.
bucketing_ignore_series: Series to ignore during bucketing.
use_leftover_buckets: Whether to throw out bucket contents at the
end of the epoch or to use them.
"""
check_argument_types()

self.batch_size = batch_size
self.batch_bucket_span = batch_bucket_span
self.token_level_batching = token_level_batching
self.use_leftover_buckets = use_leftover_buckets

self.bucketing_ignore_series = [] # type: List[str]
if bucketing_ignore_series is not None:
self.bucketing_ignore_series = bucketing_ignore_series
# pylint: enable=too-few-public-methods


# The protected functions below are designed to convert the ambiguous spec
# structures to a normalized form.

Expand Down Expand Up @@ -474,21 +509,21 @@ def maybe_get_series(self, name: str) -> Optional[Iterator]:
return self.get_series(name)
return None

def batches(self, batch_size: int) -> Iterator["Dataset"]:
# pylint: disable=too-many-locals,too-many-branches
def batches(self,
scheme: BatchingScheme) -> Iterator["Dataset"]:
"""Split the dataset into batches.
Arguments:
batch_size: The size of a batch. In case of lazy datasets, this
should be lower than the dataset buffer size. Otherwise, the
batch size will be equal to the size of the buffer.
scheme: `BatchingScheme` configuration object.
Returns:
Generator yielding the batches.
"""
if self.lazy and self.buffer_min_size < batch_size:
if self.lazy and self.buffer_min_size < scheme.batch_size:
warn("Minimum buffer size ({}) lower than batch size ({}). "
"It is recommended to use large buffer size."
.format(self.buffer_min_size, batch_size))
.format(self.buffer_min_size, scheme.batch_size))

# Initialize iterators
iterators = {s: it() for s, it in self.iterators.items()}
Expand Down Expand Up @@ -516,14 +551,38 @@ def itergen():

# Iterate over the rest of the data until buffer is empty
batch_index = 0
buckets = {} \
# type: Dict[int, List[DataExample]]
while buf:
# Create the batch
name = "{}.batch.{}".format(self.name, batch_index)
rows = [buf.popleft() for _ in range(batch_size) if buf]
data = {key: _make_datagen(rows, key) for key in rows[0]}

yield Dataset(name=name, iterators=data)
batch_index += 1
row = buf.popleft()

if scheme.batch_bucket_span is None:
bucket_id = 0
else:
# TODO: use only specific series to determine the bucket number
bucket_id = (max(len(row[key]) for key in row)
// scheme.batch_bucket_span)

if bucket_id not in buckets:
buckets[bucket_id] = []
buckets[bucket_id].append(row)

is_full = (len(buckets[bucket_id]) >= scheme.batch_size)
if scheme.token_level_batching:
bucket_width = max(max(len(row[key]) for key in row)
for row in buckets[bucket_id])
is_full = (bucket_width * len(buckets[bucket_id])
>= scheme.batch_size)

if is_full:
# Create the batch
name = "{}.batch.{}".format(self.name, batch_index)
data = {key: _make_datagen(buckets[bucket_id], key)
for key in buckets[bucket_id][0]}

yield Dataset(name=name, iterators=data)
batch_index += 1
buckets[bucket_id] = []

# If lazy, refill buffer & shuffle if needed
# Otherwise, all of the data is already loaded in the buffer.
Expand All @@ -539,6 +598,17 @@ def itergen():
random.shuffle(lbuf)
buf = deque(lbuf)

if scheme.use_leftover_buckets:
for bucket_id in buckets:
if buckets[bucket_id]:
name = "{}.batch.{}".format(self.name, batch_index)
data = {key: _make_datagen(buckets[bucket_id], key)
for key in buckets[bucket_id][0]}

yield Dataset(name=name, iterators=data)
batch_index += 1
# pylint: enable=too-many-locals,too-many-branches

def subset(self, start: int, length: int) -> "Dataset":
"""Create a subset of the dataset.
Expand Down
15 changes: 12 additions & 3 deletions neuralmonkey/experiment.py
Expand Up @@ -20,7 +20,7 @@
from neuralmonkey.learning_utils import (training_loop, evaluation,
run_on_dataset,
print_final_evaluation)
from neuralmonkey.dataset import Dataset
from neuralmonkey.dataset import Dataset, BatchingScheme
from neuralmonkey.model.sequence import EmbeddedFactorSequence
from neuralmonkey.runners.base_runner import ExecutionResult
from neuralmonkey.tf_manager import get_default_tf_manager
Expand Down Expand Up @@ -171,6 +171,7 @@ def train(self) -> None:
epochs=self.model.epochs,
trainer=self.model.trainer,
batch_size=self.model.batch_size,
batching_scheme=self.model.batching_scheme,
log_directory=self.model.output,
evaluators=self.model.evaluation,
runners=self.model.runners,
Expand Down Expand Up @@ -241,13 +242,19 @@ def run_model(self,
if not self._vars_loaded:
self.load_variables()

batching_scheme = BatchingScheme(
batch_size=batch_size or self.model.runners_batch_size,
batch_bucket_span=None,
token_level_batching=False,
bucketing_ignore_series=[])

with self.graph.as_default():
# TODO: check_dataset_and_coders(dataset, self.model.runners)
return run_on_dataset(
self.model.tf_manager, self.model.runners, dataset,
self.model.postprocess,
write_out=write_out, log_progress=log_progress,
batch_size=batch_size or self.model.runners_batch_size)
batching_scheme=batching_scheme)

def evaluate(self,
dataset: Dataset,
Expand Down Expand Up @@ -329,7 +336,9 @@ def get_current(cls) -> "Experiment":
def create_config(train_mode: bool = True) -> Configuration:
config = Configuration()
config.add_argument("tf_manager", required=False, default=None)
config.add_argument("batch_size", cond=lambda x: x > 0)
config.add_argument("batch_size", required=False, default=None,
cond=lambda x: x is None or x > 0)
config.add_argument("batching_scheme", required=False, default=None)
config.add_argument("output")
config.add_argument("postprocess", required=False, default=None)
config.add_argument("runners")
Expand Down
53 changes: 37 additions & 16 deletions neuralmonkey/learning_utils.py
Expand Up @@ -15,7 +15,7 @@
from typeguard import check_argument_types

from neuralmonkey.logging import log, log_print, warn, notice
from neuralmonkey.dataset import Dataset
from neuralmonkey.dataset import Dataset, BatchingScheme
from neuralmonkey.tf_manager import TensorFlowManager
from neuralmonkey.runners.base_runner import (
BaseRunner, ExecutionResult, reduce_execution_results)
Expand All @@ -37,7 +37,6 @@
def training_loop(tf_manager: TensorFlowManager,
epochs: int,
trainer: Union[Trainer, List[Trainer]],
batch_size: int,
log_directory: str,
evaluators: EvalConfiguration,
runners: List[BaseRunner],
Expand All @@ -51,6 +50,8 @@ def training_loop(tf_manager: TensorFlowManager,
val_preview_output_series: List[str] = None,
val_preview_num_examples: int = 15,
train_start_offset: int = 0,
batch_size: int = None,
batching_scheme: BatchingScheme = None,
runners_batch_size: int = None,
initial_variables: Union[str, List[str]] = None,
postprocess: Postprocess = None) -> None:
Expand All @@ -61,7 +62,9 @@ def training_loop(tf_manager: TensorFlowManager,
epochs: Number of epochs for which the algoritm will learn.
trainer: The trainer object containg the TensorFlow code for computing
the loss and optimization operation.
batch_size: number of examples in one mini-batch
batch_size: Number of examples in one mini-batch.
batching_scheme: Batching scheme specification. Cannot be provided when
batch_size is specified.
log_directory: Directory where the TensordBoard log will be generated.
If None, nothing will be done.
evaluators: List of evaluators. The last evaluator is used as the main.
Expand All @@ -88,8 +91,8 @@ def training_loop(tf_manager: TensorFlowManager,
validation
train_start_offset: how many lines from the training dataset should be
skipped. The training starts from the next batch.
runners_batch_size: batch size of runners. It is the same as batch_size
if not specified
runners_batch_size: batch size of runners. Reuses the training batching
scheme with bucketing turned off.
initial_variables: variables used for initialization, for example for
continuation of training. Provide it with a path to your model
directory and its checkpoint file group common prefix, e.g.
Expand All @@ -100,6 +103,24 @@ def training_loop(tf_manager: TensorFlowManager,
"""
check_argument_types()

if (batch_size is None) == (batching_scheme is None):
raise ValueError("You must specify either batch_size or "
"batching_scheme (not both).")

if batch_size is not None:
assert batching_scheme is None
batching_scheme = BatchingScheme(batch_size=batch_size)

assert batching_scheme is not None

if runners_batch_size is None:
runners_batch_size = batching_scheme.batch_size

runners_batching_scheme = BatchingScheme(
batch_size=runners_batch_size,
token_level_batching=batching_scheme.token_level_batching,
use_leftover_buckets=True)

if isinstance(val_dataset, List):
val_datasets = val_dataset
else:
Expand All @@ -118,9 +139,6 @@ def training_loop(tf_manager: TensorFlowManager,
_log_model_variables(
var_list=list(set().union(*[t.var_list for t in trainers])))

if runners_batch_size is None:
runners_batch_size = batch_size

evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e
for e in evaluators]

Expand Down Expand Up @@ -166,7 +184,7 @@ def training_loop(tf_manager: TensorFlowManager,
log_print("")
log("Epoch {} begins".format(epoch_n), color="red")

train_batches = train_dataset.batches(batch_size)
train_batches = train_dataset.batches(batching_scheme)

if epoch_n == 1 and train_start_offset:
if train_dataset.shuffled and not train_dataset.lazy:
Expand All @@ -185,7 +203,8 @@ def training_loop(tf_manager: TensorFlowManager,
batch, trainers, train=True, summaries=True)
train_results, train_outputs = run_on_dataset(
tf_manager, runners, batch, postprocess,
write_out=False, batch_size=len(batch))
write_out=False,
batching_scheme=runners_batching_scheme)
# ensure train outputs are iterable more than once
train_outputs = {
k: list(v) for k, v in train_outputs.items()}
Expand Down Expand Up @@ -213,7 +232,7 @@ def training_loop(tf_manager: TensorFlowManager,
val_results, val_outputs = run_on_dataset(
tf_manager, runners, valset,
postprocess, write_out=False,
batch_size=runners_batch_size)
batching_scheme=runners_batching_scheme)
# ensure val outputs are iterable more than once
val_outputs = {k: list(v)
for k, v in val_outputs.items()}
Expand Down Expand Up @@ -304,7 +323,7 @@ def training_loop(tf_manager: TensorFlowManager,
for dataset in test_datasets:
test_results, test_outputs = run_on_dataset(
tf_manager, runners, dataset, postprocess,
write_out=True, batch_size=runners_batch_size)
write_out=True, batching_scheme=runners_batching_scheme)
# ensure test outputs are iterable more than once
test_outputs = {k: list(v) for k, v in test_outputs.items()}
eval_result = evaluation(evaluators, dataset, runners,
Expand Down Expand Up @@ -377,7 +396,7 @@ def run_on_dataset(tf_manager: TensorFlowManager,
runners: List[BaseRunner],
dataset: Dataset,
postprocess: Postprocess,
batch_size: int,
batching_scheme: BatchingScheme,
write_out: bool = False,
log_progress: int = 0) -> Tuple[
List[ExecutionResult], Dict[str, List[Any]]]:
Expand All @@ -395,7 +414,7 @@ def run_on_dataset(tf_manager: TensorFlowManager,
postprocess: Dataset-level postprocessors
write_out: Flag whether the outputs should be printed to a file defined
in the dataset object.
batch_size: size of the minibatch
batching_scheme: Scheme used for batching.
log_progress: log progress every X seconds
extra_fetches: Extra tensors to evaluate for each batch.
Expand All @@ -413,13 +432,15 @@ def run_on_dataset(tf_manager: TensorFlowManager,
last_log_time = time.process_time()
batch_results = [[] for _ in runners] # type: List[List[ExecutionResult]]

for i, batch in enumerate(dataset.batches(batch_size)):
processed_examples = 0
for batch in dataset.batches(batching_scheme):
if 0 < log_progress < time.process_time() - last_log_time:
log("Processed {} examples.".format(i * batch_size))
log("Processed {} examples.".format(processed_examples))
last_log_time = time.process_time()

execution_results = tf_manager.execute(
batch, runners, compute_losses=contains_targets)
processed_examples += len(batch)

for script_list, ex_result in zip(batch_results, execution_results):
script_list.append(ex_result)
Expand Down

0 comments on commit 150ca10

Please sign in to comment.