From 11a076cb242c169d05cd151180a70b79f1250f51 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 4 Jul 2022 17:02:04 +0100 Subject: [PATCH 1/2] Update the dataset cache to factor in parameters from the args. --- references/video_classification/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index a746470be9b..abfd64afc79 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -98,10 +98,10 @@ def evaluate(model, criterion, data_loader, device): return metric_logger.acc1.global_avg -def _get_cache_path(filepath): +def _get_cache_path(filepath, args): import hashlib - - h = hashlib.sha1(filepath.encode()).hexdigest() + value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}" + h = hashlib.sha1(value.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) return cache_path @@ -135,7 +135,7 @@ def main(args): print("Loading training data") st = time.time() - cache_path = _get_cache_path(traindir) + cache_path = _get_cache_path(traindir, args) transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): @@ -167,7 +167,7 @@ def main(args): print("Took", time.time() - st) print("Loading validation data") - cache_path = _get_cache_path(valdir) + cache_path = _get_cache_path(valdir, args) if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) From cb1f462c64f2fea76a2e89971edd80298f2615fe Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 4 Jul 2022 17:38:09 +0100 Subject: [PATCH 2/2] Fix linter --- references/video_classification/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index abfd64afc79..016e6024886 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -100,6 +100,7 @@ def evaluate(model, criterion, data_loader, device): def _get_cache_path(filepath, args): import hashlib + value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}" h = hashlib.sha1(value.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")