diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index f5b6bfff790..1cd186a2d82 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -11,7 +11,7 @@ from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "VisionTransformer", @@ -111,6 +111,9 @@ def _vision_transformer( ) -> VisionTransformer: image_size = kwargs.pop("image_size", 224) + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = VisionTransformer( image_size=image_size, patch_size=patch_size,