From 3fd24e37235a9cc14bab3c669ee92c0074a26949 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 25 Jul 2019 08:03:45 -0700 Subject: [PATCH 1/9] Copy classification scripts for video classification --- references/video_classification/train.py | 299 +++++++++++++++++++++++ references/video_classification/utils.py | 255 +++++++++++++++++++ 2 files changed, 554 insertions(+) create mode 100644 references/video_classification/train.py create mode 100644 references/video_classification/utils.py diff --git a/references/video_classification/train.py b/references/video_classification/train.py new file mode 100644 index 00000000000..c26e4ae290f --- /dev/null +++ b/references/video_classification/train.py @@ -0,0 +1,299 @@ +from __future__ import print_function +import datetime +import os +import time +import sys + +import torch +import torch.utils.data +from torch import nn +import torchvision +from torchvision import transforms + +import utils + +try: + from apex import amp +except ImportError: + amp = None + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) + metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) + + header = 'Epoch: [{}]'.format(epoch) + for image, target in metric_logger.log_every(data_loader, print_freq, header): + start_time = time.time() + image, target = image.to(device), target.to(device) + output = model(image) + loss = criterion(output, target) + + optimizer.zero_grad() + if apex: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + optimizer.step() + + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + batch_size = image.shape[0] + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) + + +def evaluate(model, criterion, data_loader, device): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + with torch.no_grad(): + for image, target in metric_logger.log_every(data_loader, 100, header): + image = image.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + output = model(image) + loss = criterion(output, target) + + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + # FIXME need to take into account that the datasets + # could have been padded in distributed setup + batch_size = image.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + + print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) + return metric_logger.acc1.global_avg + + +def _get_cache_path(filepath): + import hashlib + h = hashlib.sha1(filepath.encode()).hexdigest() + cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") + cache_path = os.path.expanduser(cache_path) + return cache_path + + +def main(args): + if args.apex: + if sys.version_info < (3, 0): + raise RuntimeError("Apex currently only supports Python 3. Aborting.") + if amp is None: + raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " + "to enable mixed-precision training.") + + if args.output_dir: + utils.mkdir(args.output_dir) + + utils.init_distributed_mode(args) + print(args) + + device = torch.device(args.device) + + torch.backends.cudnn.benchmark = True + + # Data loading code + print("Loading data") + traindir = os.path.join(args.data_path, 'train') + valdir = os.path.join(args.data_path, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + print("Loading training data") + st = time.time() + cache_path = _get_cache_path(traindir) + if args.cache_dataset and os.path.exists(cache_path): + # Attention, as the transforms are also cached! + print("Loading dataset_train from {}".format(cache_path)) + dataset, _ = torch.load(cache_path) + else: + dataset = torchvision.datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + if args.cache_dataset: + print("Saving dataset_train to {}".format(cache_path)) + utils.mkdir(os.path.dirname(cache_path)) + utils.save_on_master((dataset, traindir), cache_path) + print("Took", time.time() - st) + + print("Loading validation data") + cache_path = _get_cache_path(valdir) + if args.cache_dataset and os.path.exists(cache_path): + # Attention, as the transforms are also cached! + print("Loading dataset_test from {}".format(cache_path)) + dataset_test, _ = torch.load(cache_path) + else: + dataset_test = torchvision.datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + if args.cache_dataset: + print("Saving dataset_test to {}".format(cache_path)) + utils.mkdir(os.path.dirname(cache_path)) + utils.save_on_master((dataset_test, valdir), cache_path) + + print("Creating data loaders") + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + else: + train_sampler = torch.utils.data.RandomSampler(dataset) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, + sampler=train_sampler, num_workers=args.workers, pin_memory=True) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=args.batch_size, + sampler=test_sampler, num_workers=args.workers, pin_memory=True) + + print("Creating model") + model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) + model.to(device) + if args.distributed and args.sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + criterion = nn.CrossEntropyLoss() + + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + if args.apex: + model, optimizer = amp.initialize(model, optimizer, + opt_level=args.apex_opt_level + ) + + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + + if args.test_only: + evaluate(model, criterion, data_loader_test, device=device) + return + + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex) + lr_scheduler.step() + evaluate(model, criterion, data_loader_test, device=device) + if args.output_dir: + checkpoint = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args} + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'checkpoint.pth')) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='PyTorch Classification Training') + + parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') + parser.add_argument('--model', default='resnet18', help='model') + parser.add_argument('--device', default='cuda', help='device') + parser.add_argument('-b', '--batch-size', default=32, type=int) + parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + help='number of data loading workers (default: 16)') + parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') + parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') + parser.add_argument('--print-freq', default=10, type=int, help='print frequency') + parser.add_argument('--output-dir', default='.', help='path where to save') + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument( + "--cache-dataset", + dest="cache_dataset", + help="Cache the datasets for quicker initialization. It also serializes the transforms", + action="store_true", + ) + parser.add_argument( + "--sync-bn", + dest="sync_bn", + help="Use sync batch norm", + action="store_true", + ) + parser.add_argument( + "--test-only", + dest="test_only", + help="Only test the model", + action="store_true", + ) + parser.add_argument( + "--pretrained", + dest="pretrained", + help="Use pre-trained models from the modelzoo", + action="store_true", + ) + + # Mixed precision training parameters + parser.add_argument('--apex', action='store_true', + help='Use apex for mixed precision training') + parser.add_argument('--apex-opt-level', default='O1', type=str, + help='For apex mixed precision training' + 'O0 for FP32 training, O1 for mixed precision training.' + 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' + ) + + # distributed training parameters + parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py new file mode 100644 index 00000000000..5ea6dfef341 --- /dev/null +++ b/references/video_classification/utils.py @@ -0,0 +1,255 @@ +from __future__ import print_function +from collections import defaultdict, deque +import datetime +import time +import torch +import torch.distributed as dist + +import errno +import os + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {}'.format(header, total_time_str)) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target[None]) + + res = [] + for k in topk: + correct_k = correct[:k].flatten().sum(dtype=torch.float32) + res.append(correct_k * (100.0 / batch_size)) + return res + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + setup_for_distributed(args.rank == 0) From 6c89d048f2fe6387a9b32df92412c247ce0519c5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 29 Jul 2019 06:06:11 -0700 Subject: [PATCH 2/9] Initial version of video classification --- references/video_classification/sampler.py | 90 +++++++++++++ references/video_classification/scheduler.py | 48 +++++++ references/video_classification/train.py | 95 ++++++++++---- references/video_classification/transforms.py | 122 ++++++++++++++++++ torchvision/datasets/kinetics.py | 9 +- torchvision/datasets/video_utils.py | 21 ++- torchvision/io/video.py | 2 +- 7 files changed, 355 insertions(+), 32 deletions(-) create mode 100644 references/video_classification/sampler.py create mode 100644 references/video_classification/scheduler.py create mode 100644 references/video_classification/transforms.py diff --git a/references/video_classification/sampler.py b/references/video_classification/sampler.py new file mode 100644 index 00000000000..062a2e91057 --- /dev/null +++ b/references/video_classification/sampler.py @@ -0,0 +1,90 @@ +import math +import torch +from torch.utils.data import Sampler +import torch.distributed as dist +import torchvision.datasets.video_utils + + +class DistributedSampler(Sampler): + """ + Extension of DistributedSampler, as discussed in + https://github.com/pytorch/pytorch/issues/23430 + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + if isinstance(self.dataset, Sampler): + orig_indices = list(iter(self.dataset)) + indices = [orig_indices[i] for i in indices] + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class SequentialClipSampler(torch.utils.data.Sampler): + """ + Samples at most `max_video_clips_per_video` clips for each video, equally spaced + Arguments: + video_clips (VideoClips): video clips to sample from + max_clips_per_video (int): maximum number of clips to be sampled per video + """ + def __init__(self, video_clips, max_clips_per_video): + if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips): + raise TypeError("Expected video_clips to be an instance of VideoClips, " + "got {}".format(type(video_clips))) + self.video_clips = video_clips + self.max_clips_per_video = max_clips_per_video + + def __iter__(self): + idxs = [] + s = 0 + # select at most max_clips_per_video for each video, uniformly spaced + for c in self.video_clips.clips: + length = len(c) + step = max(length // self.max_clips_per_video, 1) + sampled = torch.arange(length)[::step] + s + s += length + idxs.append(sampled) + idxs = torch.cat(idxs).tolist() + return iter(idxs) + + def __len__(self): + return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips) diff --git a/references/video_classification/scheduler.py b/references/video_classification/scheduler.py new file mode 100644 index 00000000000..e41c3e119c5 --- /dev/null +++ b/references/video_classification/scheduler.py @@ -0,0 +1,48 @@ +import torch +from bisect import bisect_right + + +# TODO: Is there a warmup in the multistep LR code +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer, + milestones, + gamma=0.1, + warmup_factor=1.0 / 3, + warmup_iters=5, + warmup_method="linear", + last_epoch=-1, + ): + if not milestones == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", + milestones, + ) + + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + warmup_factor = 1 + if self.last_epoch < self.warmup_iters: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = float(self.last_epoch) / self.warmup_iters + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + return [ + base_lr + * warmup_factor + * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] diff --git a/references/video_classification/train.py b/references/video_classification/train.py index c26e4ae290f..3104a4b319b 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -8,9 +8,13 @@ import torch.utils.data from torch import nn import torchvision +import torchvision.datasets.video_utils from torchvision import transforms import utils +from sampler import DistributedSampler, SequentialClipSampler +from scheduler import WarmupMultiStepLR +import transforms as T try: from apex import amp @@ -18,17 +22,17 @@ amp = None -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False): +def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) + metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}')) header = 'Epoch: [{}]'.format(epoch) - for image, target in metric_logger.log_every(data_loader, print_freq, header): + for video, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() - image, target = image.to(device), target.to(device) - output = model(image) + video, target = video.to(device), target.to(device) + output = model(video) loss = criterion(output, target) optimizer.zero_grad() @@ -40,11 +44,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri optimizer.step() acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) - batch_size = image.shape[0] + batch_size = video.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) + lr_scheduler.step() def evaluate(model, criterion, data_loader, device): @@ -101,20 +106,32 @@ def main(args): # Data loading code print("Loading data") - traindir = os.path.join(args.data_path, 'train') - valdir = os.path.join(args.data_path, 'val') + traindir = os.path.join(args.data_path, 'train_avi-480p') + valdir = os.path.join(args.data_path, 'val_avi-480p') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) + dataset = torch.load("/private/home/fmassa/github/vision/kinetics_train.pth") + # dataset = torch.load("/private/home/fmassa/github/vision/kinetics_val.pth") + dataset.video_clips.compute_clips(16, 1, frame_rate=15) + dataset.transform = torchvision.transforms.Compose([ + T.ToFloatTensorInZeroOne(), + T.Resize((128, 171)), + T.RandomHorizontalFlip(), + T.Normalize((0.43216, 0.394666, 0.37645), (0.22803, 0.22145, 0.216989)), + T.RandomCrop((112, 112)) + ]) + + """ if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: - dataset = torchvision.datasets.ImageFolder( + dataset = torchvision.datasets.Kinetics( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), @@ -126,17 +143,29 @@ def main(args): print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) + """ print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) + dataset_test = torch.load("/private/home/fmassa/github/vision/kinetics_val.pth") + dataset_test.video_clips.compute_clips(16, 1, frame_rate=15) + dataset_test.transform = torchvision.transforms.Compose([ + T.ToFloatTensorInZeroOne(), + T.Resize((128, 171)), + T.Normalize((0.43216, 0.394666, 0.37645), (0.22803, 0.22145, 0.216989)), + T.CenterCrop((112, 112)) + ]) + """ if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) else: - dataset_test = torchvision.datasets.ImageFolder( + dataset_test = torchvision.datasets.Kinetics( valdir, + args.frames_per_clip, + 1, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), @@ -147,14 +176,13 @@ def main(args): print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) - + """ print("Creating data loaders") + train_sampler = torchvision.datasets.video_utils.RandomClipSampler(dataset.video_clips, args.clips_per_video) + test_sampler = SequentialClipSampler(dataset_test.video_clips, args.clips_per_video) if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) - else: - train_sampler = torch.utils.data.RandomSampler(dataset) - test_sampler = torch.utils.data.SequentialSampler(dataset_test) + train_sampler = DistributedSampler(train_sampler) + test_sampler = DistributedSampler(test_sampler) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, @@ -165,22 +193,30 @@ def main(args): sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) + # model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) + model = torchvision.models.video.__dict__[args.model]() model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() + lr = args.lr * args.world_size optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level ) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + # per iteration, not per epoch + warmup_epochs = 10 + warmup_iters = warmup_epochs * len(data_loader) + lr_milestones = [len(data_loader) * m for m in args.lr_milestones] + lr_scheduler = WarmupMultiStepLR( + optimizer, milestones=lr_milestones, gamma=args.lr_gamma, + warmup_iters=warmup_iters, warmup_factor=1e-5) model_without_ddp = model if args.distributed: @@ -203,8 +239,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex) - lr_scheduler.step() + train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { @@ -229,21 +264,25 @@ def parse_args(): import argparse parser = argparse.ArgumentParser(description='PyTorch Classification Training') - parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') - parser.add_argument('--model', default='resnet18', help='model') + parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') + parser.add_argument('--model', default='r2plus1d_18', help='model') parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('-b', '--batch-size', default=32, type=int) - parser.add_argument('--epochs', default=90, type=int, metavar='N', + parser.add_argument('--clip-len', default=8, type=int, metavar='N', + help='number of frames per clip') + parser.add_argument('--clips-per-video', default=5, type=int, metavar='N', + help='maximum number of clips per video to consider') + parser.add_argument('-b', '--batch-size', default=24, type=int) + parser.add_argument('--epochs', default=45, type=int, metavar='N', help='number of total epochs to run') - parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', help='number of data loading workers (default: 16)') - parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') + parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') - parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') + parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py new file mode 100644 index 00000000000..8adcd5cdda2 --- /dev/null +++ b/references/video_classification/transforms.py @@ -0,0 +1,122 @@ +import torch +import random + + +def crop(vid, i, j, h, w): + return vid[..., i:(i + h), j:(j + w)] + + +def center_crop(vid, output_size): + h, w = vid.shape[-2:] + th, tw = output_size + + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return crop(vid, i, j, th, tw) + + +def hflip(vid): + return vid.flip(dims=(-1,)) + + +# NOTE: for those functions, which generally expect mini-batches, we keep them +# as non-minibatch so that they are applied as if they were 4d (thus image). +# this way, we only apply the transformation in the spatial domain +def resize(vid, size, interpolation='bilinear'): + # NOTE: using bilinear interpolation because we don't work on minibatches + # at this level + scale = None + if isinstance(size, int): + scale = float(size) / min(vid.shape[-2:]) + size = None + return torch.nn.functional.interpolate( + vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False) + + +def pad(vid, padding, fill=0, padding_mode="constant"): + # NOTE: don't want to pad on temporal dimension, so let as non-batch + # (4d) before padding. This works as expected + return torch.nn.functional.pad(vid, padding, value=fill, mode=padding_mode) + + +def to_normalized_float_tensor(vid): + return vid.permute(3, 0, 1, 2).to(torch.float32) / 255 + + +def normalize(vid, mean, std): + shape = (-1,) + (1,) * (vid.dim() - 1) + mean = torch.as_tensor(mean).reshape(shape) + std = torch.as_tensor(std).reshape(shape) + return (vid - mean) / std + + +# Class interface + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + @staticmethod + def get_params(vid, output_size): + """Get parameters for ``crop`` for a random crop. + """ + h, w = vid.shape[-2:] + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, vid): + i, j, h, w = self.get_params(vid, self.size) + return crop(vid, i, j, h, w) + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, vid): + return center_crop(vid, self.size) + + +class Resize(object): + def __init__(self, size): + self.size = size + + def __call__(self, vid): + return resize(vid, self.size) + + +class ToFloatTensorInZeroOne(object): + def __call__(self, vid): + return to_normalized_float_tensor(vid) + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, vid): + return normalize(vid, self.mean, self.std) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, vid): + if random.random() < self.p: + return hflip(vid) + return vid + + +class Pad(object): + def __init__(self, padding, fill=0): + self.padding = padding + self.fill = fill + + def __call__(self, vid): + return pad(vid, self.padding, self.fill) diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 06dce5fcfaf..d45f25df6a4 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -5,7 +5,7 @@ class KineticsVideo(VisionDataset): - def __init__(self, root, frames_per_clip, step_between_clips=1): + def __init__(self, root, frames_per_clip, step_between_clips=1, transform=None): super(KineticsVideo, self).__init__(root) extensions = ('avi',) @@ -15,6 +15,7 @@ def __init__(self, root, frames_per_clip, step_between_clips=1): self.classes = classes video_list = [x[0] for x in self.samples] self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) + self.transform = transform def __len__(self): return self.video_clips.num_clips() @@ -23,4 +24,8 @@ def __getitem__(self, idx): video, audio, info, video_idx = self.video_clips.get_clip(idx) label = self.samples[video_idx][1] - return video, audio, label + if self.transform is not None: + video = self.transform(video) + + # return video, audio, label + return video, label diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index cb426b35e69..56981d78f5f 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -60,10 +60,29 @@ def _compute_frame_pts(self): self.video_pts = [] self.video_fps = [] # TODO maybe paralellize this - for video_file in self.video_paths: + from .utils import tqdm + class DS(object): + def __init__(self, x): + self.x = x + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return read_video_timestamps(self.x[idx]) + import torch.utils.data + dl = torch.utils.data.DataLoader(DS(self.video_paths), batch_size=32, num_workers=256, collate_fn=lambda x: x) + for batch in tqdm(dl): + clips, fps = list(zip(*batch)) + clips = [torch.as_tensor(c) for c in clips] + self.video_pts.extend(clips) + self.video_fps.extend(fps) + """ + for video_file in tqdm(self.video_paths): clips, fps = read_video_timestamps(video_file) self.video_pts.append(torch.as_tensor(clips)) self.video_fps.append(fps) + """ def _init_from_metadata(self, metadata): assert len(self.video_paths) == len(metadata["video_pts"]) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 383f539e9f6..86430ae05fa 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -95,7 +95,7 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): # TODO check if stream needs to always be the video stream here or not container.seek(seek_offset, any_frame=False, backward=True, stream=stream) except av.AVError: - print("Corrupted file?", container.name) + # print("Corrupted file?", container.name) return [] buffer_count = 0 for idx, frame in enumerate(container.decode(**stream_name)): From 824438cc59073d1124097bcd1a8e8849ef33a3e0 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 29 Jul 2019 07:19:16 -0700 Subject: [PATCH 3/9] add version --- references/video_classification/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 3104a4b319b..cb2406c297c 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -99,6 +99,8 @@ def main(args): utils.init_distributed_mode(args) print(args) + print("torch version: ", torch.__version__) + print("torchvision version: ", torchvision.__version__) device = torch.device(args.device) From 5d987089c328758c577d72d0c5cde2e4a68660ae Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 31 Jul 2019 02:11:12 -0700 Subject: [PATCH 4/9] Training of r2plus1d_18 on kinetics work Gives even slightly better results than expected, with 57.336 top1 clip accuracy. But we count some clips twice in this evaluation --- torchvision/io/video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 86430ae05fa..5c7e1897913 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -21,7 +21,7 @@ def _check_av_available(): # PyAV has some reference cycles _CALLED_TIMES = 0 -_GC_COLLECTION_INTERVAL = 20 +_GC_COLLECTION_INTERVAL = 10 def write_video(filename, video_array, fps, video_codec='libx264', options=None): From dd919c05c991da61aeef3519628d03e835fd3469 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 31 Jul 2019 02:57:36 -0700 Subject: [PATCH 5/9] Cleanups on training script --- references/video_classification/sampler.py | 2 +- references/video_classification/train.py | 72 ++++++++++------------ 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/references/video_classification/sampler.py b/references/video_classification/sampler.py index 062a2e91057..277911632d5 100644 --- a/references/video_classification/sampler.py +++ b/references/video_classification/sampler.py @@ -59,7 +59,7 @@ def set_epoch(self, epoch): self.epoch = epoch -class SequentialClipSampler(torch.utils.data.Sampler): +class UniformClipSampler(torch.utils.data.Sampler): """ Samples at most `max_video_clips_per_video` clips for each video, equally spaced Arguments: diff --git a/references/video_classification/train.py b/references/video_classification/train.py index cb2406c297c..f29bf8a22ff 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,7 +12,7 @@ from torchvision import transforms import utils -from sampler import DistributedSampler, SequentialClipSampler +from sampler import DistributedSampler, UniformClipSampler from scheduler import WarmupMultiStepLR import transforms as T @@ -73,7 +73,7 @@ def evaluate(model, criterion, data_loader, device): # gather the stats from all processes metric_logger.synchronize_between_processes() - print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' + print(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) return metric_logger.acc1.global_avg @@ -81,7 +81,7 @@ def evaluate(model, criterion, data_loader, device): def _get_cache_path(filepath): import hashlib h = hashlib.sha1(filepath.encode()).hexdigest() - cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") + cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) return cache_path @@ -110,78 +110,69 @@ def main(args): print("Loading data") traindir = os.path.join(args.data_path, 'train_avi-480p') valdir = os.path.join(args.data_path, 'val_avi-480p') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645], + std=[0.22803, 0.22145, 0.216989]) print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - dataset = torch.load("/private/home/fmassa/github/vision/kinetics_train.pth") - # dataset = torch.load("/private/home/fmassa/github/vision/kinetics_val.pth") - dataset.video_clips.compute_clips(16, 1, frame_rate=15) - dataset.transform = torchvision.transforms.Compose([ + transform_train = torchvision.transforms.Compose([ T.ToFloatTensorInZeroOne(), T.Resize((128, 171)), T.RandomHorizontalFlip(), - T.Normalize((0.43216, 0.394666, 0.37645), (0.22803, 0.22145, 0.216989)), + normalize, T.RandomCrop((112, 112)) ]) - """ if args.cache_dataset and os.path.exists(cache_path): - # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) + dataset.transform = transform_train else: - dataset = torchvision.datasets.Kinetics( + dataset = torchvision.datasets.KineticsVideo( traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) + frames_per_clip=args.clip_len, + step_between_clips=1, + transform=transform_train + ) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) - """ + dataset.video_clips.compute_clips(args.clip_len, 1, frame_rate=15) + print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) - dataset_test = torch.load("/private/home/fmassa/github/vision/kinetics_val.pth") - dataset_test.video_clips.compute_clips(16, 1, frame_rate=15) - dataset_test.transform = torchvision.transforms.Compose([ + + transform_test = torchvision.transforms.Compose([ T.ToFloatTensorInZeroOne(), T.Resize((128, 171)), - T.Normalize((0.43216, 0.394666, 0.37645), (0.22803, 0.22145, 0.216989)), + normalize, T.CenterCrop((112, 112)) ]) - """ + if args.cache_dataset and os.path.exists(cache_path): - # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) + dataset_test.transform = transform_test else: - dataset_test = torchvision.datasets.Kinetics( + dataset_test = torchvision.datasets.KineticsVideo( valdir, - args.frames_per_clip, - 1, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + frames_per_clip=args.clip_len, + step_between_clips=1, + transform=transform_test + ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) - """ + dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15) + print("Creating data loaders") train_sampler = torchvision.datasets.video_utils.RandomClipSampler(dataset.video_clips, args.clips_per_video) - test_sampler = SequentialClipSampler(dataset_test.video_clips, args.clips_per_video) + test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video) if args.distributed: train_sampler = DistributedSampler(train_sampler) test_sampler = DistributedSampler(test_sampler) @@ -212,9 +203,9 @@ def main(args): opt_level=args.apex_opt_level ) - # per iteration, not per epoch - warmup_epochs = 10 - warmup_iters = warmup_epochs * len(data_loader) + # convert scheduler to be per iteration, not per epoch, for warmup that lasts + # between different epochs + warmup_iters = args.lr_warmup_epochs * len(data_loader) lr_milestones = [len(data_loader) * m for m in args.lr_milestones] lr_scheduler = WarmupMultiStepLR( optimizer, milestones=lr_milestones, gamma=args.lr_gamma, @@ -286,6 +277,7 @@ def parse_args(): dest='weight_decay') parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') + parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') From 362bbe085d4152ec7457c90791d404737c425047 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 31 Jul 2019 05:26:45 -0700 Subject: [PATCH 6/9] Lint --- references/video_classification/sampler.py | 1 - references/video_classification/scheduler.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/references/video_classification/sampler.py b/references/video_classification/sampler.py index 277911632d5..c5a879ffa1a 100644 --- a/references/video_classification/sampler.py +++ b/references/video_classification/sampler.py @@ -37,7 +37,6 @@ def __iter__(self): else: indices = list(range(len(self.dataset))) - # add extra samples to make it evenly divisible indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size diff --git a/references/video_classification/scheduler.py b/references/video_classification/scheduler.py index e41c3e119c5..f0f862d41ad 100644 --- a/references/video_classification/scheduler.py +++ b/references/video_classification/scheduler.py @@ -2,7 +2,6 @@ from bisect import bisect_right -# TODO: Is there a warmup in the multistep LR code class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): def __init__( self, @@ -41,8 +40,8 @@ def get_lr(self): alpha = float(self.last_epoch) / self.warmup_iters warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ - base_lr - * warmup_factor - * self.gamma ** bisect_right(self.milestones, self.last_epoch) + base_lr * + warmup_factor * + self.gamma ** bisect_right(self.milestones, self.last_epoch) for base_lr in self.base_lrs ] From 6456fc9f89ee23664d04b1a3650a961cb5f78359 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 31 Jul 2019 05:26:56 -0700 Subject: [PATCH 7/9] Minor improvements --- references/video_classification/train.py | 13 +++++++--- torchvision/datasets/video_utils.py | 33 ++++++++++++++---------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index f29bf8a22ff..fa7fcc2d533 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -26,7 +26,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}')) + metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}')) header = 'Epoch: [{}]'.format(epoch) for video, target in metric_logger.log_every(data_loader, print_freq, header): @@ -48,7 +48,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) - metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) + metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time)) lr_scheduler.step() @@ -129,6 +129,9 @@ def main(args): dataset, _ = torch.load(cache_path) dataset.transform = transform_train else: + if args.distributed: + print("It is recommended to pre-compute the dataset cache " + "on a single-gpu first, as it will be faster") dataset = torchvision.datasets.KineticsVideo( traindir, frames_per_clip=args.clip_len, @@ -158,6 +161,9 @@ def main(args): dataset_test, _ = torch.load(cache_path) dataset_test.transform = transform_test else: + if args.distributed: + print("It is recommended to pre-compute the dataset cache " + "on a single-gpu first, as it will be faster") dataset_test = torchvision.datasets.KineticsVideo( valdir, frames_per_clip=args.clip_len, @@ -232,7 +238,8 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex) + train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, + device, epoch, args.print_freq, args.apex) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 56981d78f5f..1ebec7df1e9 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -4,6 +4,8 @@ import torch.utils.data from torchvision.io import read_video_timestamps, read_video +from .utils import tqdm + def unfold(tensor, size, step, dilation=1): """ @@ -59,8 +61,9 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1 def _compute_frame_pts(self): self.video_pts = [] self.video_fps = [] - # TODO maybe paralellize this - from .utils import tqdm + + # strategy: use a DataLoader to parallelize read_video_timestamps + # so need to create a dummy dataset first class DS(object): def __init__(self, x): self.x = x @@ -70,19 +73,21 @@ def __len__(self): def __getitem__(self, idx): return read_video_timestamps(self.x[idx]) + import torch.utils.data - dl = torch.utils.data.DataLoader(DS(self.video_paths), batch_size=32, num_workers=256, collate_fn=lambda x: x) - for batch in tqdm(dl): - clips, fps = list(zip(*batch)) - clips = [torch.as_tensor(c) for c in clips] - self.video_pts.extend(clips) - self.video_fps.extend(fps) - """ - for video_file in tqdm(self.video_paths): - clips, fps = read_video_timestamps(video_file) - self.video_pts.append(torch.as_tensor(clips)) - self.video_fps.append(fps) - """ + dl = torch.utils.data.DataLoader( + DS(self.video_paths), + batch_size=16, + num_workers=torch.get_num_threads(), + collate_fn=lambda x: x) + + with tqdm(total=len(dl)) as pbar: + for batch in dl: + pbar.update(1) + clips, fps = list(zip(*batch)) + clips = [torch.as_tensor(c) for c in clips] + self.video_pts.extend(clips) + self.video_fps.extend(fps) def _init_from_metadata(self, metadata): assert len(self.video_paths) == len(metadata["video_pts"]) From c2ccc6e553351ad638eb397be1a728b0128d77d2 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 31 Jul 2019 05:41:24 -0700 Subject: [PATCH 8/9] Remove some hacks --- references/video_classification/train.py | 23 ++++++++++++++++------- torchvision/datasets/kinetics.py | 3 +-- torchvision/io/video.py | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index fa7fcc2d533..a45357d2c43 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -6,6 +6,7 @@ import torch import torch.utils.data +from torch.utils.data.dataloader import default_collate from torch import nn import torchvision import torchvision.datasets.video_utils @@ -57,16 +58,16 @@ def evaluate(model, criterion, data_loader, device): metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' with torch.no_grad(): - for image, target in metric_logger.log_every(data_loader, 100, header): - image = image.to(device, non_blocking=True) + for video, target in metric_logger.log_every(data_loader, 100, header): + video = video.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - output = model(image) + output = model(video) loss = criterion(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets # could have been padded in distributed setup - batch_size = image.shape[0] + batch_size = video.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) @@ -86,6 +87,12 @@ def _get_cache_path(filepath): return cache_path +def collate_fn(batch): + # remove audio from the batch + batch = [(d[0], d[2]) for d in batch] + return default_collate(batch) + + def main(args): if args.apex: if sys.version_info < (3, 0): @@ -185,11 +192,13 @@ def main(args): data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True) + sampler=train_sampler, num_workers=args.workers, + pin_memory=True, collate_fn=collate_fn) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, - sampler=test_sampler, num_workers=args.workers, pin_memory=True) + sampler=test_sampler, num_workers=args.workers, + pin_memory=True, collate_fn=collate_fn) print("Creating model") # model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) @@ -267,7 +276,7 @@ def parse_args(): parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') parser.add_argument('--model', default='r2plus1d_18', help='model') parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('--clip-len', default=8, type=int, metavar='N', + parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip') parser.add_argument('--clips-per-video', default=5, type=int, metavar='N', help='maximum number of clips per video to consider') diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index d45f25df6a4..f7d3fbe89d7 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -27,5 +27,4 @@ def __getitem__(self, idx): if self.transform is not None: video = self.transform(video) - # return video, audio, label - return video, label + return video, audio, label diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 5c7e1897913..83afe726e43 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -95,6 +95,7 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): # TODO check if stream needs to always be the video stream here or not container.seek(seek_offset, any_frame=False, backward=True, stream=stream) except av.AVError: + # TODO add some warnings in this case # print("Corrupted file?", container.name) return [] buffer_count = 0 From 2c9624f42a6e56cf64d3e0eccda7fc4e7e7538c3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 31 Jul 2019 05:45:10 -0700 Subject: [PATCH 9/9] Lint --- references/video_classification/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 8adcd5cdda2..9435450c4b3 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -30,7 +30,7 @@ def resize(vid, size, interpolation='bilinear'): scale = float(size) / min(vid.shape[-2:]) size = None return torch.nn.functional.interpolate( - vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False) + vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False) def pad(vid, padding, fill=0, padding_mode="constant"):