diff --git a/CUB/train.py b/CUB/train.py index 1f88b05..843e1b8 100644 --- a/CUB/train.py +++ b/CUB/train.py @@ -163,7 +163,7 @@ def train(model, args): attr_criterion = None if args.optimizer == 'Adam': - optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=arg.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'RMSprop': optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) else: