Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 15 additions & 28 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode

try:
from apex import amp
except ImportError:
amp = None


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False, model_ema=None):
def train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
Expand All @@ -29,13 +26,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)

optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if amp:
with torch.cuda.amp.autocast():
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss = criterion(output, target)
loss.backward()
optimizer.step()

Expand Down Expand Up @@ -156,12 +156,6 @@ def load_data(traindir, valdir, args):


def main(args):
if args.apex and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)

if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -228,8 +222,7 @@ def main(args):
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))

if args.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
scaler = torch.cuda.amp.GradScaler() if args.amp else None

args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "steplr":
Expand Down Expand Up @@ -292,7 +285,9 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
Expand Down Expand Up @@ -385,15 +380,7 @@ def get_args_parser(add_help=True):
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

# Mixed precision training parameters
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
parser.add_argument(
"--apex-opt-level",
default="O1",
type=str,
help="For apex mixed precision training"
"O0 for FP32 training, O1 for mixed precision training."
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
)
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
Expand Down