diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 1775a823c57b6..871228820414c 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1037,7 +1037,7 @@ def _get_op_by_id(ops, id): grad_op_dist_attr.set_output_dims_mapping( output_name, ref_fwd_dims_mapping) - elif grad_op.type == 'fill_zeros_like': + elif grad_op.type == 'fill_any_like': ref_var_name = grad_op.input_arg_names[0] ref_var = vars[ref_var_name] ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( @@ -1274,7 +1274,7 @@ def _get_op_by_id(ops, id): grad_op_dist_attr.impl_type = "default" grad_op_dist_attr.impl_idx = 0 - elif grad_op.type == 'fill_zeros_like': + elif grad_op.type == 'fill_any_like': ref_var_name = grad_op.input_arg_names[0] ref_var = vars[ref_var_name] ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index cce4e63b63c52..24fcb10a78919 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -41,28 +41,35 @@ def set_field_default_config(category, field, default_value): ######################################### BASE = "base" set_field_default_config(BASE, "auto_mode", "semi") +set_field_default_config(BASE, "gradient_scale", True) +set_field_default_config(BASE, "use_cache", True) +set_field_default_config(BASE, "return_numpy", True) +set_field_default_config(BASE, "all_ranks", False) +set_field_default_config(BASE, "split_data", False) +set_field_default_config(BASE, "seed", None) ######################################### # recompute configuration ######################################### RECOMPUTE = "recompute" -set_field_default_config(RECOMPUTE, "enabled", False) +set_field_default_config(RECOMPUTE, "enable", False) set_field_default_config(RECOMPUTE, "checkpoints", None) +set_field_default_config(RECOMPUTE, "enable_tuning", False) ######################################### # AMP configuration ######################################### AMP = "amp" -set_field_default_config(AMP, "enabled", False) +set_field_default_config(AMP, "enable", False) set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) set_field_default_config(AMP, "incr_ratio", 2.0) set_field_default_config(AMP, "decr_ratio", 0.8) set_field_default_config(AMP, "use_dynamic_loss_scaling", True) -set_field_default_config(AMP, "custom_white_list", None) -set_field_default_config(AMP, "custom_black_list", None) -set_field_default_config(AMP, "custom_black_varnames", None) +set_field_default_config(AMP, "custom_white_list", []) +set_field_default_config(AMP, "custom_black_list", []) +set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "use_pure_fp16", False) set_field_default_config(AMP, "use_fp16_guard", True) set_field_default_config(AMP, "use_optimizer_fp16", False) @@ -71,16 +78,40 @@ def set_field_default_config(category, field, default_value): # sharding configuration ######################################### SHARDING = "sharding" -set_field_default_config(SHARDING, "enabled", False) +set_field_default_config(SHARDING, "enable", False) set_field_default_config(SHARDING, "stage", 1) set_field_default_config(SHARDING, "sharding_degree", 8) set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0) set_field_default_config(SHARDING, "enable_tuning", False) +set_field_default_config(SHARDING, "tuning_range", []) ######################################### # gradient merge configuration ######################################### GRADIENT_MERGE = "gradient_merge" -set_field_default_config(GRADIENT_MERGE, "enabled", False) +set_field_default_config(GRADIENT_MERGE, "enable", False) set_field_default_config(GRADIENT_MERGE, "k_steps", 1) set_field_default_config(GRADIENT_MERGE, "avg", True) + +######################################### +# quantization configuration +######################################### +QAT = "qat" +set_field_default_config(QAT, "enable", False) +set_field_default_config(QAT, "channel_wise_abs_max", True) +set_field_default_config(QAT, "weight_bits", 8) +set_field_default_config(QAT, "activation_bits", 8) +set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) +set_field_default_config(QAT, "algo", None) + +# ######################################### +# auto tuning configuration +# ######################################### +TUNING = "tuning" +set_field_default_config(TUNING, "enable", False) +set_field_default_config(TUNING, "batch_size", 1) +set_field_default_config(TUNING, "dataset", None) +set_field_default_config(TUNING, "profile_start_step", 1) +set_field_default_config(TUNING, "profile_end_step", 1) +set_field_default_config(TUNING, "run_after_tuning", True) +set_field_default_config(TUNING, "verbose", True) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 92a503659041e..29dd084b9e485 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -120,15 +120,12 @@ def __init__(self, self._backup_serial_main_program_stack = [] self._backup_serial_startup_program_stack = [] - # flag whether scale gradient with dp size + # # flag whether scale gradient with dp size self._gradient_scale = True # A flag indicates whether the used parallelism is data parallel self._data_parallel = False - # flag whether using `to_static` - self._dygraph_mode = False - @property def serial_main_program(self): return self._serial_main_program diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 3947e03e4d937..287550b300172 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -16,6 +16,7 @@ import time import copy import logging +import random import numpy as np from collections import defaultdict @@ -91,19 +92,43 @@ def __init__(self, cluster=None, strategy=None): self.model = model - self.loss = loss - self.optimizer = optimizer - self.metrics = metrics + + if loss and not isinstance(loss, + paddle.nn.Layer) and not callable(loss): + raise TypeError( + "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." + ) + self._loss = loss + + if optimizer and not isinstance( + optimizer, + (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): + raise TypeError( + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" + " or `paddle.fluid.optimizer.Optimizer`.") + self._optimizer = self._validate_opt(optimizer) + + metrics = metrics or [] + for metric in to_list(metrics): + assert isinstance(metric, Metric), \ + "{} is not sub class of Metric".format( + metric.__class__.__name__) + self._metrics = to_list(metrics) + self.cluster = cluster if self.cluster is None: self.cluster = get_default_cluster() - self.strategy = strategy - if self.strategy is None: - self.strategy = fleet.DistributedStrategy() + + if strategy and not isinstance(strategy, Strategy): + raise TypeError( + "'strategy' must be object of class 'paddle.distributed.auto_parallel.strategy'" + ) + self.strategy = strategy or Strategy() + if os.getenv("POD_NAME"): print("Distribute training by paddle.distributed.launch", flush=True) - fleet.init(is_collective=True, strategy=self.strategy) + fleet.init(is_collective=True) self._executor = None self._cur_rank = paddle.distributed.get_rank() @@ -128,54 +153,18 @@ def __init__(self, "eval": False, "predict": False } - self._dygraph_mode = False - - # TODO: move the following configuration to the strategy - self._gradient_scale = True - self._user_tuning_config = None - self._all_ranks = False - self._use_cache = False - self._return_numpy = True - - def _prepare(self): - if self.optimizer and not isinstance( - self.optimizer, - (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): - raise TypeError( - "'optimizer' must be object of class `paddle.optimizer.Optimizer`" - " or `paddle.fluid.optimizer.Optimizer`.") - self._optimizer = self.optimizer - if self.loss and not isinstance( - self.loss, paddle.nn.Layer) and not callable(self.loss): - raise TypeError( - "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." - ) - self._loss = self.loss - - metrics = self.metrics or [] - for metric in to_list(metrics): - assert isinstance(metric, Metric), \ - "{} is not sub class of Metric".format( - metric.__class__.__name__) - self._metrics = to_list(metrics) self._planned_mode = None - self._all_ranks = all_ranks - self._prepare_single_mode("train") + self._dygraph_mode = False + self._tuning = self.strategy.tuning def _prepare_single_mode(self, mode): - + # Do the build process self._build(mode) # Do the planning process self._plan(mode) - - # Do the Optimization tuning - if self._user_tuning_config and mode == "train": - self._optimization_tuning(mode) - # Do the parallel process self._parallel(mode) - # Init comm and startup program self._initialize(mode) self._mode_init_states[mode] = True @@ -250,21 +239,22 @@ def _build(self, mode): self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, self._optimizer, losses, feed_vars, fetch_vars, self.cluster, self.strategy) - self._dist_contexts[mode].gradient_scale = self._gradient_scale - self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode + self._dist_contexts[mode].gradient_scale = self.strategy.gradient_scale - def _optimization_tuning(self, mode): + def _optimization_tuning(self, mode, dataset, batch_size): + if not self.strategy.tuning.enable or mode != "train": + return + + # Do the build process + self._build(mode) + # Do the planning process + self._plan(mode) - self.mode = mode - assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size." - assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset." - batch_size = self._user_tuning_config["batch_size"] - dataset = self._user_tuning_config["dataset"] dataset.dp_world_size = self.dp_world_sizes dataset.dp_rank = self.dp_ranks from .tuner.optimization_tuner import OptimizationTuner - self._optimization_tuner = OptimizationTuner(self._user_tuning_config, + self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(), self._dist_contexts[mode], dataset, self.inputs_spec, @@ -274,7 +264,7 @@ def _optimization_tuning(self, mode): self._optimization_tuner.tune() - if self._user_tuning_config["run_after_tuning"]: + if self._tuning.run_after_tuning: # update the strategy self._dist_contexts[ mode]._strategy = self._optimization_tuner.get_best_config() @@ -365,6 +355,11 @@ def _initialize(self, mode): if isinstance(place, fluid.CUDAPlace): place = fluid.CUDAPlace(ParallelEnv().dev_id) + if self.strategy.seed: + paddle.seed(self.strategy.seed + self.dp_ranks[0]) + np.random.seed(self.strategy.seed + self.dp_ranks[0]) + random.seed(self.strategy.seed + self.dp_ranks[0]) + if self._dygraph_mode: dist_context = self._dist_contexts[mode] dist_main_program = self._dist_main_progs[mode][self._cur_rank] @@ -383,59 +378,31 @@ def _initialize(self, mode): prune_startup_prog = dist_startup_prog._prune(uninitialized) self._executor.run(prune_startup_prog) - if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']: - # from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16 - def cast_parameters_to_fp16(place, - program, - scope=None, - to_fp16_var_names=None): - """ - Traverse all parameters in the whole model and set them to the FP16 data type. - Whereas, this function will keep parameters of batchnorms in FP32. - Args: - place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors. - program (Program): The used program. - scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values. - Default is None. - to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` - will be set to FP16. Usually, it is the returned - value of `cast_model_to_fp16` API. - """ - from paddle.framework import core - import numpy as np - all_parameters = [] - for block in program.blocks: - all_parameters.extend(block.all_parameters()) - - var_scope = scope if scope else paddle.static.global_scope() - for param in all_parameters: - if param.dtype == core.VarDesc.VarType.FP16: - param_t = var_scope.find_var( - param.name).get_tensor() - data = np.array(param_t) - param_t.set(np.float16(data), place) - - cast_parameters_to_fp16(place, prune_startup_prog) + else: + self._logger.info("NOTE: parameters wiil be re-initialized.") + dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] + self._executor.run(dist_startup_prog) def _infer_sample_spec(self, data, batch_size, split): - if isinstance(data, paddle.io.Dataset): + if isinstance(data, paddle.io.IterableDataset): if split is None: - input, label = data[0] + input, label = next(iter(data)) else: - sample = data[0] + sample = next(iter(data)) input = sample[:split] label = sample[split:] - elif isinstance(data, paddle.io.IterableDataset): + elif isinstance(data, paddle.io.Dataset): if split is None: - input, label = next(iter(data)) + input, label = data[0] else: - sample = next(iter(data)) + sample = data[0] input = sample[:split] label = sample[split:] else: raise ValueError( "Data should be a Dataset or IterableDatset, but received {}.". format(type(data).__name__)) + self.inputs_spec = [] self.labels_spec = [] input_list = to_list(input) @@ -463,11 +430,14 @@ def _infer_item_spec(item, name, batch_size, specs): name = "input" + str(i) _infer_item_spec(item, name, batch_size, self.inputs_spec) if label_list is not None: - for item in label_list: + for i, item in enumerate(label_list): assert item is not None, "Receive None input." name = "label" + str(i) _infer_item_spec(item, name, batch_size, self.labels_spec) + self.inputs_spec = self._validate_spec(self.inputs_spec) + self.labels_spec = self._validate_spec(self.labels_spec) + def fit(self, train_data, train_sample_split=None, @@ -524,8 +494,8 @@ def fit(self, assert valid_data is None, "No support for validation for now" self.mode = 'train' self._infer_sample_spec(train_data, batch_size, train_sample_split) - if not self._mode_init_states['train']: - self._prepare() + if not self._mode_init_states[self.mode]: + self._prepare_single_mode(self.mode) assert self.mode in self._dist_main_progs, \ "train model is not ready, please call `engine.prepare()` first." @@ -537,31 +507,29 @@ def fit(self, usr_fetch = self._validate_fetches(fetches) fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) - lr_scheduler = self.get_lr_scheduler(self.main_program) + lr_scheduler = self._get_lr_scheduler(self.main_program) + outputs = [] for epoch in range(epochs): train_logs = {"epoch: {:d} ": epoch} for step, _ in enumerate(train_dataloader): try: - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) + outs = self._executor.run( + self.main_program, + fetch_list=fetch_list, + use_program_cache=self.strategy.use_cache, + return_numpy=self.strategy.return_numpy) except fluid.core.EOFException: break - + # update lr train_logs["step: {:d} "] = step - if lr_scheduler is not None and step % self.k_steps == 0: + if lr_scheduler is not None and step % self._k_steps == 0: lr_scheduler.step() - try: - train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr() - except: - train_logs[ - "lr: {:5e} "] = self._lr_optimizer._learning_rate.get_lr( - ) + train_logs["lr: {:5e} "] = self._get_lr(self._lr_optimizer) # inner fetches if fetch_loss: train_logs["loss: {:9f} "] = outs[0][0] + outputs.append(outs[:len(fetch_loss)]) # user fetches user_outs = outs[len(fetch_loss):] user_fetch_list = fetch_list[len(fetch_loss):] @@ -571,6 +539,8 @@ def fit(self, string = '[train] ' + ''.join(list(train_logs.keys())) self._logger.info(string.format(*list(train_logs.values()))) + return outputs + def evaluate(self, eval_data, eval_sample_split=None, @@ -618,18 +588,21 @@ def evaluate(self, inner_fetch = dict(fetch_loss, **fetch_metrics) fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) + outputs = [] for step, _ in enumerate(eval_dataloader): eval_logs = {"step: {:d} ": step} try: - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) + outs = self._executor.run( + self.main_program, + fetch_list=fetch_list, + use_program_cache=self.strategy.use_cache, + return_numpy=self.strategy.return_numpy) except fluid.core.EOFException: break # inner fetches if fetch_loss: eval_logs["loss: {:9f} "] = outs[0][0] + outputs.append(outs[:len(fetch_loss)]) # Metric if fetch_metrics: metric_out = outs[len(fetch_loss):len(inner_fetch)] @@ -696,10 +669,11 @@ def predict(self, for step, _ in enumerate(test_dataloader): predict_logs = {"step: {:d} ": step} try: - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) + outs = self._executor.run( + self.main_program, + fetch_list=fetch_list, + use_program_cache=self.strategy.use_cache, + return_numpy=self.strategy.return_numpy) except fluid.core.EOFException: break outputs.append(outs[:len(fetch_outputs)]) @@ -711,6 +685,11 @@ def predict(self, return outputs + def _tune(self, tune_data, tune_sample_split=None, batch_size=1): + self.mode = 'train' + self._infer_sample_spec(tune_data, batch_size, tune_sample_split) + self._optimization_tuning(self.mode, tune_data, batch_size) + def _create_dataloader(self, dataset, batch_size, @@ -719,9 +698,9 @@ def _create_dataloader(self, collate_fn=None): if self.strategy.gradient_merge and batch_size is not None: - assert batch_size % self.k_steps == 0, \ - "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self.k_steps) - batch_size //= self.k_steps + assert batch_size % self._k_steps == 0, \ + "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) + batch_size //= self._k_steps dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] @@ -784,9 +763,7 @@ def _create_dataloader(self, def _validate_spec(self, specs): specs = to_list(specs) - self.k_steps = 1 - if self.strategy.gradient_merge: - self.k_steps = self.strategy.gradient_merge_configs['k_steps'] + self._k_steps = self.strategy.gradient_merge.k_steps if specs is not None: for i, spec in enumerate(specs): assert isinstance(spec, InputSpec) @@ -794,11 +771,11 @@ def _validate_spec(self, specs): raise ValueError( "Requires Input[{}].name != None, but receive `None` with {}." .format(i, spec)) - if self.k_steps > 1: + if self._k_steps > 1: shape = list(spec.shape) - assert shape[0] % self.k_steps == 0, \ - "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self.k_steps) - shape[0] //= self.k_steps + assert shape[0] % self._k_steps == 0, \ + "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps) + shape[0] //= self._k_steps spec.shape = shape return specs @@ -859,7 +836,7 @@ def _set_recompute_ckpts(self): # NOTE hack to enable recompute in engine api for GPT-3 # TODO support more PaddleNLP/CV models here - config = self.strategy.recompute_configs + recompute = self.strategy.recompute # extract ckpts by specific model if isinstance(self.model, paddle.nn.Layer): @@ -868,23 +845,28 @@ def _set_recompute_ckpts(self): ) and self.model.__class__.__name__ == 'GPTForPretraining': exact_ckpts = self.model.gpt.checkpoints else: - exact_ckpts = config["checkpoints"] + exact_ckpts = recompute.checkpoints else: - exact_ckpts = config["checkpoints"] + exact_ckpts = recompute.checkpoints # modify strategy - if self.strategy.recompute: - config["checkpoints"] = exact_ckpts[:] - self.strategy.recompute_configs = config + if recompute.enable: + recompute.checkpoints = exact_ckpts[:] logs = { 'Model Class': self.model.__class__.__name__, 'Applied Recompute ckpts': exact_ckpts } self._logger.info(logs) + def _validate_opt(self, optimizer): + optimizer._parameter_list = None + optimizer._param_groups = None + return optimizer + def save(self, path, training=True): """ Saves the model, parameters, optimizer state to path. + If `training` is set to False, only inference model will be saved. Args: path (str): The file prefix to save model. The format @@ -957,7 +939,7 @@ def load(self, path, strict=True, load_optimizer=True): load_optimizer) @staticmethod - def get_lr_scheduler(program): + def _get_lr_scheduler(program): lr_sheduler = None if hasattr(program, 'lr_sheduler'): from paddle.optimizer.lr import LRScheduler @@ -965,6 +947,17 @@ def get_lr_scheduler(program): assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" return lr_sheduler + def _get_lr(self, optimizer): + if isinstance(optimizer, paddle.optimizer.Optimizer): + return optimizer.get_lr() + elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer): + return optimizer._learning_rate.get_lr() + else: + raise TypeError( + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ + " or `paddle.fluid.optimizer.Optimizer`." + ) + @property def mode(self): return self._mode diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 0851d8f80c9c9..b83a19b512ef8 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -160,8 +160,8 @@ def _apply_pre_optimization(self, main_program, startup_program, loss, # apply quantization pass # The pass can be applied when mode must be 'train' - if self._mode == 'train' and self._strategy.qat: - config = copy.deepcopy(self._strategy.qat_configs) + if self._mode == 'train' and self._strategy.qat.enable: + config = copy.deepcopy(self._strategy.qat.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads auto_parallel_quantization_pass = new_pass( @@ -176,8 +176,8 @@ def _apply_pre_optimization(self, main_program, startup_program, loss, # apply amp pass # FIXME we disenable amp for eval since it has a little bug with # eval program and which will be fixed in future - if self._mode == 'train' and self._strategy.amp: - config = copy.deepcopy(self._strategy.amp_configs) + if self._mode == 'train' and self._strategy.amp.enable: + config = copy.deepcopy(self._strategy.amp.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["loss"] = loss @@ -195,8 +195,8 @@ def _apply_pre_optimization(self, main_program, startup_program, loss, # apply recompute pass # recompute is then train-only optimization - if self._mode == "train" and self._strategy.recompute: - config = copy.deepcopy(self._strategy.recompute_configs) + if self._mode == "train" and self._strategy.recompute.enable: + config = copy.deepcopy(self._strategy.recompute.to_dict()) config["dist_context"] = self._dist_context config["no_grad_set"] = None config["loss"] = loss @@ -217,12 +217,12 @@ def _apply_post_optimization(self, main_program, startup_program, rank, config = {} config["dist_context"] = self._dist_context config["global_rank"] = rank - config["use_sharding"] = self._strategy.sharding + config["use_sharding"] = self._strategy.sharding.enable dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass.apply([main_program], [startup_program], self._pass_context) - if self._strategy.sharding: - config = copy.deepcopy(self._strategy.sharding_configs) + if self._strategy.sharding.enable: + config = copy.deepcopy(self._strategy.sharding.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["global_rank"] = rank @@ -234,7 +234,7 @@ def _apply_post_optimization(self, main_program, startup_program, rank, # GradClip is train-only optimization if self._mode == "train": - config = copy.deepcopy(self._strategy.sharding_configs) + config = copy.deepcopy(self._strategy.sharding.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["rank_id"] = rank @@ -244,8 +244,8 @@ def _apply_post_optimization(self, main_program, startup_program, rank, self._pass_context) # gradient_merge is then train-only optimization - if self._mode == "train" and self._strategy.gradient_merge: - config = copy.deepcopy(self._strategy.gradient_merge_configs) + if self._mode == "train" and self._strategy.gradient_merge.enable: + config = copy.deepcopy(self._strategy.gradient_merge.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads auto_parallel_gradient_merge_pass = new_pass( diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 06d921e84fc4c..e0e260907f263 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -65,6 +65,14 @@ def __repr__(self): sort_keys=True, indent=4) + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + return result + class RecomputeConfig(BaseConfig): @@ -94,6 +102,20 @@ def __init__(self, config_dict=None): super(GradientMergeConfig, self).__init__(category, config_dict) +class QATConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.QAT + super(QATConfig, self).__init__(category, config_dict) + + +class TuningConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.TUNING + super(TuningConfig, self).__init__(category, config_dict) + + class Strategy(BaseConfig): """ The `Strategy` object is used to configure the paralleization and optimization beheviors. @@ -104,21 +126,24 @@ class Strategy(BaseConfig): configurations while other default configurations are left unchanged. If this is a string, it is interpreted as the path to a YAML configuration and will be loaded to override the corresponding default configurations. - + Examples: .. code-block:: python import paddle import paddle.distributed.auto_parallel as auto - + strategy = auto.Strategy() sharding = strategy.sharding self.assertEqual(sharding.enabled, False) self.assertEqual(sharding.stage, 1) + self.assertEqual(sharding.sharding_degree, 8) sharding.enabled = True - sharding.state = 2 + sharding.stage = 2 + sharding.sharding_degree = 2 self.assertEqual(sharding.enabled, True) self.assertEqual(sharding.stage, 2) + self.assertEqual(sharding.sharding_degree, 2) """ @@ -135,6 +160,7 @@ def __init__(self, config=None): .format(config)) else: self._config_dict = {} + category = constants.BASE super(Strategy, self).__init__(category, self._config_dict) @@ -150,5 +176,8 @@ def __init__(self, config=None): config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None) self.gradient_merge = GradientMergeConfig(config_dict) - config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None) - self.gradient_merge = GradientMergeConfig(config_dict) + config_dict = self._config_dict.get(constants.QAT, None) + self.qat = QATConfig(config_dict) + + config_dict = self._config_dict.get(constants.TUNING, None) + self.tuning = TuningConfig(config_dict) diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index 6657e1ae96e32..16b0cea342dfb 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase): # TODO import trial class & copy strategy def __init__(self, config): super().__init__(config) - self._changed_configs = ["sharding_configs"] + self._changed_configs = ["sharding"] def _init_spaces(self): self._max_stage = 3 self._trial_idx = 0 - stage_range = self._config.sharding_configs.get("stage_range", None) + stage_range = self._config.sharding.to_dict().get("tuning_range", None) if stage_range: assert set(stage_range).issubset( set([0, 1, 2, 3]) @@ -136,9 +136,8 @@ def next_trial(self): stage = self._stage_range[self._trial_idx] new_strategy = copy.deepcopy(self._config.dist_strategy) - config_dict = new_strategy.sharding_configs - config_dict["stage"] = stage - new_strategy.sharding_configs = config_dict + sharding = new_strategy.sharding + sharding.stage = stage name = "trial-sharding-stage{}".format(stage) trial = Trial(new_strategy, name, self.changed_configs) diff --git a/python/paddle/distributed/auto_parallel/tuner/config.py b/python/paddle/distributed/auto_parallel/tuner/config.py index 19818a3a65570..b1eedbe04f0eb 100644 --- a/python/paddle/distributed/auto_parallel/tuner/config.py +++ b/python/paddle/distributed/auto_parallel/tuner/config.py @@ -17,15 +17,13 @@ import pathlib import paddle -from paddle.distributed import fleet +from ..strategy import Strategy _tuning_supported_passes = ["sharding", "recompute"] -_strategy_config_suffiex = "_configs" def _get_pass_config(strategy, pass_name): - config_name = pass_name + _strategy_config_suffiex - config = getattr(strategy, config_name) + config = getattr(strategy, pass_name) return config @@ -38,10 +36,8 @@ class TuningConfig(object): def __init__(self, user_config, strategy): - if not isinstance(strategy, fleet.DistributedStrategy): - raise TypeError( - "'strategy' must be object of class `fleet.DistributedStrategy`." - ) + if not isinstance(strategy, Strategy): + raise TypeError("'strategy' must be object of class `Strategy`.") if not user_config: user_config = {} @@ -116,11 +112,11 @@ def _initialize(self, user_config): for p in _tuning_supported_passes: if getattr(self._dist_strategy, p) and _get_pass_config( - self._dist_strategy, p)["enable_tuning"]: + self._dist_strategy, p).enable_tuning: # TODO distinguish different args of each passes self._tuning_passes_name.add(p) - config_name = p + _strategy_config_suffiex + config_name = p p_dict = getattr(self._dist_strategy, config_name) self.__dict__[config_name] = p_dict diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index bec74e315becd..aa807641ee1e2 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import yaml import os import sys import copy @@ -256,8 +257,8 @@ def _apply_optimization(self, trial): startup_program = dist_context.serial_startup_program # applying optimization pass - if new_strategy.amp: - config = copy.deepcopy(new_strategy.amp_configs) + if new_strategy.amp.enable: + config = copy.deepcopy(new_strategy.amp.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_context._params_grads @@ -275,8 +276,8 @@ def _apply_optimization(self, trial): auto_parallel_amp_pass.apply([main_program], [startup_program], pass_context) - if new_strategy.recompute: - config = copy.deepcopy(new_strategy.recompute_configs) + if new_strategy.recompute.enable: + config = copy.deepcopy(new_strategy.recompute.to_dict()) config["dist_context"] = dist_context config["no_grad_set"] = None config["loss"] = dist_context.serial_loss @@ -303,8 +304,8 @@ def _apply_optimization(self, trial): dist_context, dist_params_grads) resharder.reshard() - if new_strategy.sharding: - config = copy.deepcopy(new_strategy.sharding_configs) + if new_strategy.sharding.enable: + config = copy.deepcopy(new_strategy.sharding.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_params_grads config["global_rank"] = self.rank @@ -313,8 +314,8 @@ def _apply_optimization(self, trial): auto_parallel_sharding_pass.apply([dist_main_prog], [dist_startup_prog], pass_context) - if new_strategy.gradient_merge: - config = copy.deepcopy(new_strategy.gradient_merge_configs) + if new_strategy.gradient_merge.enable: + config = copy.deepcopy(new_strategy.gradient_merge.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_params_grads auto_parallel_gradient_merge_pass = new_pass( @@ -493,8 +494,9 @@ def summary(self): fw.write(line + "\n") full_strategy = self.get_best_config() - full_strategy.save_to_prototxt( - os.path.join(self.project_dir, "tuned_dist_strategy.prototxt")) + path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml") + with open(path, 'w') as outfile: + yaml.dump(full_strategy, outfile, default_flow_style=False) def clear(self): """ diff --git a/python/paddle/distributed/auto_parallel/tuner/trial.py b/python/paddle/distributed/auto_parallel/tuner/trial.py index 3937ca9865181..edc588b4c70fe 100644 --- a/python/paddle/distributed/auto_parallel/tuner/trial.py +++ b/python/paddle/distributed/auto_parallel/tuner/trial.py @@ -156,9 +156,10 @@ def summary(self): draws += h1_format.format("{} auto=True <-> {}".format(name, name)) draws += line + "\n" my_configs = getattr(self.space, name) - keys = my_configs.keys() + keys = my_configs.to_dict().keys() for key in keys: - draws += h2_format.format(key, str(my_configs.get(key, None))) + draws += h2_format.format( + key, str(my_configs.to_dict().get(key, None))) result_res = draws + border return result_res diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 370316767ece1..724f98eb89af9 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1097,7 +1097,7 @@ def set_grad_var_shape(program, dist_context): if op.type in [ "c_allreduce_sum", "c_identity", "scale", "cast", - "fill_zeros_like" + "fill_any_like" ]: forward_var_name = op.input_arg_names[0] elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 3f3448b5008e6..458cb26ccd481 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -314,7 +314,9 @@ def _keep_fp32_output(op, out_name): consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) else: - assert in_var.dtype == dst_dtype + assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format( + grad_op.type, in_name, dst_dtype, in_var.dtype, + str(grad_op)) for out_name in grad_op.output_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 586ad235fd15a..cd668d175cf5c 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -154,7 +154,7 @@ def _analyze_program(self): def _could_be_prune(self): - return self.dist_context._gradient_scale and ( + return self.dist_context.gradient_scale and ( self._support_rescale_grad or self._all_dp_groups_same_degree()) def _all_dp_groups_same_degree(self): diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 7702de7c01edd..4f735984ae370 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -379,6 +379,10 @@ def _insert_backward_cast_ops(self, op, idx, block, src_dtype, dst_dtype, # create cast grad grad_slot_name = slot_name + "@GRAD" assert grad_slot_name in op.output_names + if len(op.output(grad_slot_name)) == 0: + var = block.var(src_name) + assert var.stop_gradient is True + continue assert len(op.output(grad_slot_name)) == 1 grad_name = op.output(grad_slot_name)[0] grad = block.var(grad_name) @@ -536,6 +540,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): return output_var +def cast_startup_program(): + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + + param_to_dtype = {} + for block in main_program.blocks: + for p in block.all_parameters(): + param_to_dtype[p.name] = p.dtype + + def is_initialization_op(op): + comm_op_prefix = "c_" + op_type = op.type + if op_type.startswith(comm_op_prefix): + return False + + if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0: + return False + + return True + + for op in startup_program.global_block().ops: + if is_initialization_op(op): + output_name = op.output_arg_names[0] + if param_to_dtype.get(output_name, + None) == core.VarDesc.VarType.FP16: + assert op.has_attr( + 'dtype' + ), "initialization op is supported to has dtype attribute but got {}.".format( + str(op)) + if op.attr('dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.FP16) + + @register_pass("auto_parallel_fp16") class FP16Pass(AMPPass): @@ -563,6 +600,8 @@ def _apply_single_impl(self, main_program, startup_program, context): input_data_var_names) is_train = fp16_state._build_state() + cast_startup_program() + if is_train: with paddle.static.program_guard(main_program, startup_program): # TODO (JZ-LIANG)support cast forward program only when inference diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 2d005e2ab1bc4..604db0cf3cf4c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -40,6 +40,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS}) set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS + ${dist_ENVS}) + set_tests_properties(test_iterable_dataset + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) @@ -76,8 +80,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_process_mesh MODULES test_process_mesh) py_test_modules(test_interface MODULES test_interface) py_test_modules(test_stategy MODULES test_strategy) - py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS - ${dist_ENVS}) - set_tests_properties(test_iterable_dataset - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) + endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py new file mode 100644 index 0000000000000..66fac22eac96c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto + +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed.auto_parallel.strategy import Strategy +from paddle.distributed.auto_parallel.engine import Engine +from get_gpt_model import generate_model, create_data_holder, FakeDataset + + +def apply_pass(use_amp=False, level=None): + strategy = Strategy() + strategy.auto_mode = "semi" + if use_amp: + amp = strategy.amp + amp.enable = True + amp.custom_white_list = ['softmax', 'layer_norm', 'gelu'] + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum' + ] + amp.init_loss_scaling = 32768 + amp.use_fp16_guard = False + amp.use_pure_fp16 = level in ["o2", "o3"] + amp.use_optimizer_fp16 = level == "o3" + print("amp level: ", level) + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestAMPPass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_amp=False, level=None): + reset_prog() + + strategy = apply_pass(use_amp, level) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + + engine = Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_amp_pass(self): + # mp2 training + mp_engine = self.get_engine() + mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + mp_losses = np.array(mp_losses) + + # mp2 amp-o1 training + amp_o1_engine = self.get_engine(True, "o1") + amp_o1_losses = amp_o1_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + amp_o1_losses = np.array(amp_o1_losses) + self.check_results(mp_losses, amp_o1_losses) + + # mp2 amp-o2 training + amp_o2_engine = self.get_engine(True, "o2") + amp_o2_losses = amp_o2_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + amp_o2_losses = np.array(amp_o2_losses) + self.check_results(mp_losses, amp_o2_losses) + + # mp2 amp-o3 training + amp_o3_engine = self.get_engine(True, "o3") + amp_o3_losses = amp_o3_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + amp_o3_losses = np.array(amp_o3_losses) + self.check_results(mp_losses, amp_o3_losses) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py index 2a6e01eac745d..4639abf32554e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py @@ -32,7 +32,7 @@ paddle.enable_static() _global_parallel_strategy = None -_global_process_mesh = None +_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) batch_size = 4 hidden_size = 1024 sequence_len = 512 @@ -122,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program): def train(): - global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) - dist_strategy = fleet.DistributedStrategy() dist_strategy.amp = False dist_strategy.pipeline = False diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py index 60a915c53cddf..576596922807f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py @@ -18,9 +18,10 @@ import numpy as np import paddle -import paddle.distributed.fleet as fleet import paddle.distributed.auto_parallel as auto +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed.auto_parallel.strategy import Strategy from paddle.distributed.auto_parallel.engine import Engine from get_gpt_model import generate_model, create_data_holder, FakeDataset @@ -28,14 +29,12 @@ def apply_pass(use_sharding=False): - strategy = fleet.DistributedStrategy() - strategy.semi_auto = True + strategy = Strategy() + strategy.auto_mode = "semi" if use_sharding: - strategy.sharding = True - strategy.sharding_configs = { - "sharding_degree": 2, - "stage": 2, - } + sharding = strategy.sharding + sharding.sharding_degree = 2 + sharding.stage = 2 return strategy @@ -76,34 +75,17 @@ def init(self, engine): paddle.seed(2022) np.random.seed(2022) random.seed(2022) - engine.mode = "train" - engine._executor.run(engine.startup_program) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) - def get_dp2_engine(self): + def get_engine(self, use_sharding=False): reset_prog() - strategy = apply_pass() + strategy = apply_pass(use_sharding) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) model, loss = generate_model("dp") - inputs_spec, labels_spec = create_data_holder(self.batch_size) - - engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) - engine.prepare(optimizer=opt, loss=loss) - self.init(engine) - return engine - - def get_dp2sharding2_engine(self): - reset_prog() - - strategy = apply_pass(True) - clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) - opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) - model, loss = generate_model("dp") - inputs_spec, labels_spec = create_data_holder(self.batch_size) - - engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) - engine.prepare(optimizer=opt, loss=loss) + engine = Engine(model, loss, opt, strategy=strategy) self.init(engine) return engine @@ -121,15 +103,13 @@ def check_result(self, dp_params, sharding_params): def test_grad_clip(self): # dp2 training - dp_engine = self.get_dp2_engine() - dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True) + dp_engine = self.get_engine() + dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) dp_param_values = get_parameter_value(dp_engine.main_program) # dp2sharding2 training - sharding_engine = self.get_dp2sharding2_engine() - sharding_engine.fit(self.dataset, - batch_size=self.batch_size, - use_cache=True) + sharding_engine = self.get_engine(True) + sharding_engine.fit(self.dataset, 3, batch_size=self.batch_size) sharding_param_values = get_parameter_value( sharding_engine.main_program) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index ea42641aa7c36..93edb1d78931f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -27,9 +27,9 @@ import paddle.utils as utils from paddle.fluid import layers from paddle.io import Dataset, IterableDataset, DataLoader -from paddle.static import InputSpec -from paddle.distributed import fleet + import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.strategy import Strategy from paddle.distributed.auto_parallel.engine import Engine from paddle.optimizer.lr import CosineAnnealingDecay from paddle.fluid.dataloader.collate import default_collate_fn @@ -118,10 +118,10 @@ def train(fetch): grad_clip=None) metric = paddle.metric.Accuracy() - strategy = fleet.DistributedStrategy() - strategy.semi_auto = True + strategy = Strategy() + strategy.auto_mode = "semi" - engine = Engine(mlp, loss, optimizer, metric, strategy) + engine = Engine(mlp, loss, optimizer, metric, strategy=strategy) # train train_dataset = MyDataset(batch_num * batch_size) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py index 92a5813f2d1ac..de9eab53ce193 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py @@ -26,10 +26,10 @@ import paddle.nn.functional as F import paddle.utils as utils from paddle.fluid import layers -from paddle.io import Dataset, IterableDataset, DataLoader -from paddle.static import InputSpec -from paddle.distributed import fleet +from paddle.io import Dataset, DataLoader + import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.strategy import Strategy from paddle.distributed.auto_parallel.engine import Engine paddle.enable_static() @@ -108,12 +108,8 @@ def train(fetch): epsilon=1e-08, grad_clip=None) - dist_strategy = fleet.DistributedStrategy() - dist_strategy.amp = False - dist_strategy.pipeline = False - dist_strategy.recompute = False - # init parallel optimizer - dist_strategy.semi_auto = True + dist_strategy = Strategy() + dist_strategy.auto_mode = "semi" # init engine engine = Engine(mlp, @@ -124,9 +120,7 @@ def train(fetch): # train train_dataset = MyDataset(batch_num * batch_size) - engine.fit(train_dataset, - batch_size=batch_size, - steps_per_epoch=batch_num * batch_size) + engine.fit(train_dataset, batch_size=batch_size) # eval eval_dataset = MyDataset(batch_size) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index 2884a03a023e5..f5071cb469400 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -16,6 +16,7 @@ import numpy as np import paddle +import paddle.distributed.auto_parallel as auto sys.path.append("..") import auto_parallel_gpt_model as modeling @@ -25,7 +26,7 @@ vocab_size = 1000 -class FakeDataset: +class FakeDataset(paddle.io.Dataset): def __init__(self, num_samples): self.num_samples = num_samples @@ -67,8 +68,9 @@ def create_data_holder(batch_size): def generate_model(strategy): modeling.init_global() - modeling._global_process_mesh = list( - range(paddle.distributed.get_world_size())) + ranks = list(range(paddle.distributed.get_world_size())) + modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks, + dim_names=["x"]) if strategy == "serial": modeling._global_parallel_strategy = "serial" elif strategy == "mp": diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py new file mode 100644 index 0000000000000..f21722e05cf46 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto + +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed.auto_parallel.strategy import Strategy +from paddle.distributed.auto_parallel.engine import Engine +from get_gpt_model import generate_model, create_data_holder, FakeDataset + +paddle.enable_static() + + +def apply_pass(use_gradient_merge=False): + strategy = Strategy() + strategy.auto_mode = "semi" + if use_gradient_merge: + gradient_merge = strategy.gradient_merge + gradient_merge.enable = True + gradient_merge.k_steps = 4 + gradient_merge.avg = True + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestGradientMergePass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 8 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_gradient_merge=False): + reset_prog() + + strategy = apply_pass(use_gradient_merge) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("dp") + + engine = Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_gradient_merge_pass(self): + # dp2 training + dp_engine = self.get_engine() + dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + dp_losses = np.array(dp_losses) + + # dp2 gradient merge training + gm_engine = self.get_engine(True) + gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size) + gm_losses = np.array(gm_losses) + + avg_loss = 0 + pass_avg_ret_list = [] + for i, pass_ret in enumerate(gm_losses): + if (i + 1) % 4 == 0: + avg_loss += pass_ret[0] + pass_avg_ret_list.append([avg_loss / 4]) + avg_loss = 0 + else: + avg_loss += pass_ret[0] + + self.check_results(dp_losses, np.array(pass_avg_ret_list)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py index 56ef762cba5db..404d02e622df2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py @@ -18,10 +18,9 @@ import numpy as np import paddle.distributed.auto_parallel as auto -from paddle.static import InputSpec -from paddle.distributed import fleet from paddle.incubate.autograd import Hessian from paddle.distributed.auto_parallel.engine import Engine +from paddle.distributed.auto_parallel.strategy import Strategy np.random.seed(1234) paddle.seed(1234) @@ -129,8 +128,8 @@ def main(): # model laplace = LaplaceModel() - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True + dist_strategy = Strategy() + dist_strategy.auto_mode = "semi" engine = Engine(laplace, loss=loss_func, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py b/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py index 4ca3d14f7165a..ecf6eeb5526a1 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py @@ -28,8 +28,9 @@ from paddle.fluid import layers from paddle.io import Dataset, IterableDataset, DataLoader from paddle.static import InputSpec -from paddle.distributed import fleet + import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.strategy import Strategy from paddle.distributed.auto_parallel.engine import Engine from paddle.optimizer.lr import CosineAnnealingDecay from paddle.fluid.dataloader.collate import default_collate_fn @@ -48,10 +49,9 @@ paddle.seed(44) -class MyDataset(IterableDataset): +class MyDataset(paddle.io.IterableDataset): def __init__(self, num_samples): - super(MyDataset, self).__init__() self.num_samples = num_samples def __iter__(self): @@ -61,10 +61,9 @@ def __iter__(self): yield input, label -class MyDataset1(Dataset): +class MyDataset1(paddle.io.Dataset): def __init__(self, num_samples): - super(MyDataset1, self).__init__() self.num_samples = num_samples self.data = [] for i in range(self.num_samples): @@ -112,12 +111,10 @@ def __init__(self, self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - out = auto.shard_op(self.norm, dist_attr={"process_mesh": - PP_MESH_0})(input) + out = auto.shard_op(self.norm, PP_MESH_0)(input) out = self.linear0(out) out = F.gelu(out, approximate=True) - out = auto.shard_op(self.linear1, dist_attr={"process_mesh": - PP_MESH_1})(out) + out = auto.shard_op(self.linear1, PP_MESH_1)(out) out = self.dropout(out) out = self.linear2(out) self.out = out @@ -136,54 +133,36 @@ def train(fetch): epsilon=1e-08, grad_clip=None) - inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') - labels_spec = InputSpec([batch_size], 'int64', 'label') - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True + dist_strategy = Strategy() + dist_strategy.auto_mode = "semi" dist_strategy.split_data = True - fleet.init(is_collective=True, strategy=dist_strategy) # init engine engine = Engine(mlp, - inputs_spec=inputs_spec, - labels_spec=labels_spec, + loss, + optimizer, + paddle.metric.Accuracy(), strategy=dist_strategy) - engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) - - # fetch - if fetch: - fetches = {'out': mlp.out} - else: - fetches = None # train train_dataset = MyDataset(batch_num * batch_size) - train_dataset1 = MyDataset1(batch_num) - engine.fit(train_dataset, - epochs=2, - batch_size=batch_size, - steps_per_epoch=batch_num, - fetches=fetches) - - engine.fit(train_dataset1, - epochs=2, - batch_size=None, - steps_per_epoch=batch_num, - fetches=fetches) + engine.fit(train_dataset, epochs=2, batch_size=batch_size) + + train_dataset1 = MyDataset1(batch_size * batch_num) + engine.fit(train_dataset1, epochs=2, batch_size=None) # eval eval_dataset = MyDataset(batch_size) - engine.evaluate(eval_dataset, batch_size, fetches=fetches) + engine.evaluate(eval_dataset, batch_size=batch_size) # predict test_dataset = MyDataset(batch_size) - engine.predict(test_dataset, batch_size, fetches=fetches) + engine.predict(test_dataset, batch_size=batch_size) # save temp_dir = tempfile.TemporaryDirectory() model_filename = os.path.join(temp_dir.name, 'mlp_inf') - engine.save(model_filename, training=False, mode='predict') + engine.save(model_filename, training=False) temp_dir.cleanup() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py index 8e058d16b87b3..77e916c4a232c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py @@ -27,10 +27,10 @@ import paddle.utils as utils from paddle.fluid import layers from paddle.io import Dataset, IterableDataset, DataLoader -from paddle.static import InputSpec -from paddle.distributed import fleet + import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.engine import Engine +from paddle.distributed.auto_parallel.strategy import Strategy from engine_api_dp import MyDataset paddle.enable_static() @@ -43,20 +43,6 @@ paddle.seed(44) -# class MyDataset(Dataset): - -# def __init__(self, num_samples): -# super(MyDataset, self).__init__() -# self.num_samples = num_samples - -# def __getitem__(self, index): -# input = np.random.uniform(size=image_size).astype("float32") -# label = np.random.randint(0, class_num - 1, dtype="int64") -# return input, label - -# def __len__(self): -# return self.num_samples - class MLPLayer(nn.Layer): @@ -107,50 +93,33 @@ def train(fetch): epsilon=1e-08, grad_clip=None) - inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') - labels_spec = InputSpec([batch_size], 'int64', 'label') - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.amp = False - dist_strategy.pipeline = False - dist_strategy.recompute = False - # init parallel optimizer - dist_strategy.semi_auto = True - dist_strategy.sharding = True - dist_strategy.sharding_configs = { - "sharding_degree": 2, - "stage": 3, - "enable_tuning": True, - } - fleet.init(is_collective=True, strategy=dist_strategy) - - # init engine - import tempfile - tmp_dir = tempfile.TemporaryDirectory() - dataset = MyDataset(batch_num * batch_size) - + dist_strategy = Strategy() + dist_strategy.auto_mode = "semi" + # sharding config + sharding = dist_strategy.sharding + sharding.enable = True + sharding.sharding_degree = 2 + sharding.stage = 3 + sharding.enable_tuning = True + sharding.tuning_range = [0, 1, 2, 3] # Tuning configuration - tuning_config = { - "batch_size": batch_size, - "dataset": dataset, - "profile_start_step": 1, - "profile_end_step": 5, - "run_after_tuning": True, - "sharding": { - "stage_range": [0, 1, 2, 3] - }, - "verbose": True, - } + tuning = dist_strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 5 + tuning.run_after_tuning = True + tuning.verbose = True + + dataset = MyDataset(batch_num * batch_size) engine = Engine(mlp, - inputs_spec=inputs_spec, - labels_spec=labels_spec, - strategy=dist_strategy, - user_tuning_config=tuning_config) - engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + loss, + optimizer, + paddle.metric.Accuracy(), + strategy=dist_strategy) + engine._tune(dataset, batch_size=batch_size) # check tuned - assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] != - 3) + assert (engine._dist_contexts['train'].strategy.sharding.stage != 3) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py new file mode 100644 index 0000000000000..462bf4f4a179c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto + +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed.auto_parallel.strategy import Strategy +from paddle.distributed.auto_parallel.engine import Engine +from get_gpt_model import generate_model, create_data_holder, FakeDataset + + +def apply_pass(use_recompute=False): + strategy = Strategy() + strategy.auto_mode = "semi" + if use_recompute: + recompute = strategy.recompute + recompute.enable = True + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestRecomputePass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-6 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_recompute=False): + reset_prog() + + strategy = apply_pass(use_recompute) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + + engine = Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_recompute_pass(self): + # mp2 training + mp_engine = self.get_engine() + mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + mp_losses = np.array(mp_losses) + + # mp2 recompute training + rc_engine = self.get_engine(True) + rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc_losses = np.array(rc_losses) + self.check_results(mp_losses, rc_losses) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py new file mode 100644 index 0000000000000..dbb5f2fc2c79a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto + +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed.auto_parallel.strategy import Strategy +from paddle.distributed.auto_parallel.engine import Engine +from get_gpt_model import generate_model, create_data_holder, FakeDataset + +paddle.enable_static() + + +def apply_pass(use_sharding=False, stage=None): + strategy = Strategy() + strategy.auto_mode = "semi" + if use_sharding: + sharding = strategy.sharding + sharding.enable = True + sharding.sharding_degree = 2 + sharding.stage = 1 + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestShardingPass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-6 + self.atol = 1e-8 + self.batch_size = 2 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_sharding=False, stage=None): + reset_prog() + + strategy = apply_pass(use_sharding, stage) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("dp") + + engine = Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_sharding_pass(self): + # dp2 training + dp_engine = self.get_engine() + dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + dp_losses = np.array(dp_losses) + + # sharding2 stage1 training + sharding1_engine = self.get_engine(True, 1) + sharding1_losses = sharding1_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + sharding1_losses = np.array(sharding1_losses) + self.check_results(dp_losses, sharding1_losses) + + # sharding2 stage2 training + sharding2_engine = self.get_engine(True, 2) + sharding2_losses = sharding2_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + sharding2_losses = np.array(sharding2_losses) + self.check_results(dp_losses, sharding2_losses) + + # sharding2 stage3 training + sharding3_engine = self.get_engine(True, 3) + sharding3_losses = sharding3_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + sharding3_losses = np.array(sharding3_losses) + self.check_results(dp_losses, sharding3_losses) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py index 2632ea96e01c9..bd26e90a2b825 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py @@ -78,9 +78,9 @@ def init_optimizer(self): def test_lr_scheduler(self): self.init_engine() - lr = self.engine.optimizer._learning_rate - assert isinstance(lr, paddle.optimizer.lr.LRScheduler) self.engine.fit(self.dataset, batch_size=self.batch_size) + lr = self.engine._lr_optimizer._learning_rate + assert isinstance(lr, paddle.optimizer.lr.LRScheduler) class TestGradClipByGlobalNorm(TestEngineBase): diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_amp.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_amp.py new file mode 100644 index 0000000000000..ed2cf0328e85c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_amp.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestAMPPass(unittest.TestCase): + + def test_mp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "amp_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_grad_clip.py similarity index 100% rename from python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py rename to python/paddle/fluid/tests/unittests/auto_parallel/test_pass_grad_clip.py diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_gradient_merge.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_gradient_merge.py new file mode 100644 index 0000000000000..e55ddbea58336 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_gradient_merge.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestGradientMergePass(unittest.TestCase): + + def test_dp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, + "gradient_merge_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_quantization.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_quantization.py new file mode 100644 index 0000000000000..1ab751200be15 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_quantization.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto + +from paddle.distributed.auto_parallel.strategy import Strategy +from paddle.distributed.auto_parallel.engine import Engine +from get_gpt_model import generate_model, create_data_holder, FakeDataset + +paddle.enable_static() + + +def apply_pass(): + dist_strategy = Strategy() + dist_strategy.auto_mode = "semi" + qat = dist_strategy.qat + qat.enable = True + qat.channel_wise_abs_max = True + qat.weight_bits = 8 + qat.activation_bits = 8 + qat.not_quant_pattern = ['skip_quant'] + return dist_strategy + + +class TestQuantizationPass(unittest.TestCase): + + def test_qat_pass(self): + + batch_size = 8 + batch_num = 10 + + strategy = apply_pass() + model, loss = generate_model("serial") + opt = paddle.optimizer.AdamW(learning_rate=0.00001) + engine = Engine(model, loss, opt, strategy=strategy) + dataset = FakeDataset(batch_size * batch_num) + engine.fit(dataset, 3, batch_size=batch_size) + + self.check_program(engine.main_program) + + def check_program(self, program): + + quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']} + quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']} + + quantized_ops = set() + for block in program.blocks: + for op in block.ops: + is_quntized = False + if op.type in quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + if ".quantized" in arg_name: + is_quntized = True + + if not is_quntized: + continue + + # check forward + if op.type in quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + assert arg_name.endswith('.quantized.dequantized') + quantized_ops.add(arg_name) + + for op in block.ops: + is_quntized = False + if op.type in quantizable_grad_op_inputs: + for pname in quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + if ".quantized" in arg_name: + is_quntized = True + + if not is_quntized: + continue + + # check backward + if op.type in quantizable_grad_op_inputs: + for pname in quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + assert arg_name.endswith('.quantized.dequantized') + assert arg_name in quantized_ops + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_recompute.py new file mode 100644 index 0000000000000..e7eb7ddd2a604 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_recompute.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestRecomputePass(unittest.TestCase): + + def test_mp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "recompute_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_sharding.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_sharding.py new file mode 100644 index 0000000000000..77e969c83bf81 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_sharding.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestShardingPass(unittest.TestCase): + + def test_dp2sharding2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "sharding_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py deleted file mode 100644 index f84ee03e0c940..0000000000000 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import sys -import numpy as np -import paddle - -import paddle.distributed.fleet as fleet -import paddle.distributed.auto_parallel as auto - -from paddle.distributed.auto_parallel.engine import Engine -from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr - -sys.path.append("..") -import auto_parallel_gpt_model as modeling -from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion - -paddle.enable_static() - - -class FakeDataset: - - def __init__(self, num_samples, sequence_len, vocab_size): - self.num_samples = num_samples - self.sequence_len = sequence_len - self.vocab_size = vocab_size - - def __getitem__(self, idx): - tokens = np.random.randint(self.vocab_size, size=self.sequence_len) - position_ids = np.arange(self.sequence_len) - attention_mask = np.tril(np.ones(self.sequence_len)).reshape( - (1, self.sequence_len, self.sequence_len)).astype(np.float32) - labels = np.random.randint(self.vocab_size, size=self.sequence_len) - loss_mask = np.ones(self.sequence_len).astype(np.float32) - return tokens, position_ids, attention_mask, labels, loss_mask - - def __len__(self): - return self.num_samples - - -def apply_pass(): - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True - dist_strategy.qat = True - dist_strategy.qat_configs = { - 'channel_wise_abs_max': True, - 'weight_bits': 8, - 'activation_bits': 8, - 'not_quant_pattern': ['skip_quant'], - } - return dist_strategy - - -def create_data_holder(batch_size, sequence_len): - tokens = paddle.static.InputSpec(name="tokens", - shape=[batch_size, sequence_len], - dtype='int64') - position_ids = paddle.static.InputSpec(name="position_ids", - shape=[batch_size, sequence_len], - dtype='int64') - attention_mask = paddle.static.InputSpec( - name="attention_mask", - shape=[batch_size, 1, sequence_len, sequence_len], - dtype='float32') - labels = paddle.static.InputSpec(name="labels", - shape=[batch_size, sequence_len], - dtype='int64') - loss_mask = paddle.static.InputSpec(name="loss_mask", - shape=[batch_size, sequence_len], - dtype='float32') - return [tokens, position_ids, attention_mask], [labels, loss_mask] - - -def get_gpt_model(): - modeling.init_global() - modeling._global_parallel_strategy = "serial" - modeling._global_process_mesh = auto.ProcessMesh(mesh=[0]) - - gpt = GPTModel(vocab_size=1000, - hidden_size=64, - num_hidden_layers=2, - num_attention_heads=8, - intermediate_size=256, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - max_position_embeddings=1024, - type_vocab_size=1, - initializer_range=0.02, - pad_token_id=0, - eos_token_id=7, - bos_token_id=0, - eol_token_id=3) - model = GPTForPretraining(gpt, - vocab_size=1000, - hidden_size=64, - initializer_range=0.02) - criterion = GPTPretrainingCriterion() - return model, criterion - - -class TestQuantizationPass(unittest.TestCase): - - def test_qat_pass(self): - - batch_size = 8 - batch_num = 10 - sequence_len = 512 - vocab_size = 1000 - - strategy = apply_pass() - model, loss = get_gpt_model() - opt = paddle.optimizer.AdamW(learning_rate=0.00001) - inputs_spec, labels_spec = create_data_holder(batch_size=batch_size, - sequence_len=sequence_len) - - engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) - engine.prepare(optimizer=opt, loss=loss) - - dataset = FakeDataset(batch_size * batch_num, sequence_len, vocab_size) - engine.fit(train_data=dataset, batch_size=batch_size) - - self.check_program(engine.main_program) - - def check_program(self, program): - - quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']} - quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']} - - quantized_ops = set() - for block in program.blocks: - for op in block.ops: - is_quntized = False - if op.type in quantizable_op_and_inputs: - for arg_name in op.input_arg_names: - if ".quantized" in arg_name: - is_quntized = True - - if not is_quntized: - continue - - # check forward - if op.type in quantizable_op_and_inputs: - for arg_name in op.input_arg_names: - assert arg_name.endswith('.quantized.dequantized') - quantized_ops.add(arg_name) - - for op in block.ops: - is_quntized = False - if op.type in quantizable_grad_op_inputs: - for pname in quantizable_grad_op_inputs[op.type]: - arg_name = op.input(pname)[0] - if ".quantized" in arg_name: - is_quntized = True - - if not is_quntized: - continue - - # check backward - if op.type in quantizable_grad_op_inputs: - for pname in quantizable_grad_op_inputs[op.type]: - arg_name = op.input(pname)[0] - assert arg_name.endswith('.quantized.dequantized') - assert arg_name in quantized_ops - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 5434bbe6f90e1..030fd095946cc 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -24,47 +24,65 @@ def test_default_config(self): strategy = Strategy() recompute = strategy.recompute - self.assertEqual(recompute.enabled, False) + self.assertEqual(recompute.enable, False) self.assertEqual(recompute.checkpoints, None) amp = strategy.amp - self.assertEqual(amp.enabled, False) + self.assertEqual(amp.enable, False) self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.decr_every_n_nan_or_inf, 2) self.assertAlmostEqual(amp.incr_ratio, 2.0) self.assertAlmostEqual(amp.decr_ratio, 0.8) self.assertEqual(amp.use_dynamic_loss_scaling, True) - self.assertEqual(amp.custom_black_list, None) - self.assertEqual(amp.custom_white_list, None) - self.assertEqual(amp.custom_black_varnames, None) + self.assertEqual(amp.custom_black_list, []) + self.assertEqual(amp.custom_white_list, []) + self.assertEqual(amp.custom_black_varnames, []) self.assertEqual(amp.use_pure_fp16, False) self.assertEqual(amp.use_fp16_guard, True) self.assertEqual(amp.use_optimizer_fp16, False) sharding = strategy.sharding - self.assertEqual(sharding.enabled, False) + self.assertEqual(sharding.enable, False) self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.sharding_degree, 8) self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0) self.assertEqual(sharding.enable_tuning, False) + self.assertEqual(sharding.tuning_range, []) gradient_merge = strategy.gradient_merge - self.assertEqual(gradient_merge.enabled, False) + self.assertEqual(gradient_merge.enable, False) self.assertEqual(gradient_merge.k_steps, 1) self.assertEqual(gradient_merge.avg, True) + qat = strategy.qat + self.assertEqual(qat.enable, False) + self.assertEqual(qat.channel_wise_abs_max, True) + self.assertEqual(qat.weight_bits, 8) + self.assertEqual(qat.activation_bits, 8) + self.assertEqual(qat.not_quant_pattern, ['skip_quant']) + self.assertEqual(qat.algo, None) + + tuning = strategy.tuning + self.assertEqual(tuning.enable, False) + self.assertEqual(tuning.batch_size, 1) + self.assertEqual(tuning.dataset, None) + self.assertEqual(tuning.profile_start_step, 1) + self.assertEqual(tuning.profile_end_step, 1) + self.assertEqual(tuning.run_after_tuning, True) + self.assertEqual(tuning.verbose, True) + def test_modify_config(self): strategy = Strategy() recompute = strategy.recompute - recompute.enabled = True - recompute.checkpoinits = ["x"] - self.assertEqual(recompute.enabled, True) - self.assertEqual(recompute.checkpoinits, ["x"]) + recompute.enable = True + recompute.checkpoints = ["x"] + self.assertEqual(recompute.enable, True) + self.assertEqual(recompute.checkpoints, ["x"]) amp = strategy.amp - amp.enabled = True + amp.enable = True amp.init_loss_scaling = 16384.0 amp.incr_every_n_steps = 2000 amp.decr_every_n_nan_or_inf = 4 @@ -77,7 +95,7 @@ def test_modify_config(self): amp.use_pure_fp16 = True amp.use_fp16_guard = False amp.use_optimizer_fp16 = True - self.assertEqual(amp.enabled, True) + self.assertEqual(amp.enable, True) self.assertAlmostEqual(amp.init_loss_scaling, 16384.0) self.assertEqual(amp.incr_every_n_steps, 2000) self.assertEqual(amp.decr_every_n_nan_or_inf, 4) @@ -92,27 +110,30 @@ def test_modify_config(self): self.assertEqual(amp.use_optimizer_fp16, True) sharding = strategy.sharding - sharding.enabled = True + sharding.enable = True sharding.stage = 2 sharding.sharding_degree = 2 sharding.segment_broadcast_MB = 64.0 sharding.enable_tuning = True - self.assertEqual(sharding.enabled, True) + sharding.tuning_range = [1, 2, 3] + self.assertEqual(sharding.enable, True) self.assertEqual(sharding.stage, 2) self.assertEqual(sharding.sharding_degree, 2) self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0) self.assertEqual(sharding.enable_tuning, True) + self.assertEqual(sharding.tuning_range, [1, 2, 3]) gradient_merge = strategy.gradient_merge - gradient_merge.enabled = True + gradient_merge.enable = True gradient_merge.k_steps = 4 gradient_merge.avg = False - self.assertEqual(gradient_merge.enabled, True) + self.assertEqual(gradient_merge.enable, True) self.assertEqual(gradient_merge.k_steps, 4) self.assertEqual(gradient_merge.avg, False) def test_file_config(self): yaml_data = """ + all_ranks: false amp: custom_black_list: - y @@ -122,7 +143,7 @@ def test_file_config(self): - x decr_every_n_nan_or_inf: 4 decr_ratio: 0.4 - enabled: true + enable: false incr_every_n_steps: 2000 incr_ratio: 4.0 init_loss_scaling: 16384.0 @@ -133,17 +154,40 @@ def test_file_config(self): auto_mode: semi gradient_merge: avg: false - enabled: true + enable: false k_steps: 4 + gradient_scale: true + qat: + activation_bits: 8 + algo: null + channel_wise_abs_max: true + enable: false + not_quant_pattern: + - skip_quant + weight_bits: 8 recompute: checkpoints: null - enabled: true + enable: false + enable_tuning: false + return_numpy: true + seed: null sharding: + enable: false enable_tuning: true - enabled: true segment_broadcast_MB: 64.0 - sharding_degree: 2 + sharding_degree: 8 stage: 2 + tuning_range: None + split_data: false + tuning: + batch_size: 1 + dataset: null + enable: false + profile_end_step: 1 + profile_start_step: 1 + run_after_tuning: true + verbose: true + use_cache: true """ yaml_path = "./strategy.yml" yaml_dict = yaml.load(yaml_data, Loader=yaml.Loader) @@ -151,7 +195,7 @@ def test_file_config(self): yaml.dump(yaml_dict, outfile, default_flow_style=False) strategy = Strategy(yaml_path) - + print(strategy) self.assertEqual(yaml_dict, strategy.to_dict()) # Remove the created file diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py index 3dabe38ff6e1d..1c869813d319b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py @@ -36,7 +36,7 @@ epoch_num = 10 hidden_size = 1024 sequence_len = 512 -_g_process_mesh = [[0, 1], [2, 3]] +_g_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) def get_random_inputs_and_labels(input_shape, label_shape): @@ -84,18 +84,12 @@ def __init__(self, def forward(self, input): out = self.norm(input) - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.linear0.weight, _g_process_mesh[:, 0], + [None, 'x']) out = self.linear0(out) out = F.gelu(out, approximate=True) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _g_process_mesh[1], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear1.weight, _g_process_mesh[:, 1], + ['x', None]) out = self.linear1(out) return out @@ -155,16 +149,8 @@ def get_program(): dataloader.set_batch_generator(batch_generator_creator(), places=paddle.static.cuda_places()) # data dist_attr - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [-1, -1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(input, _g_process_mesh[:, 0], [None, None, None]) + auto.shard_tensor(label, _g_process_mesh[:, 0], [None, None, None]) mlp_start = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py index cf2d605035a4c..444e0df454d96 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py @@ -37,7 +37,7 @@ epoch_num = 10 hidden_size = 1024 sequence_len = 512 -_g_process_mesh = auto.ProcessMesh([0, 1]) +_g_process_mesh = auto.ProcessMesh([0, 1], dim_names=['x']) def get_random_inputs_and_labels(input_shape, label_shape): @@ -85,61 +85,21 @@ def __init__(self, def forward(self, input): - auto.shard_tensor(self.norm.weight, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) - auto.shard_tensor(self.norm.bias, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.linear0.bias, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [0] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(self.linear1.bias, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(self.norm.weight, _g_process_mesh, [None]) + auto.shard_tensor(self.norm.bias, _g_process_mesh, [None]) + auto.shard_tensor(self.linear0.weight, _g_process_mesh, [None, 'x']) + auto.shard_tensor(self.linear0.bias, _g_process_mesh, ['x']) + auto.shard_tensor(self.linear1.weight, _g_process_mesh, ['x', None]) + auto.shard_tensor(self.linear1.bias, _g_process_mesh, [None]) out = self.norm(input) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, None]) out = self.linear0(out) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, 0] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, 'x']) out = F.gelu(out, approximate=True) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, 0] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, 'x']) out = self.linear1(out) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, None]) return out @@ -155,21 +115,13 @@ def get_program(): # 循环计数器 i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) - auto.shard_tensor(i, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(i, _g_process_mesh, [None]) # 循环次数 loop_len = fluid.layers.fill_constant(shape=[1], dtype='int64', value=epoch_num) - auto.shard_tensor(loop_len, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(loop_len, _g_process_mesh, [None]) # input input = static.data(name="input", @@ -188,25 +140,13 @@ def get_program(): dataloader.set_batch_generator(batch_generator_creator(), places=paddle.static.cuda_places()) # data dist_attr - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(input, _g_process_mesh, [None, None, None]) + auto.shard_tensor(label, _g_process_mesh, [None, None, None]) # fill constant bsz like tmp = paddle.fluid.layers.fill_constant_batch_size_like( input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0) - auto.shard_tensor(tmp, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, 0, -1, -1] - }) + auto.shard_tensor(tmp, _g_process_mesh, [None, 'x', None, None]) # model mlp_start = MLPLayer(hidden_size=hidden_size, @@ -224,21 +164,13 @@ def get_program(): # }) cond = fluid.layers.less_than(x=i, y=loop_len) - auto.shard_tensor(cond, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(cond, _g_process_mesh, [None]) while_op = fluid.layers.While(cond=cond) with while_op.block(): pre_input = fluid.layers.array_read(array=input_array, i=i) - auto.shard_tensor(pre_input, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(pre_input, _g_process_mesh, [None, None, None]) mlp_while = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -252,11 +184,7 @@ def get_program(): fluid.layers.less_than(x=i, y=loop_len, cond=cond) end_pred = fluid.layers.array_read(array=input_array, i=i) - auto.shard_tensor(end_pred, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(end_pred, _g_process_mesh, [None, None, None]) mlp_end = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -265,18 +193,10 @@ def get_program(): pred = mlp_end(end_pred) error_cost = paddle.nn.functional.square_error_cost(pred, label) - auto.shard_tensor(error_cost, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(error_cost, _g_process_mesh, [None, None, None]) loss = paddle.mean(error_cost) - auto.shard_tensor(loss, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(loss, _g_process_mesh, [None]) return train_program, start_program, dataloader, i, loss diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 39862130e0b6d..e7f721dd422cf 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -726,6 +726,10 @@ def forward(self, masked_positions=None, use_cache=False, cache=None): + input_ids.stop_gradient = True + position_ids.stop_gradient = True + attention_mask.stop_gradient = True + outputs = self.gpt(input_ids, position_ids=position_ids, attention_mask=attention_mask, @@ -739,40 +743,42 @@ def forward(self, x = encoder_outputs w = self.gpt.embeddings.word_embeddings.weight - mesh = _global_process_mesh - x_dims_mapping = [-1 for i in range(len(x.shape))] - w_dims_mapping = [-1 for i in range(len(w.shape))] + mesh = None if _global_parallel_strategy == "pp": mesh = PP_MESH_LIST[-1] + x_dims_mapping = [None for i in range(len(x.shape))] + w_dims_mapping = [None for i in range(len(w.shape))] elif _global_parallel_strategy == "dp": - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + mesh = _global_process_mesh + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = [None for i in range(len(w.shape))] elif _global_parallel_strategy == "mp": - w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] + mesh = _global_process_mesh + x_dims_mapping = [None for i in range(len(x.shape))] + w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)] elif _global_parallel_strategy == "dp_mp": - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] - w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] + mesh = _global_process_mesh + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)] elif _global_parallel_strategy == "dp_pp": mesh = DPPP_MESH_LIST[-1] - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = [None for i in range(len(w.shape))] elif _global_parallel_strategy == "mp_pp": mesh = MPPP_MESH_LIST[-1] - w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] + x_dims_mapping = [None for i in range(len(x.shape))] + w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)] elif _global_parallel_strategy == "dp_mp_pp": mesh = DPMPPP_MESH_LIST[-1] - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] - w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] - - matmul = auto.shard_op(paddle.matmul, - dist_attr={ - 'process_mesh': mesh, - x: { - "dims_mapping": x_dims_mapping - }, - w: { - "dims_mapping": w_dims_mapping - } - }) - logits = matmul(x, w, transpose_y=True) + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)] + + if mesh: + matmul = auto.shard_op(paddle.matmul, mesh, + [x_dims_mapping, w_dims_mapping, None]) + logits = matmul(x, w, transpose_y=True) + else: + logits = paddle.matmul(x, w, transpose_y=True) if use_cache: return logits, cached_kvs @@ -791,25 +797,29 @@ def __init__(self): self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") def forward(self, prediction_scores, masked_lm_labels, loss_mask): + masked_lm_labels.stop_gradient = True + loss_mask.stop_gradient = True - mesh = _global_process_mesh - dims_mapping = [-1 for i in range(len(loss_mask.shape))] + mesh = None if _global_parallel_strategy == "dp": - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + mesh = _global_process_mesh + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] elif _global_parallel_strategy == "dp_mp": - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + mesh = _global_process_mesh + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] elif _global_parallel_strategy == "dp_pp": mesh = DPPP_MESH_LIST[-1] - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] elif _global_parallel_strategy == "dp_mp_pp": mesh = DPMPPP_MESH_LIST[-1] - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] - auto.shard_tensor(loss_mask, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": dims_mapping - }) + if mesh: + auto.shard_tensor(loss_mask, mesh, dims_mapping) masked_lm_loss = self.loss_func(prediction_scores, masked_lm_labels.unsqueeze(2)) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index f4a6c66274ec6..3091a927a8224 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -86,7 +86,7 @@ def _run_gpu_main(self, model, apply_pass, dump_file, **kwargs): paddle.static.Program()): with paddle.static.scope_guard(scope): with paddle.fluid.unique_name.guard(): - main_prog, startup_prog, inputs, outputs, reader = self.get_model( + main_prog, startup_prog, inputs, outputs, data_loader = self.get_model( place, **kwargs) inputs = self._to_var_names(inputs) outputs = self._to_var_names(outputs) @@ -95,20 +95,48 @@ def _run_gpu_main(self, model, apply_pass, dump_file, **kwargs): exe = paddle.static.Executor(place) with paddle.static.scope_guard(scope): exe.run(startup_prog) - for batch_id, input_data in enumerate(reader()): - assert len(input_data) == len(inputs), "{} vs {}".format( - len(input_data), len(inputs)) - feed = dict(zip(inputs, input_data)) - fetch_values = exe.run(main_prog, feed=feed, fetch_list=outputs) - if paddle.distributed.get_rank() == 0: - output_dict = OrderedDict(zip(outputs, fetch_values)) - print('batch {}, outputs {}'.format(batch_id, output_dict)) - all_fetch_values.append(fetch_values) + data_loader.start() + batch_id = 0 + while True: + try: + fetch_values = exe.run(main_prog, fetch_list=outputs) + if paddle.distributed.get_rank() == 0: + output_dict = OrderedDict(zip(outputs, fetch_values)) + print('batch {}, outputs {}'.format( + batch_id, output_dict)) + all_fetch_values.append(fetch_values) + batch_id += 1 + except paddle.fluid.core.EOFException: + data_loader.reset() + break with open(dump_file, "wb") as f: pickle.dump(all_fetch_values, f) def get_gpt_model(self, strategy, place, batch_size, sequence_len, vocab_size, **kwargs): + + def gen_data(): + np.random.seed(2021) + for _ in range(10): + tokens = [] + position_ids = [] + attention_mask = [] + labels = [] + loss_mask = [] + for _ in range(batch_size): + tokens.append( + np.random.randint(vocab_size, + size=sequence_len).astype("int64")) + position_ids.append(np.arange(sequence_len).astype("int64")) + attention_mask.append( + [np.tril(np.ones(sequence_len)).astype("float32")]) + labels.append( + np.random.randint(vocab_size, + size=sequence_len).astype("int64")) + loss_mask.append(np.ones(sequence_len).astype("float32")) + + yield tokens, position_ids, attention_mask, labels, loss_mask + modeling.init_global() if strategy == "dp": modeling._global_parallel_strategy = "dp" @@ -139,6 +167,10 @@ def get_gpt_model(self, strategy, place, batch_size, sequence_len, dtype='float32') data_holder = [tokens, position_ids, attention_mask, labels, loss_mask] + data_loader = paddle.fluid.io.DataLoader.from_generator( + feed_list=data_holder, capacity=70, iterable=False) + data_loader.set_batch_generator(gen_data, paddle.static.cuda_places()) + if modeling._global_parallel_strategy == "dp": auto.shard_tensor(tokens, modeling._global_process_mesh, ["x", None]) @@ -170,40 +202,21 @@ def get_gpt_model(self, strategy, place, batch_size, sequence_len, preds = model(tokens, position_ids, attention_mask) criterion = GPTPretrainingCriterion() loss = criterion(preds, labels, loss_mask) - clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) + clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) if kwargs.get('optimizer', None) == "LarsMomentum": optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( learning_rate=0.001, momentum=0.9) else: - optimizer = paddle.fluid.optimizer.AdamOptimizer( - learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=clip) + optimizer = paddle.optimizer.Adam(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=clip) optimizer = fleet.distributed_optimizer(optimizer) startup_program = paddle.static.default_startup_program() _, _, dist_startup_prog, dist_main_prog = optimizer.minimize( loss, startup_program) - def gen_data(): - np.random.seed(2021) - for _ in range(10): - tokens = [] - position_ids = [] - attention_mask = [] - labels = [] - loss_mask = [] - for _ in range(batch_size): - tokens.append( - np.random.randint(vocab_size, size=sequence_len)) - position_ids.append(np.arange(sequence_len)) - attention_mask.append([np.tril(np.ones(sequence_len))]) - labels.append( - np.random.randint(vocab_size, size=sequence_len)) - loss_mask.append(np.ones(sequence_len)) - - yield tokens, position_ids, attention_mask, labels, loss_mask - - return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data + return dist_main_prog, dist_startup_prog, data_holder, [loss + ], data_loader diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py index 5ac78cc5fec4d..4c20153ccbfd9 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py @@ -20,10 +20,19 @@ import paddle import paddle.distributed.fleet as fleet from auto_parallel_pass_test_base import AutoPallelPassTestBase -from test_auto_parallel_amp_pass import TestAMPPass -class TestPF16Pass(TestAMPPass): +class TestPF16Pass(AutoPallelPassTestBase): + + def init(self): + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + self.rtol = 1e-5 + self.atol = 1e-8 + + paddle.seed(2021) + random.seed(2021) + np.random.seed(2021) def apply_passes(self): dist_strategy = fleet.DistributedStrategy() @@ -34,14 +43,30 @@ def apply_passes(self): 'layer_norm', 'gelu', ], - "custom_black_list": ['c_softmax_with_cross_entropy'], - "init_loss_scaling": 32768, - "use_dynamic_loss_scaling": True, - "use_pure_fp16": True + "custom_black_list": + ['c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'], + "init_loss_scaling": + 32768, + "use_dynamic_loss_scaling": + True, + "use_pure_fp16": + True, + "use_fp16_guard": + False } dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) + def test_bs_8(self): + self.check_main(gpus=[0, 1], + batch_size=8, + sequence_len=512, + vocab_size=1000) + + def get_model(self, place, batch_size, sequence_len, vocab_size): + return self.get_gpt_model("mp", place, batch_size, sequence_len, + vocab_size) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py index 7b9f587c906b0..8f45b67090e93 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py @@ -157,6 +157,12 @@ def test_result(self): def get_model(self, place, batch_size, hidden_size, max_step): + def gen_data(): + for i in range(max_step): + x_data = input_data[i * batch_size:(i + 1) * batch_size, :] + y_data = label_data[i * batch_size:(i + 1) * batch_size, :] + yield x_data, y_data + train_program = static.Program() startup_program = static.Program() with static.program_guard(train_program, startup_program), \ @@ -168,6 +174,12 @@ def get_model(self, place, batch_size, hidden_size, max_step): shape=[batch_size, 1], dtype='float32') input.stop_gradient = False + data_holder = [input, label] + data_loader = paddle.fluid.io.DataLoader.from_generator( + feed_list=data_holder, capacity=70, iterable=False) + data_loader.set_batch_generator(gen_data, + paddle.static.cuda_places()) + loss = mlp_forward(input, label, hidden_size) optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01) @@ -178,13 +190,8 @@ def get_model(self, place, batch_size, hidden_size, max_step): input_data = np.random.random(size=(128, hidden_size)).astype('float32') label_data = np.random.random(size=(128, 1)).astype('float32') - def reader(): - for i in range(max_step): - x_data = input_data[i * batch_size:(i + 1) * batch_size, :] - y_data = label_data[i * batch_size:(i + 1) * batch_size, :] - yield x_data, y_data - - return dist_main_prog, dist_startup_prog, [input, label], [loss], reader + return dist_main_prog, dist_startup_prog, [input, + label], [loss], data_loader if __name__ == "__main__":