diff --git a/references/video_classification/train.py b/references/video_classification/train.py index a746470be9b..016e6024886 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -98,10 +98,11 @@ def evaluate(model, criterion, data_loader, device): return metric_logger.acc1.global_avg -def _get_cache_path(filepath): +def _get_cache_path(filepath, args): import hashlib - h = hashlib.sha1(filepath.encode()).hexdigest() + value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}" + h = hashlib.sha1(value.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) return cache_path @@ -135,7 +136,7 @@ def main(args): print("Loading training data") st = time.time() - cache_path = _get_cache_path(traindir) + cache_path = _get_cache_path(traindir, args) transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): @@ -167,7 +168,7 @@ def main(args): print("Took", time.time() - st) print("Loading validation data") - cache_path = _get_cache_path(valdir) + cache_path = _get_cache_path(valdir, args) if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights)