Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 19 additions & 35 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,13 @@
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler

try:
from apex import amp
except ImportError:
amp = None


try:
from torchvision.prototype import models as PM
except ImportError:
PM = None


def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
Expand All @@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
for video, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
video, target = video.to(device), target.to(device)
output = model(video)
loss = criterion(output, target)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(video)
loss = criterion(output, target)

optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()

if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.step()

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = video.shape[0]
Expand Down Expand Up @@ -101,11 +98,6 @@ def collate_fn(batch):
def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.apex and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)

if args.output_dir:
utils.mkdir(args.output_dir)
Expand Down Expand Up @@ -224,9 +216,7 @@ def main(args):

lr = args.lr * args.world_size
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)

if args.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
scaler = torch.cuda.amp.GradScaler() if args.amp else None

# convert scheduler to be per iteration, not per epoch, for warmup that lasts
# between different epochs
Expand Down Expand Up @@ -267,6 +257,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])

if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
Expand All @@ -277,9 +269,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(
model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex
)
train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
Expand All @@ -289,6 +279,8 @@ def main(args):
"epoch": epoch,
"args": args,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

Expand Down Expand Up @@ -363,24 +355,16 @@ def parse_args():
action="store_true",
)

# Mixed precision training parameters
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
parser.add_argument(
"--apex-opt-level",
default="O1",
type=str,
help="For apex mixed precision training"
"O0 for FP32 training, O1 for mixed precision training."
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

args = parser.parse_args()

return args
Expand Down