Skip to content

Commit

Permalink
Move scheduler step position after optimizer step following pytorch c…
Browse files Browse the repository at this point in the history
…onvention. And other minor formatting issue.
  • Loading branch information
zhmiao committed Feb 11, 2020
1 parent d904d35 commit 01e52ed
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
22 changes: 11 additions & 11 deletions run_networks.py
Expand Up @@ -14,7 +14,7 @@
class model ():

def __init__(self, config, data, test=False):

self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.config = config
self.training_opt = self.config['training_opt']
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, config, data, test=False):
os.remove(self.log_file)

def init_models(self, optimizer=True):

networks_defs = self.config['networks']
self.networks = {}
self.model_optim_params_list = []
Expand All @@ -79,9 +79,9 @@ def init_models(self, optimizer=True):
# Optimizer list
optim_params = val['optim_params']
self.model_optim_params_list.append({'params': self.networks[key].parameters(),
'lr': optim_params['lr'],
'momentum': optim_params['momentum'],
'weight_decay': optim_params['weight_decay']})
'lr': optim_params['lr'],
'momentum': optim_params['momentum'],
'weight_decay': optim_params['weight_decay']})

def init_criterions(self):

Expand Down Expand Up @@ -188,12 +188,6 @@ def train(self):

torch.cuda.empty_cache()

# Set model modes and set scheduler
# In training, step optimizer scheduler and set model to train()
self.model_optimizer_scheduler.step()
if self.criterion_optimizer:
self.criterion_optimizer_scheduler.step()

# Iterate over dataset
for step, (inputs, labels, _) in enumerate(self.data['train']):

Expand Down Expand Up @@ -234,6 +228,12 @@ def train(self):
% (minibatch_acc)]
print_write(print_str, self.log_file)

# Set model modes and set scheduler
# In training, step optimizer scheduler and set model to train()
self.model_optimizer_scheduler.step()
if self.criterion_optimizer:
self.criterion_optimizer_scheduler.step()

# After every epoch, validation
self.eval(phase='val')

Expand Down
2 changes: 1 addition & 1 deletion utils.py
Expand Up @@ -33,7 +33,7 @@ def init_weights(model, weights_path, caffe=False, classifier=False):
"""Initialize weights"""
print('Pretrained %s weights path: %s' % ('classifier' if classifier else 'feature model',
weights_path))
weights = torch.load(weights_path)
weights = torch.load(weights_path)
if not classifier:
if caffe:
weights = {k: weights[k] if k in weights else model.state_dict()[k]
Expand Down

0 comments on commit 01e52ed

Please sign in to comment.