diff --git a/models/dist_model.py b/models/dist_model.py index e92d3f9..abe9696 100755 --- a/models/dist_model.py +++ b/models/dist_model.py @@ -56,7 +56,10 @@ class DistModel(BaseModel): if not use_gpu: kw['map_location'] = 'cpu' if(model_path is None): - model_path = './weights/%s.pth'%net + #model_path = './weights/%s.pth'%net + import inspect + model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), \ + '..', '..', 'weights/%s.pth' % net)) self.net.load_state_dict(torch.load(model_path, **kw)) elif(self.model=='net'): # pretrained network