Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7 from zhaoyinglia/new_api_valid
Browse files Browse the repository at this point in the history
[Auto Parallel] Add valid after training
  • Loading branch information
aoyulong committed Sep 14, 2022
2 parents 14bce35 + dd26467 commit 28d6af8
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 183 deletions.
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "seed", None)
set_field_default_config(
BASE, "reinit", False
) # Only for debug, and must be set with seed at the same time when the value is True.

#########################################
# recompute configuration
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self,
self._backup_serial_main_program_stack = []
self._backup_serial_startup_program_stack = []

# # flag whether scale gradient with dp size
# flag whether scale gradient with dp size
self._gradient_scale = True

# A flag indicates whether the used parallelism is data parallel
Expand Down
110 changes: 56 additions & 54 deletions python/paddle/distributed/auto_parallel/dist_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def __init__(self):

def save(self, path, serial_program, dist_main_program, dist_context):

def _save_state(program, path, mode="param"):
state = {
k: np.array(v)
for k, v in program.state_dict(mode).items()
}
with open(path, "wb") as f:
pickle.dump(state, f)

dirname, filename = _process_path(path)

rank_id = paddle.distributed.get_rank()
Expand All @@ -76,82 +84,76 @@ def save(self, path, serial_program, dist_main_program, dist_context):
with open(dist_model_path, "wb") as f:
f.write(dist_main_program.desc.serialize_to_string())

# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
dist_param = {
k: np.array(v)
for k, v in dist_main_program.state_dict().items()
}
with open(dist_param_path, "wb") as f:
pickle.dump(dist_param, f)

# save distributed attribute
dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr"
dist_attr_path = os.path.join(dirname, dist_attr_filename)
dist_attrs = get_dist_attr(dist_main_program, dist_context)
with open(dist_attr_path, "wb") as f:
pickle.dump(dist_attrs, f)

# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
_save_state(dist_main_program, dist_param_path)

# save distributed opt states
dist_opt_filename = filename + "_dist" + str(rank_id) + ".pdopt"
dist_opt_path = os.path.join(dirname, dist_opt_filename)
_save_state(dist_main_program, dist_opt_path, "opt")

# TODO:save cluster.json

def load(self,
path,
program,
dist_context,
strict=True,
load_optimizer=True):
def load(self, path, load_optimizer=True):
# TODO: if `program` is None, load `path.pdmodel`.
def _load_file(filename, dirname, suffix="pdparams"):
file_list = []
for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix),
file):
file_list.append(os.path.join(dirname, file))
file_list.sort()
return file_list

def _load_state(filename, dirname, suffix="pdparams"):
file_list = _load_file(filename, dirname, suffix)
state_dict = {}
for file in file_list:
with open(file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in state_dict:
state_dict[name].append(np.array(value))
else:
state_dict[name] = [np.array(value)]
self._logger.info("Load param file: {}".format(file_list))
return state_dict

filename = os.path.basename(path)
if filename == "":
raise ValueError(
"path should be of 'dirname/filename' format, but received filename is empty string"
)
dirname = os.path.dirname(path)
# load path.pdparam
param_file_list = []
for param_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdparams'.format(filename),
param_file):
param_file_list.append(os.path.join(dirname, param_file))
param_file_list.sort()
self._logger.info(
"Load distributed attribute file: {}".format(param_file_list))
param_dict = {}
for param_file in param_file_list:
with open(param_file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in param_dict:
param_dict[name].append(np.array(value))
else:
param_dict[name] = [np.array(value)]

# load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname,
"pdopt") if load_optimizer else {}
state_dict = dict(param_state_dict, **opt_state_dict)

# load path.pdattr
dist_attr_file_list = []
for dist_attr_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdattr'.format(filename),
dist_attr_file):
dist_attr_file_list.append(os.path.join(dirname,
dist_attr_file))
dist_attr_file_list.sort()
dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list))
pre_dist_attr = {}
dist_attr = {}
for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f:
dist_attr = pickle.load(f, encoding='latin1')
for name, attr in dist_attr.items():
if name not in pre_dist_attr:
pre_dist_attr[name] = attr

# get current dist_attr
cur_dist_attr = get_dist_attr(program, dist_context)

# param convert
converter = Converter(param_dict, pre_dist_attr, cur_dist_attr)
param_dict = converter.convert(strict=strict)
program.set_state_dict(param_dict)
dist_attr_info = pickle.load(f, encoding='latin1')
for name, attr in dist_attr_info.items():
if name not in dist_attr:
dist_attr[name] = attr

return state_dict, dist_attr

def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs):

Expand Down
Loading

0 comments on commit 28d6af8

Please sign in to comment.