From 74482f010e3a248a65581c5c36fca881ad40025b Mon Sep 17 00:00:00 2001 From: rsomani95 Date: Thu, 31 Oct 2019 17:38:19 +0530 Subject: [PATCH 1/3] Generalised for custom dataset --- references/video_classification/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 74852c2f721..44595e72313 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -116,8 +116,8 @@ def main(args): # Data loading code print("Loading data") - traindir = os.path.join(args.data_path, 'train_avi-480p') - valdir = os.path.join(args.data_path, 'val_avi-480p') + traindir = os.path.join(args.data_path, args.train_dir) + valdir = os.path.join(args.data_path, args.valid_dir) normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]) @@ -203,6 +203,8 @@ def main(args): print("Creating model") model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) + if args.output_classes is not None: + model.fc.out_features = args.output_classes model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -274,7 +276,10 @@ def parse_args(): parser = argparse.ArgumentParser(description='PyTorch Classification Training') parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') + parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') + parser.add_argument('--val-dir', default='val_avi-480p', help='name of val dir') parser.add_argument('--model', default='r2plus1d_18', help='model') + parser.add_argument('--output-classes', default=None, help='no. of output classes (if finetuning)') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip') From 85e3f8df2bc335ffe11c08aaeac205df06599850 Mon Sep 17 00:00:00 2001 From: rsomani95 Date: Mon, 4 Nov 2019 17:47:00 +0530 Subject: [PATCH 2/3] Typo, redundant code, sensible default --- references/video_classification/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 44595e72313..1a577f72255 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -117,7 +117,7 @@ def main(args): # Data loading code print("Loading data") traindir = os.path.join(args.data_path, args.train_dir) - valdir = os.path.join(args.data_path, args.valid_dir) + valdir = os.path.join(args.data_path, args.val_dir) normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]) @@ -203,8 +203,7 @@ def main(args): print("Creating model") model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) - if args.output_classes is not None: - model.fc.out_features = args.output_classes + model.fc.out_features = args.output_classes model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -279,7 +278,7 @@ def parse_args(): parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') parser.add_argument('--val-dir', default='val_avi-480p', help='name of val dir') parser.add_argument('--model', default='r2plus1d_18', help='model') - parser.add_argument('--output-classes', default=None, help='no. of output classes (if finetuning)') + parser.add_argument('--output-classes', default=400, help='no. of output classes (if finetuning)') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip') From fdcbb5254c3b49a27d9500730a191a09bdf2f126 Mon Sep 17 00:00:00 2001 From: rsomani95 Date: Tue, 26 Nov 2019 22:39:59 +0530 Subject: [PATCH 3/3] Args for name of train and val dir --- references/video_classification/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 1a577f72255..b51b424c629 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -203,7 +203,6 @@ def main(args): print("Creating model") model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) - model.fc.out_features = args.output_classes model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -278,7 +277,6 @@ def parse_args(): parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') parser.add_argument('--val-dir', default='val_avi-480p', help='name of val dir') parser.add_argument('--model', default='r2plus1d_18', help='model') - parser.add_argument('--output-classes', default=400, help='no. of output classes (if finetuning)') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip')