Skip to content

Commit

Permalink
Make Optimizer.load_state_dict use __setstate__
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke committed Feb 26, 2017
1 parent 1f6f82d commit bd7a5ad
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torch/optim/optimizer.py
Expand Up @@ -66,6 +66,9 @@ def __getstate__(self):
'param_groups': self.param_groups,
}

def __setstate__(self, state):
self.__dict__.update(state)

def state_dict(self):
"""Returns the state of the optimizer as a :class:`dict`.
Expand Down Expand Up @@ -115,14 +118,15 @@ def load_state_dict(self, state_dict):
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
self.state = {id_map.get(k, k): v for k, v in state_dict['state'].items()}
state = {id_map.get(k, k): v for k, v in state_dict['state'].items()}

# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
self.param_groups = [
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

def zero_grad(self):
"""Clears the gradients of all optimized :class:`Variable` s."""
Expand Down

0 comments on commit bd7a5ad

Please sign in to comment.