Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,15 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines.append("")


def generate_weights_table(module, table_name, metrics):
def generate_weights_table(module, table_name, metrics, include_pattern=None, exclude_pattern=None):
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum]

if include_pattern is not None:
weights = [w for w in weights if include_pattern in str(w)]
if exclude_pattern is not None:
weights = [w for w in weights if exclude_pattern not in str(w)]
Comment on lines +355 to +358
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a die-hard fan of this, but we have to special-case this generation function (or its inputs) in some way. Happy to consider other suggestions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other alternative makes the code simpler here but moves complexity to the caller, i.e., they could provide a validation pattern as regex.

def generate_weights_table(module, table_name, metrics, include_pattern=".*"):
    weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
    weights = [w for weight_enum in weight_enums for w in weight_enum if re.match(include_pattern, w)]
    ...

regex would include both include and exclude patterns

Copy link
Member Author

@NicolasHug NicolasHug May 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about using regex, but I could not find a include_pattern where

re.match(include_pattern, w)

would be Truthy for weights that do not contain "Keypoint". I'm sure it's possible, but the complexity of the resulting regex might outweight the complexity of the current code. Any pointer @jdsgomes ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should do the trick:

In [59]:  re.match("^((?!Keypoint).)*$", "no_keypoint_in_name")
Out[59]: <re.Match object; span=(0, 19), match='no_keypoint_in_name'>

vs

In [58]: re.match("^((?!Keypoint).)*$", "xxxKeypoint")
Out[58]: None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jdsgomes , that seems to work indeed. I'm a bit on the fence with this. In general, I try to avoid regex like the plague. This one typically doesn't read easily to me and would require some extra comment IMHO.

I'll yield to whichever you prefer. Do you have a preference here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are including/excluding simple patterns I would agree that regex here might just complicate things and just save a few lines of code, so I would say to leave as is.


metrics_keys, metrics_names = zip(*metrics)
column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"]
column_names = [f"**{name}**" for name in column_names] # Add bold
Expand All @@ -377,7 +382,15 @@ def generate_weights_table(module, table_name, metrics):


generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
generate_weights_table(module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")])
generate_weights_table(
module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")], exclude_pattern="Keypoint"
)
generate_weights_table(
module=M.detection,
table_name="detection_keypoint",
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
include_pattern="Keypoint",
)
generate_weights_table(
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
)
Expand Down
24 changes: 24 additions & 0 deletions docs/source/models/keypoint_rcnn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Keypoint R-CNN
==============

.. currentmodule:: torchvision.models.detection

The Keypoint R-CNN model is based on the `Mask R-CNN
<https://arxiv.org/abs/1703.06870>`__ paper.


Model builders
--------------

The following model builders can be used to instantiate a Keypoint R-CNN model,
with or without pre-trained weights. All the model builders internally rely on
the ``torchvision.models.detection.KeypointRCNN`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/keypoint_rcnn.py>`__
for more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

keypointrcnn_resnet50_fpn
25 changes: 23 additions & 2 deletions docs/source/models_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ All models are evaluated on COCO val2017:



Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================
Object Detection
================

.. currentmodule:: torchvision.models.detection

Expand All @@ -114,6 +114,27 @@ Box MAPs are reported on COCO
.. include:: generated/detection_table.rst


Keypoint detection
==================

.. currentmodule:: torchvision.models.detection

The following keypoint detection models are available, with or without
pre-trained weights:

.. toctree::
:maxdepth: 1

models/keypoint_rcnn

Table of all available Keypoint detection weights
-------------------------------------------------

Box and Keypoint MAPs are reported on COCO:

.. include:: generated/detection_keypoint_table.rst


Video Classification
====================

Expand Down
14 changes: 11 additions & 3 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def keypointrcnn_resnet50_fpn(
"""
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.

Reference: `"Mask R-CNN" <https://arxiv.org/abs/1703.06870>`_.
Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.

The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
Expand Down Expand Up @@ -410,14 +410,22 @@ def keypointrcnn_resnet50_fpn(
>>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)

Args:
weights (KeypointRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model
weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int, optional): number of output classes of the model (including the background)
num_keypoints (int, optional): number of keypoints
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
pretrained weights for the backbone.
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.

.. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
:members:
"""
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
Expand Down