diff --git a/imagenet/main.py b/imagenet/main.py index 86abea3739..f0f196a1e5 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -112,7 +112,7 @@ def main(): batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) - # define loss function (criterion) and pptimizer + # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr,