Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes):
return confmat


def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq):
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
header = f"Epoch: [{epoch}]"
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()

lr_scheduler.step()

Expand Down Expand Up @@ -153,6 +159,8 @@ def main(args):
params_to_optimize.append({"params": params, "lr": args.lr * 10})
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

scaler = torch.cuda.amp.GradScaler() if args.amp else None

iters_per_epoch = len(data_loader)
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
Expand Down Expand Up @@ -186,6 +194,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])

if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
Expand All @@ -196,7 +206,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
checkpoint = {
Expand All @@ -206,6 +216,8 @@ def main(args):
"epoch": epoch,
"args": args,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

Expand Down Expand Up @@ -269,6 +281,9 @@ def get_args_parser(add_help=True):
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

return parser


Expand Down