Skip to content

Commit

Permalink
record parameter names in each param group
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#4954

For each parameter group, also record the parameter names. Users of `reduce_param_groups()` can use them such as printing out parameter names as in D2GO in D45855436

Differential Revision: D45855434

fbshipit-source-id: 676113f6b1e44cb7d75b57ecbc0629e9a0d4ac45
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed May 19, 2023
1 parent 2c6c380 commit b4ddb71
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions detectron2/solver/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,13 @@ def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
ret = defaultdict(dict)
for item in params:
assert "params" in item
cur_params = {x: y for x, y in item.items() if x != "params"}
for param in item["params"]:
ret[param].update({"params": [param], **cur_params})
cur_params = {x: y for x, y in item.items() if x != "params" and x != "param_names"}
if "param_names" in item:
for param_name, param in zip(item["param_names"], item["params"]):
ret[param].update({"param_names": [param_name], "params": [param], **cur_params})
else:
for param in item["params"]:
ret[param].update({"params": [param], **cur_params})
return list(ret.values())


Expand All @@ -257,12 +261,21 @@ def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
params = _expand_param_groups(params)
groups = defaultdict(list) # re-group all parameter groups by their hyperparams
for item in params:
cur_params = tuple((x, y) for x, y in item.items() if x != "params")
groups[cur_params].extend(item["params"])
cur_params = tuple((x, y) for x, y in item.items() if x != "params" and x != "param_names")
groups[cur_params].append({"params": item["params"]})
if "param_names" in item:
groups[cur_params][-1]["param_names"] = item["param_names"]

ret = []
for param_keys, param_values in groups.items():
cur = {kv[0]: kv[1] for kv in param_keys}
cur["params"] = param_values
cur["params"] = list(
itertools.chain.from_iterable([params["params"] for params in param_values])
)
if len(param_values) > 0 and "param_names" in param_values[0]:
cur["param_names"] = list(
itertools.chain.from_iterable([params["param_names"] for params in param_values])
)
ret.append(cur)
return ret

Expand Down

0 comments on commit b4ddb71

Please sign in to comment.