diff --git a/docs/source/models/regnet.rst b/docs/source/models/regnet.rst new file mode 100644 index 00000000000..aef4abd2544 --- /dev/null +++ b/docs/source/models/regnet.rst @@ -0,0 +1,37 @@ +RegNet +====== + +.. currentmodule:: torchvision.models + +The RegNet model is based on the `Designing Network Design Spaces +`_ paper. + + +Model builders +-------------- + +The following model builders can be used to instantiate a RegNet model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.regnet.RegNet`` base class. Please refer to the `source code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + regnet_y_400mf + regnet_y_800mf + regnet_y_1_6gf + regnet_y_3_2gf + regnet_y_8gf + regnet_y_16gf + regnet_y_32gf + regnet_y_128gf + regnet_x_400mf + regnet_x_800mf + regnet_x_1_6gf + regnet_x_3_2gf + regnet_x_8gf + regnet_x_16gf + regnet_x_32gf diff --git a/docs/source/models/resnet.rst b/docs/source/models/resnet.rst index 8ab79fe885b..7976eb437e5 100644 --- a/docs/source/models/resnet.rst +++ b/docs/source/models/resnet.rst @@ -10,7 +10,7 @@ The ResNet model is based on the `Deep Residual Learning for Image Recognition Model builders -------------- -The following model builders can be used to instanciate a ResNet model, with or +The following model builders can be used to instantiate a ResNet model, with or without pre-trained weights. All the model builders internally rely on the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code diff --git a/docs/source/models/squeezenet.rst b/docs/source/models/squeezenet.rst index a3e8603e92d..9771e5c623a 100644 --- a/docs/source/models/squeezenet.rst +++ b/docs/source/models/squeezenet.rst @@ -11,7 +11,7 @@ paper. Model builders -------------- -The following model builders can be used to instanciate a SqueezeNet model, with or +The following model builders can be used to instantiate a SqueezeNet model, with or without pre-trained weights. All the model builders internally rely on the ``torchvision.models.squeezenet.SqueezeNet`` base class. Please refer to the `source code diff --git a/docs/source/models/vgg.rst b/docs/source/models/vgg.rst index 068bd330c8b..a9fa9aabfb1 100644 --- a/docs/source/models/vgg.rst +++ b/docs/source/models/vgg.rst @@ -10,7 +10,7 @@ Image Recognition `_ paper. Model builders -------------- -The following model builders can be used to instanciate a VGG model, with or +The following model builders can be used to instantiate a VGG model, with or without pre-trained weights. All the model buidlers internally rely on the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code `_ for diff --git a/docs/source/models/vision_transformer.rst b/docs/source/models/vision_transformer.rst new file mode 100644 index 00000000000..914caa9311e --- /dev/null +++ b/docs/source/models/vision_transformer.rst @@ -0,0 +1,28 @@ +VisionTransformer +================= + +.. currentmodule:: torchvision.models + +The VisionTransformer model is based on the `An Image is Worth 16x16 Words: +Transformers for Image Recognition at Scale `_ paper. + + +Model builders +-------------- + +The following model builders can be used to instantiate a VisionTransformer model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.vision_transformer.VisionTransformer`` base class. +Please refer to the `source code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + vit_b_16 + vit_b_32 + vit_l_16 + vit_l_32 + vit_h_14 diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index a8fdad8efb0..d512d917d65 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -36,9 +36,11 @@ weights: .. toctree:: :maxdepth: 1 + models/regnet models/resnet models/squeezenet models/vgg + models/vision_transformer Table of all available classification weights diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index f878bdd5754..821d86f11f0 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -861,11 +861,20 @@ class RegNet_X_32GF_Weights(WeightsEnum): def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_400MF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_400MF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_400MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_400MF_Weights + :members: """ weights = RegNet_Y_400MF_Weights.verify(weights) @@ -877,11 +886,20 @@ def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_800MF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_800MF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_800MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_800MF_Weights + :members: """ weights = RegNet_Y_800MF_Weights.verify(weights) @@ -893,11 +911,20 @@ def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_1.6GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_1_6GF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_1_6GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_1_6GF_Weights + :members: """ weights = RegNet_Y_1_6GF_Weights.verify(weights) @@ -911,11 +938,20 @@ def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_3.2GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_3_2GF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_3_2GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_3_2GF_Weights + :members: """ weights = RegNet_Y_3_2GF_Weights.verify(weights) @@ -929,11 +965,20 @@ def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_8GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_8GF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_8GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_8GF_Weights + :members: """ weights = RegNet_Y_8GF_Weights.verify(weights) @@ -947,11 +992,20 @@ def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bo def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_16GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_16GF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_16GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_16GF_Weights + :members: """ weights = RegNet_Y_16GF_Weights.verify(weights) @@ -965,11 +1019,20 @@ def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_32GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_32GF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_32GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_32GF_Weights + :members: """ weights = RegNet_Y_32GF_Weights.verify(weights) @@ -983,12 +1046,20 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_128GF architecture from - `"Designing Network Design Spaces" `_. - NOTE: Pretrained weights are not available for this model. + `Designing Network Design Spaces `_. Args: - weights (RegNet_Y_128GF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_Y_128GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_Y_128GF_Weights + :members: """ weights = RegNet_Y_128GF_Weights.verify(weights) @@ -1002,11 +1073,20 @@ def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_400MF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_X_400MF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_X_400MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_X_400MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_X_400MF_Weights + :members: """ weights = RegNet_X_400MF_Weights.verify(weights) @@ -1018,11 +1098,20 @@ def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_800MF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. Args: - weights (RegNet_X_800MF_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`torchvision.models.regnet.RegNet_X_800MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_X_800MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_X_800MF_Weights + :members: """ weights = RegNet_X_800MF_Weights.verify(weights) @@ -1034,7 +1123,20 @@ def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_1.6GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. + + Args: + weights (:class:`torchvision.models.regnet.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_X_1_6GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_X_1_6GF_Weights + :members: Args: weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model @@ -1050,7 +1152,20 @@ def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_3.2GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. + + Args: + weights (:class:`torchvision.models.regnet.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_X_3_2GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_X_3_2GF_Weights + :members: Args: weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model @@ -1066,7 +1181,20 @@ def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_8GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. + + Args: + weights (:class:`torchvision.models.regnet.RegNet_X_8GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_X_8GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_X_8GF_Weights + :members: Args: weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model @@ -1082,7 +1210,20 @@ def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bo def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_16GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. + + Args: + weights (:class:`torchvision.models.regnet.RegNet_X_16GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_X_16GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_X_16GF_Weights + :members: Args: weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model @@ -1098,7 +1239,20 @@ def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_32GF architecture from - `"Designing Network Design Spaces" `_. + `Designing Network Design Spaces `_. + + Args: + weights (:class:`torchvision.models.regnet.RegNet_X_32GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.regnet.RegNet_X_32GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.regnet.RegNet_X_32GF_Weights + :members: Args: weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 6d881080d04..f85404c4bde 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -490,11 +490,20 @@ class ViT_H_14_Weights(WeightsEnum): def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: - weights (ViT_B_16_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.vision_transformer.ViT_B_16_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.vision_transformer.ViT_B_16_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.vision_transformer.ViT_B_16_Weights + :members: """ weights = ViT_B_16_Weights.verify(weights) @@ -514,11 +523,20 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: - weights (ViT_B_32_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.vision_transformer.ViT_B_32_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.vision_transformer.ViT_B_32_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.vision_transformer.ViT_B_32_Weights + :members: """ weights = ViT_B_32_Weights.verify(weights) @@ -538,11 +556,20 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: - weights (ViT_L_16_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.vision_transformer.ViT_L_16_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.vision_transformer.ViT_L_16_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.vision_transformer.ViT_L_16_Weights + :members: """ weights = ViT_L_16_Weights.verify(weights) @@ -562,11 +589,20 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: - weights (ViT_L_32_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.vision_transformer.ViT_L_32_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.vision_transformer.ViT_L_32_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.vision_transformer.ViT_L_32_Weights + :members: """ weights = ViT_L_32_Weights.verify(weights) @@ -585,11 +621,20 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_h_14 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. Args: - weights (ViT_H_14_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + weights (:class:`~torchvision.models.vision_transformer.ViT_H_14_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.vision_transformer.ViT_H_14_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.vision_transformer.ViT_H_14_Weights + :members: """ weights = ViT_H_14_Weights.verify(weights)