diff --git a/run_networks.py b/run_networks.py index 3510783..2158672 100644 --- a/run_networks.py +++ b/run_networks.py @@ -14,7 +14,7 @@ class model (): def __init__(self, config, data, test=False): - + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.config = config self.training_opt = self.config['training_opt'] @@ -52,7 +52,7 @@ def __init__(self, config, data, test=False): os.remove(self.log_file) def init_models(self, optimizer=True): - + networks_defs = self.config['networks'] self.networks = {} self.model_optim_params_list = [] @@ -79,9 +79,9 @@ def init_models(self, optimizer=True): # Optimizer list optim_params = val['optim_params'] self.model_optim_params_list.append({'params': self.networks[key].parameters(), - 'lr': optim_params['lr'], - 'momentum': optim_params['momentum'], - 'weight_decay': optim_params['weight_decay']}) + 'lr': optim_params['lr'], + 'momentum': optim_params['momentum'], + 'weight_decay': optim_params['weight_decay']}) def init_criterions(self): @@ -188,12 +188,6 @@ def train(self): torch.cuda.empty_cache() - # Set model modes and set scheduler - # In training, step optimizer scheduler and set model to train() - self.model_optimizer_scheduler.step() - if self.criterion_optimizer: - self.criterion_optimizer_scheduler.step() - # Iterate over dataset for step, (inputs, labels, _) in enumerate(self.data['train']): @@ -234,6 +228,12 @@ def train(self): % (minibatch_acc)] print_write(print_str, self.log_file) + # Set model modes and set scheduler + # In training, step optimizer scheduler and set model to train() + self.model_optimizer_scheduler.step() + if self.criterion_optimizer: + self.criterion_optimizer_scheduler.step() + # After every epoch, validation self.eval(phase='val') diff --git a/utils.py b/utils.py index 8cc9d60..613a0aa 100644 --- a/utils.py +++ b/utils.py @@ -33,7 +33,7 @@ def init_weights(model, weights_path, caffe=False, classifier=False): """Initialize weights""" print('Pretrained %s weights path: %s' % ('classifier' if classifier else 'feature model', weights_path)) - weights = torch.load(weights_path) + weights = torch.load(weights_path) if not classifier: if caffe: weights = {k: weights[k] if k in weights else model.state_dict()[k]