Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from coco_utils import get_coco_api_from_dataset


def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
Expand All @@ -27,10 +27,9 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

loss_dict = model(images, targets)

losses = sum(loss for loss in loss_dict.values())
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())

# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
Expand All @@ -44,8 +43,13 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
sys.exit(1)

optimizer.zero_grad()
losses.backward()
optimizer.step()
if scaler is not None:
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
else:
losses.backward()
optimizer.step()

if lr_scheduler is not None:
lr_scheduler.step()
Expand Down
11 changes: 10 additions & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def get_args_parser(add_help=True):
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

return parser


Expand Down Expand Up @@ -209,6 +212,8 @@ def main(args):
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

scaler = torch.cuda.amp.GradScaler() if args.amp else None

args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "multisteplr":
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
Expand All @@ -225,6 +230,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])

if args.test_only:
evaluate(model, data_loader_test, device=device)
Expand All @@ -235,7 +242,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
lr_scheduler.step()
if args.output_dir:
checkpoint = {
Expand All @@ -245,6 +252,8 @@ def main(args):
"args": args,
"epoch": epoch,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

Expand Down