diff --git a/eval.py b/eval.py index 7c3fef1..c9e67e3 100644 --- a/eval.py +++ b/eval.py @@ -48,7 +48,7 @@ def get_args(): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) -model = models.get_model(args.model, n_cls, half_prec, data.shapes_dict[args.dataset], args.n_filters_cnn, args.n_hidden_fc) +model = models.get_model(args.model, n_cls, half_prec, data.shapes_dict[args.dataset], args.n_filters_cnn) model = model.cuda() model_dict = torch.load('models/{}.pth'.format(args.model_path)) if args.early_stopped_model: