diff --git a/docs/source/models/ssdlite.rst b/docs/source/models/ssdlite.rst new file mode 100644 index 00000000000..1f8437a6ff1 --- /dev/null +++ b/docs/source/models/ssdlite.rst @@ -0,0 +1,26 @@ +SSDlite +======= + +.. currentmodule:: torchvision.models.detection + +The SSDLite model is based on the `SSD: Single Shot MultiBox Detector +`__, `Searching for MobileNetV3 +`__ and `MobileNetV2: Inverted Residuals and Linear +Bottlenecks __` papers. + + +Model builders +-------------- + +The following model builders can be used to instantiate a SSD Lite model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.detection.ssd.SSD`` base class. Please refer to the `source +code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + ssdlite320_mobilenet_v3_large diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index b4945938824..43f0190f006 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -98,6 +98,7 @@ weights: models/fcos models/mask_rcnn models/retinanet + models/ssdlite Table of all available detection weights ---------------------------------------- diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index a18e166c429..b23490bc295 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -215,11 +215,9 @@ def ssdlite320_mobilenet_v3_large( norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, ) -> SSD: - """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at - `"Searching for MobileNetV3" - `_ and - `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" - `_. + """SSDlite model architecture with input size 320x320 and a MobileNetV3 Large backbone, as + described at `Searching for MobileNetV3 `__ and + `MobileNetV2: Inverted Residuals and Linear Bottlenecks `__. See :func:`~torchvision.models.detection.ssd300_vgg16` for more details. @@ -231,15 +229,31 @@ def ssdlite320_mobilenet_v3_large( >>> predictions = model(x) Args: - weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int, optional): number of output classes of the model (including the background) - weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone - trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. - Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is - passed (the default) this value is set to 6. + weights (:class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model + (including the background). + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained + weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers + starting from final block. Valid values are between 0 and 6, with 6 meaning all + backbone layers are trainable. If ``None`` is passed (the default) this value is + set to 6. norm_layer (callable, optional): Module specifying the normalization layer to use. + **kwargs: parameters passed to the ``torchvision.models.detection.ssd.SSD`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights + :members: """ + weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)