diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 74852c2f721..b51b424c629 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -116,8 +116,8 @@ def main(args): # Data loading code print("Loading data") - traindir = os.path.join(args.data_path, 'train_avi-480p') - valdir = os.path.join(args.data_path, 'val_avi-480p') + traindir = os.path.join(args.data_path, args.train_dir) + valdir = os.path.join(args.data_path, args.val_dir) normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]) @@ -274,6 +274,8 @@ def parse_args(): parser = argparse.ArgumentParser(description='PyTorch Classification Training') parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') + parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') + parser.add_argument('--val-dir', default='val_avi-480p', help='name of val dir') parser.add_argument('--model', default='r2plus1d_18', help='model') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('--clip-len', default=16, type=int, metavar='N',