From 1b0b680efc8684e027681e32cb7af5b6d99370ff Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Jan 2022 21:10:21 +0000 Subject: [PATCH] Adding missing named param check on ViT --- torchvision/prototype/models/vision_transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index f5b6bfff790..1cd186a2d82 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -11,7 +11,7 @@ from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "VisionTransformer", @@ -111,6 +111,9 @@ def _vision_transformer( ) -> VisionTransformer: image_size = kwargs.pop("image_size", 224) + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = VisionTransformer( image_size=image_size, patch_size=patch_size,