Skip to content

Generalising Video Classification Training Script (references/video_classification/train.py) #1540

@rsomani95

Description

@rsomani95

The references/video_classification/train.py script can be generalised for finetuning on any dataset with just a few changes.

3 additional arguments would be needed:

parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir')
parser.add_argument('--val-dir', default='val_avi-480p', help='name of val dir')
parser.add_argument('--output-classes', default=None, help='no. of output classes (if finetuning)')

Some minimal changes to the script:

Modifying lines 119-120

traindir = os.path.join(args.data_path, args.train_dir)
valdir = os.path.join(args.data_path, args.valid_dir)

One additional line after line 205:

#line 205
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)

#additional chunk for fine-tuning
if args.output_classes is not None:
    model.fc.out_features = args.output_classes

#line 206
model.to(device)

If this makes sense to you guys, I'll be happy to put in a PR for it.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions