diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index a8a943b73..7fc5976c9 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -21,6 +21,7 @@ BaseRunner, ExecutionResult, reduce_execution_results) from neuralmonkey.trainers.generic_trainer import GenericTrainer from neuralmonkey.trainers.multitask_trainer import MultitaskTrainer +from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer # pylint: disable=invalid-name Evaluation = Dict[str, float] @@ -28,7 +29,7 @@ EvalConfiguration = List[Union[Tuple[SeriesName, Any], Tuple[SeriesName, SeriesName, Any]]] Postprocess = Optional[List[Tuple[SeriesName, Callable]]] -Trainer = Union[GenericTrainer, MultitaskTrainer] +Trainer = Union[GenericTrainer, MultitaskTrainer, DelayedUpdateTrainer] # pylint: enable=invalid-name @@ -154,10 +155,40 @@ def training_loop(tf_manager: TensorFlowManager, "TensorFlowManager when using loss as " "the main metric") + if log_period_batch is not None and isinstance( + trainer, DelayedUpdateTrainer): + if log_period_batch % trainer.batches_per_update != 0: + raise ValueError("When using delayed update trainer, the logging " + "period must be divisible by batches_per_update.") + + if val_period_batch is not None and isinstance( + trainer, DelayedUpdateTrainer): + if val_period_batch % trainer.batches_per_update != 0: + raise ValueError("When using delayed update trainer, validation " + "period must be divisible by batches_per_update.") + step = 0 seen_instances = 0 last_seen_instances = 0 + def _is_logging_time(period_batch: Optional[int], + period_time: Optional[float], + last_time: float) -> bool: + if step == 0: + return False + + if period_batch is not None: + return step % period_batch == 0 + + assert period_time is not None + + # deal with delayed trainer + if isinstance(trainer, DelayedUpdateTrainer): + if step % trainer.batches_per_update != 0: + return False + + return last_time + period_time < time.process_time() + if initial_variables is None: # Assume we don't look at coder checkpoints when global # initial variables are supplied @@ -196,9 +227,9 @@ def training_loop(tf_manager: TensorFlowManager, for batch_n, batch in enumerate(train_batches): step += 1 seen_instances += len(batch) - if _is_logging_time(step, log_period_batch, - last_log_time, log_period_time): + if _is_logging_time(log_period_batch, log_period_time, + last_log_time): trainer_result = tf_manager.execute( batch, trainers, train=True, summaries=True) train_results, train_outputs = run_on_dataset( @@ -221,8 +252,8 @@ def training_loop(tf_manager: TensorFlowManager, tf_manager.execute( batch, trainers, train=True, summaries=False) - if _is_logging_time(step, val_period_batch, - last_val_time, val_period_time): + if _is_logging_time(val_period_batch, val_period_time, + last_val_time): log_print("") val_duration_start = time.process_time() val_examples = 0 @@ -336,16 +367,6 @@ def training_loop(tf_manager: TensorFlowManager, raise interrupt # pylint: disable=raising-bad-type -def _is_logging_time( - step: int, logging_period_batch: Optional[int], - last_log_time: float, logging_period_time: Optional[float]): - if logging_period_batch is not None: - return step % logging_period_batch == logging_period_batch - 1 - - assert logging_period_time is not None - return last_log_time + logging_period_time < time.process_time() - - def _resolve_period( period: Union[str, int]) -> Tuple[Optional[int], Optional[float]]: if isinstance(period, int): diff --git a/neuralmonkey/trainers/delayed_update_trainer.py b/neuralmonkey/trainers/delayed_update_trainer.py new file mode 100644 index 000000000..72d284b00 --- /dev/null +++ b/neuralmonkey/trainers/delayed_update_trainer.py @@ -0,0 +1,244 @@ +from typing import Dict, List, Tuple +# pylint: disable=unused-import +from typing import Optional +# pylint: enable=unused-import + +import tensorflow as tf +from typeguard import check_argument_types + +from neuralmonkey.decorators import tensor +from neuralmonkey.runners.base_runner import ( + Executable, ExecutionResult, NextExecute) +from neuralmonkey.trainers.generic_trainer import ( + GenericTrainer, Objective, Gradients) + + +class DelayedUpdateTrainer(GenericTrainer): + + # pylint: disable=too-many-arguments + def __init__(self, + batches_per_update: int, + objectives: List[Objective], + l1_weight: float = 0.0, + l2_weight: float = 0.0, + clip_norm: float = None, + optimizer: tf.train.Optimizer = None, + var_scopes: List[str] = None, + var_collection: str = None) -> None: + check_argument_types() + GenericTrainer.__init__(self, objectives, l1_weight, l2_weight, + clip_norm, optimizer, var_scopes, + var_collection) + + self.batches_per_update = batches_per_update + # pylint: enable=too-many-arguments + + @tensor + def existing_grads_and_vars(self) -> Tuple[ + List[tf.Tensor], List[tf.Variable]]: + orig_grads = super().raw_gradients + + # pylint: disable=not-an-iterable + # Pylint does not understand @tensor annotations + transposed = tuple(zip( + *[(grad, var) for grad, var in orig_grads if grad is not None])) + # pylint: enable=not-an-iterable + + return list(transposed[0]), list(transposed[1]) + + @tensor + def gradient_buffers(self) -> List[tf.Variable]: + # pylint: disable=unpacking-non-sequence + existing_gradients, _ = self.existing_grads_and_vars + # pylint: enable=unpacking-non-sequence + + with tf.variable_scope("gradient_buffer"): + return [tf.Variable(initial_value=tf.zeros_like(grad), + trainable=False) + for grad in existing_gradients] + + @tensor + def objective_buffers(self) -> List[tf.Variable]: + with tf.variable_scope("loss_buffers"): + return [tf.Variable(0.0, trainable=False) for _ in self.objectives] + + # pylint: disable=no-self-use + @tensor + def diff_buffer(self) -> tf.Variable: + return tf.Variable(0.0, trainable=False) + + @tensor + def cumulator_counter(self) -> tf.Variable: + return tf.Variable(0, trainable=False, name="self.cumulator_counter") + # pylint: enable=no-self-use + + @tensor + def accumulate_ops(self) -> List[tf.Operation]: + # pylint: disable=unpacking-non-sequence + existing_gradients, _ = self.existing_grads_and_vars + # pylint: enable=unpacking-non-sequence + + # pylint: disable=not-an-iterable + # Pylint does not understand @tensor annotations + accumulate_ops = [ + tf.assign_add(gradbuf, grad) + for gradbuf, grad in zip( + self.gradient_buffers, existing_gradients)] + + accumulate_ops.extend( + tf.assign_add(objbuf, obj.loss) + for objbuf, obj in zip(self.objective_buffers, self.objectives)) + # pylint: enable=not-an-iterable + + accumulate_ops.append( + tf.assign_add(self.diff_buffer, self.differentiable_loss_sum)) + accumulate_ops.append( + tf.assign_add(self.cumulator_counter, 1)) + + return accumulate_ops + + @tensor + def reset_ops(self) -> List[tf.Operation]: + # pylint: disable=not-an-iterable + # Pylint does not understand @tensor annotations + reset_ops = [tf.assign(gradbuf, tf.zeros_like(gradbuf)) + for gradbuf in self.gradient_buffers] + reset_ops.extend( + tf.assign(objbuf, 0.0) for objbuf in self.objective_buffers) + # pylint: enable=not-an-iterable + + reset_ops.append(tf.assign(self.diff_buffer, 0.0)) + reset_ops.append(tf.assign(self.cumulator_counter, 0)) + return reset_ops + + @tensor + def raw_gradients(self) -> Gradients: + """Return averaged gradients over buffers.""" + # pylint: disable=not-an-iterable + # Pylint does not understand @tensor annotations + averaged_grads = [grad / tf.to_float(self.cumulator_counter) + for grad in self.gradient_buffers] + # pylint: enable=not-an-iterable + + tf.summary.scalar( + "train_opt_cost", + self.diff_buffer / tf.to_float(self.cumulator_counter), + collections=["summary_train"]) + + # log all objectives + for obj, objbuf in zip(self.objectives, self.objective_buffers): + tf.summary.scalar( + obj.name, objbuf / tf.to_float(self.cumulator_counter), + collections=["summary_train"]) + + # now, zip averaged grads with associated vars to a Gradients struct. + # pylint: disable=unpacking-non-sequence + _, existing_vars = self.existing_grads_and_vars + # pylint: enable=unpacking-non-sequence + return list(zip(averaged_grads, existing_vars)) + + @tensor + def summaries(self) -> Dict[str, tf.Tensor]: + # pylint: disable=protected-access + if isinstance(self.optimizer._lr, tf.Tensor): + tf.summary.scalar("learning_rate", self.optimizer._lr, + collections=["summary_train"]) + # pylint: enable=protected-access + + # pylint: disable=unpacking-non-sequence + l1_norm, l2_norm = self.regularization_losses + # pylint: enable=unpacking-non-sequence + + tf.summary.scalar("train_l1", l1_norm, collections=["summary_train"]) + tf.summary.scalar("train_l2", l2_norm, collections=["summary_train"]) + + # pylint: disable=not-an-iterable + # Pylint does not understand @tensor annotations + for grad, var in self.gradients: + if grad is not None: + summary_name = "gr_{}".format(var.name) + tf.summary.histogram( + summary_name, grad, collections=["summary_gradients"]) + # pylint: enable=not-an-iterable + + return { + "scalar_summaries": tf.summary.merge( + tf.get_collection("summary_train")), + "histogram_summaries": tf.summary.merge( + tf.get_collection("summary_gradients"))} + + def get_executable(self, + compute_losses: bool = True, + summaries: bool = True, + num_sessions: int = 1) -> Executable: + assert compute_losses + if num_sessions != 1: + raise ValueError( + "Trainer only supports execution in a single session") + + return DelayedTrainExecutable(self, summaries) + + +class DelayedTrainExecutable(Executable): + + def __init__(self, trainer: DelayedUpdateTrainer, summaries: bool) -> None: + self.trainer = trainer + self.summaries = summaries + self.result = None # type: Optional[ExecutionResult] + + self.state = 0 + self.res_hist_sums = None + self.res_scal_sums = None + self.res_losses = None + + def next_to_execute(self) -> NextExecute: + + if self.state == 0: # ACCUMULATING + fetches = {"accumulators": self.trainer.accumulate_ops, + "counter": self.trainer.cumulator_counter, + "losses": self.trainer.objective_values} + coders = self.trainer.all_coders + + elif self.state == 1: # UPDATING + fetches = { + "train_op": self.trainer.train_op, + "_update_ops": tf.get_collection(tf.GraphKeys.UPDATE_OPS)} + + if self.summaries: + fetches.update(self.trainer.summaries) + + coders = self.trainer.all_coders + + else: # RESETTING + fetches = {"resets": self.trainer.reset_ops} + coders = set() + + return coders, fetches, [{}] + + def collect_results(self, results: List[Dict]) -> None: + assert len(results) == 1 + result = results[0] + + if self.state == 0: # ACCUMULATING + self.res_losses = result["losses"] + + # Are we updating? + counter = result["counter"] + + if counter == self.trainer.batches_per_update: + self.state = 1 + return + elif self.state == 1: + if self.summaries: + self.res_scal_sums = result["scalar_summaries"] + self.res_hist_sums = result["histogram_summaries"] + + self.state = 2 + return + + assert self.res_losses is not None + self.result = ExecutionResult( + [], losses=self.res_losses, + scalar_summaries=self.res_scal_sums, + histogram_summaries=self.res_hist_sums, + image_summaries=None) diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index b0f3e9edf..5cc6063c9 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -2,7 +2,10 @@ import re import tensorflow as tf +from typeguard import check_argument_types +from neuralmonkey.decorators import tensor +from neuralmonkey.logging import log from neuralmonkey.model.model_part import ModelPart from neuralmonkey.runners.base_runner import ( Executable, ExecutionResult, NextExecute) @@ -39,6 +42,10 @@ class Objective(NamedTuple( # pylint: disable=too-few-public-methods,too-many-locals,too-many-arguments class GenericTrainer: + @staticmethod + def default_optimizer(): + return tf.train.AdamOptimizer(learning_rate=1e-4) + def __init__(self, objectives: List[Objective], l1_weight: float = 0.0, @@ -47,193 +54,219 @@ def __init__(self, optimizer: tf.train.Optimizer = None, var_scopes: List[str] = None, var_collection: str = None) -> None: + check_argument_types() + + self.objectives = objectives + self.l1_weight = l1_weight + self.l2_weight = l2_weight + self.clip_norm = clip_norm + self.var_scopes = var_scopes + self.var_collection = var_collection + if self.var_collection is None: + self.var_collection = tf.GraphKeys.TRAINABLE_VARIABLES + + self.optimizer = ( + optimizer if optimizer is not None else self.default_optimizer()) + self.all_coders = set.union(*(obj.decoder.get_dependencies() + for obj in objectives)) + + log("Train op: {}".format(str(self.train_op))) + + # pylint: disable=no-self-use + @tensor + def regularization_losses(self) -> Tuple[tf.Tensor, tf.Tensor]: + """Compute the regularization losses, e.g. L1 and L2.""" + regularizable = [v for v in tf.trainable_variables() + if not BIAS_REGEX.findall(v.name) + and not v.name.startswith("vgg") + and not v.name.startswith("Inception") + and not v.name.startswith("resnet")] + + with tf.name_scope("regularization"): + l1_norm = sum(tf.reduce_sum(abs(v)) for v in regularizable) + l2_norm = sum(tf.reduce_sum(v ** 2) for v in regularizable) + + return l1_norm, l2_norm + # pylint: enable=no-self-use + + @tensor + def objective_values(self) -> List[tf.Tensor]: + """Compute unweighted losses for fetching.""" + # pylint: disable=unpacking-non-sequence + l1_norm, l2_norm = self.regularization_losses + # pylint: disable=unpacking-non-sequence + + return [o.loss for o in self.objectives] + [l1_norm, l2_norm] + + @tensor + def differentiable_loss_sum(self) -> tf.Tensor: + """Compute the differentiable loss (including regularization).""" + obj_weights = [] # type: List[Optional[float]] + for obj in self.objectives: + if obj.gradients is not None: + obj_weights.append(None) + elif obj.weight is None: + obj_weights.append(1.0) + else: + obj_weights.append(obj.weight) + + obj_weights += [self.l1_weight, self.l2_weight] + diff_loss = sum( + o * w for o, w in zip(self.objective_values, obj_weights) + if w is not None) + + return diff_loss + + @tensor + def raw_gradients(self) -> Gradients: + """Compute the gradients.""" + with tf.name_scope("gradient_collection"): + gradients = self.optimizer.compute_gradients( + self.differentiable_loss_sum, self.var_list) + + def scale_grads(gradients: Gradients, + weight: ObjectiveWeight) -> Gradients: + result = [] # type: Gradients + for grad, var in gradients: + if weight is not None and grad is not None: + result.append((weight * grad, var)) + else: + result.append((grad, var)) + return result + + # objectives that have their gradients explictly computed + other_gradients = [ + scale_grads(o.gradients, o.weight) + for o in self.objectives if o.gradients is not None] + + def sum_grads(gradients_list: List[Gradients]) -> Gradients: + summed_dict = {} # type: Dict[tf.Variable, tf.Tensor] + for gradients in gradients_list: + for grad, var in gradients: + if grad is not None: + if var not in summed_dict: + summed_dict[var] = grad + else: + summed_dict[var] += grad + + return [(grad, var) for var, grad in summed_dict.items()] + + if other_gradients: + gradients = sum_grads([gradients] + other_gradients) + + return gradients + + @tensor + def gradients(self) -> Gradients: + gradients = self.raw_gradients + + if self.clip_norm: + assert self.clip_norm > 0.0 + # pylint: disable=not-an-iterable + # Pylint does not understand @tensor annotations + gradients = [ + (tf.clip_by_norm(grad, self.clip_norm), var) + for grad, var in self.raw_gradients if grad is not None] + # pylint: disable=not-an-iterable + + return gradients + + @tensor + def train_op(self) -> tf.Operation: + """Construct the training op.""" + with tf.name_scope("trainer"): + step = tf.train.get_or_create_global_step() + return self.optimizer.apply_gradients(self.gradients, step) - if var_collection is None: - var_collection = tf.GraphKeys.TRAINABLE_VARIABLES - - if var_scopes is None: - var_lists = [tf.get_collection(var_collection)] + @property + def var_list(self) -> List[tf.Variable]: + if self.var_scopes is None: + vlists = [tf.get_collection(self.var_collection)] else: - var_lists = [tf.get_collection(var_collection, scope) - for scope in var_scopes] + vlists = [tf.get_collection(self.var_collection, scope) + for scope in self.var_scopes] # Flatten the list of lists - self.var_list = [var for var_list in var_lists for var in var_list] + return [var for var_list in vlists for var in var_list] - with tf.variable_scope("trainer", reuse=tf.AUTO_REUSE): - step = tf.train.get_or_create_global_step() + @tensor + def summaries(self) -> Dict[str, tf.Tensor]: - if optimizer: - self.optimizer = optimizer - else: - self.optimizer = tf.train.AdamOptimizer( - learning_rate=1e-4, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - use_locking=False) - # pylint: disable=protected-access - if isinstance(self.optimizer._lr, tf.Tensor): - tf.summary.scalar("learning_rate", self.optimizer._lr, - collections=["summary_train"]) - # pylint: enable=protected-access - - with tf.name_scope("regularization"): - regularizable = [v for v in tf.trainable_variables() - if not BIAS_REGEX.findall(v.name) - and not v.name.startswith("vgg") - and not v.name.startswith("Inception") - and not v.name.startswith("resnet")] - l1_value = sum(tf.reduce_sum(abs(v)) for v in regularizable) - l1_cost = l1_weight * l1_value if l1_weight > 0 else 0.0 - - l2_value = sum(tf.reduce_sum(v ** 2) for v in regularizable) - l2_cost = l2_weight * l2_value if l2_weight > 0 else 0.0 - - # unweighted losses for fetching - self.losses = [o.loss for o in objectives] + [l1_value, l2_value] - tf.summary.scalar("train_l1", l1_value, + # pylint: disable=protected-access + if isinstance(self.optimizer._lr, tf.Tensor): + tf.summary.scalar("learning_rate", self.optimizer._lr, collections=["summary_train"]) - tf.summary.scalar("train_l2", l2_value, + # pylint: enable=protected-access + + # pylint: disable=unpacking-non-sequence + l1_norm, l2_norm = self.regularization_losses + # pylint: enable=unpacking-non-sequence + tf.summary.scalar("train_l1", l1_norm, collections=["summary_train"]) + tf.summary.scalar("train_l2", l2_norm, collections=["summary_train"]) + + for obj in self.objectives: + tf.summary.scalar(obj.name, obj.loss, collections=["summary_train"]) - # log all objectives - for obj in objectives: - tf.summary.scalar( - obj.name, obj.loss, collections=["summary_train"]) - - # if the objective does not have its own gradients, - # just use TF to do the derivative - with tf.name_scope("gradient_collection"): - differentiable_loss_sum = sum( - (o.weight if o.weight is not None else 1) * o.loss - for o in objectives - if o.gradients is None) + l1_cost + l2_cost - implicit_gradients = self._get_gradients( - differentiable_loss_sum) - - # objectives that have their gradients explictly computed - other_gradients = [ - _scale_gradients(o.gradients, o.weight) - for o in objectives if o.gradients is not None] - - if other_gradients: - gradients = _sum_gradients( - [implicit_gradients] + other_gradients) - else: - gradients = implicit_gradients - - tf.summary.scalar("train_opt_cost", - differentiable_loss_sum, - collections=["summary_train"]) - - if clip_norm: - assert clip_norm > 0.0 - gradients = [(tf.clip_by_norm(grad, clip_norm), var) - for grad, var in gradients - if grad is not None] - - self.all_coders = set.union(*(obj.decoder.get_dependencies() - for obj in objectives)) - - self.train_op = self.optimizer.apply_gradients( - gradients, global_step=step) - - for grad, var in gradients: - if grad is not None: - tf.summary.histogram( - "gr_{}".format(var.name), - grad, collections=["summary_gradients"]) - - self.histogram_summaries = tf.summary.merge( - tf.get_collection("summary_gradients")) - self.scalar_summaries = tf.summary.merge( - tf.get_collection("summary_train")) - - def _get_gradients(self, tensor: tf.Tensor) -> Gradients: - gradient_list = self.optimizer.compute_gradients(tensor, self.var_list) - return gradient_list - - def get_executable( - self, compute_losses=True, summaries=True, - num_sessions=1) -> Executable: + tf.summary.scalar("train_opt_cost", self.differentiable_loss_sum, + collections=["summary_train"]) + + # pylint: disable=not-an-iterable + # Pylint does not understand @tensor annotations + for grad, var in self.gradients: + if grad is not None: + summary_name = "gr_{}".format(var.name) + tf.summary.histogram( + summary_name, grad, collections=["summary_gradients"]) + # pylint: enable=not-an-iterable + + return { + "scalar_summaries": tf.summary.merge( + tf.get_collection("summary_train")), + "histogram_summaries": tf.summary.merge( + tf.get_collection("summary_gradients"))} + + def get_executable(self, + compute_losses: bool = True, + summaries: bool = True, + num_sessions: int = 1) -> Executable: assert compute_losses + if num_sessions != 1: + raise ValueError( + "Trainer only supports execution in a single session") - return TrainExecutable(self.all_coders, - num_sessions, - self.train_op, - self.losses, - self.scalar_summaries if summaries else None, - self.histogram_summaries if summaries else None) - - -def _sum_gradients(gradients_list: List[Gradients]) -> Gradients: - summed_dict = {} # type: Dict[tf.Variable, tf.Tensor] - for gradients in gradients_list: - for tensor, var in gradients: - if tensor is not None: - if var not in summed_dict: - summed_dict[var] = tensor - else: - summed_dict[var] += tensor - return [(tensor, var) for var, tensor in summed_dict.items()] - - -def _scale_gradients(gradients: Gradients, - weight: ObjectiveWeight) -> Gradients: - - result = [] # type: Gradients - for tensor, var in gradients: - if weight is not None and tensor is not None: - result.append((weight * tensor, var)) - else: - result.append((tensor, var)) - - return result + return TrainExecutable(self, summaries) class TrainExecutable(Executable): - def __init__(self, all_coders, num_sessions, - train_op, losses, scalar_summaries, - histogram_summaries): - self.all_coders = all_coders - self.num_sessions = num_sessions - self.train_op = train_op - self.losses = losses - self.scalar_summaries = scalar_summaries - self.histogram_summaries = histogram_summaries - - self.result = None + def __init__(self, trainer: GenericTrainer, summaries: bool) -> None: + self.trainer = trainer + self.summaries = summaries + self.result = None # type: Optional[ExecutionResult] def next_to_execute(self) -> NextExecute: - fetches = {"train_op": self.train_op} - if self.scalar_summaries is not None: - fetches["scalar_summaries"] = self.scalar_summaries - fetches["histogram_summaries"] = self.histogram_summaries - fetches["losses"] = self.losses + fetches = {"train_op": self.trainer.train_op} + + if self.summaries: + fetches.update(self.trainer.summaries) + + fetches["losses"] = self.trainer.objective_values fetches["_update_ops"] = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - return self.all_coders, fetches, [{} for _ in range(self.num_sessions)] + return self.trainer.all_coders, fetches, [{}] def collect_results(self, results: List[Dict]) -> None: - if self.scalar_summaries is None: - scalar_summaries = None - histogram_summaries = None - else: - # TODO collect summaries from different sessions - scalar_summaries = results[0]["scalar_summaries"] - histogram_summaries = results[0]["histogram_summaries"] + assert len(results) == 1 + result = results[0] - losses_sum = [0. for _ in self.losses] - for session_result in results: - for i in range(len(self.losses)): - # from the end, losses are last ones - losses_sum[i] += session_result["losses"][i] - avg_losses = [s / len(results) for s in losses_sum] + scalar_summaries = ( + result["scalar_summaries"] if self.summaries else None) + histogram_summaries = ( + result["histogram_summaries"] if self.summaries else None) self.result = ExecutionResult( - [], losses=avg_losses, + [], losses=result["losses"], scalar_summaries=scalar_summaries, histogram_summaries=histogram_summaries, image_summaries=None) diff --git a/tests/labeler.ini b/tests/labeler.ini index ee7dd2fd7..7800255af 100644 --- a/tests/labeler.ini +++ b/tests/labeler.ini @@ -68,12 +68,16 @@ data_id="tags" dropout_keep_prob=0.5 vocabulary= - [trainer] -class=trainers.cross_entropy_trainer.CrossEntropyTrainer -decoders=[] +class=trainers.delayed_update_trainer.DelayedUpdateTrainer +batches_per_update=5 l2_weight=1.0e-8 clip_norm=1.0 +objectives=[] + +[obj] +class=trainers.cross_entropy_trainer.xent_objective +decoder= [runner] class=runners.LabelRunner diff --git a/tests/transformer.ini b/tests/transformer.ini index 968ea26ae..96f1895b5 100644 --- a/tests/transformer.ini +++ b/tests/transformer.ini @@ -72,11 +72,23 @@ depth=2 n_heads_self=3 n_heads_enc=2 +; [trainer] +; class=trainers.cross_entropy_trainer.CrossEntropyTrainer +; decoders=[] +; optimizer= + [trainer] -class=trainers.cross_entropy_trainer.CrossEntropyTrainer -decoders=[] +class=trainers.delayed_update_trainer.DelayedUpdateTrainer +batches_per_update=5 +l2_weight=1.0e-8 +clip_norm=1.0 +objectives=[] optimizer= +[obj] +class=trainers.cross_entropy_trainer.xent_objective +decoder= + [decayed_lr] class=functions.noam_decay learning_rate=0.2