diff --git a/docs/source/conf.py b/docs/source/conf.py index 3b12fedfb0e..f4b38075c8b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -320,7 +320,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines): "The model builder above accepts the following values as the ``weights`` parameter.", f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``.", ] + + if obj.__doc__ != "An enumeration.": + # We only show the custom enum doc if it was overriden. The default one from Python is "An enumeration" + lines.append("") + lines.append(obj.__doc__) + lines.append("") + for field in obj: lines += [f"**{str(field)}**:", ""] if field == obj.DEFAULT: @@ -335,10 +342,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines): metrics = meta.pop("metrics", {}) meta_with_metrics = dict(meta, **metrics) + meta_with_metrics.pop("categories", None) # We don't want to document these, they can be too long + + custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs + if custom_docs is not None: + lines += [custom_docs, ""] + for k, v in meta_with_metrics.items(): - if k == "categories": - continue - elif k == "recipe": + if k == "recipe": v = f"`link <{v}>`__" table.append((str(k), str(v))) table = tabulate(table, tablefmt="rst") diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 651585cfa7d..a39ca62ca78 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -90,6 +90,7 @@ def test_schema_meta_validation(model_fn): "num_params", "recipe", "unquantized", + "_docs", } # mandatory fields for each computer vision task classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")} diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index bc5d952368e..66d3b9d6370 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -355,6 +355,9 @@ class ResNet50_Weights(WeightsEnum): "acc@1": 76.130, "acc@5": 92.862, }, + "_docs": """ + These are standard weights using the basic recipe of the paper. + """, }, ) IMAGENET1K_V2 = Weights( @@ -368,6 +371,10 @@ class ResNet50_Weights(WeightsEnum): "acc@1": 80.858, "acc@5": 95.434, }, + "_docs": """ + These are improved weights, using TorchVision's `new recipe + `_. + """, }, ) DEFAULT = IMAGENET1K_V2