Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
"The model builder above accepts the following values as the ``weights`` parameter.",
f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``.",
]

if obj.__doc__ != "An enumeration.":
# We only show the custom enum doc if it was overriden. The default one from Python is "An enumeration"
lines.append("")
lines.append(obj.__doc__)

lines.append("")

for field in obj:
lines += [f"**{str(field)}**:", ""]
if field == obj.DEFAULT:
Expand All @@ -335,10 +342,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics)

meta_with_metrics.pop("categories", None) # We don't want to document these, they can be too long

custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs
if custom_docs is not None:
lines += [custom_docs, ""]

for k, v in meta_with_metrics.items():
if k == "categories":
continue
elif k == "recipe":
if k == "recipe":
v = f"`link <{v}>`__"
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
Expand Down
1 change: 1 addition & 0 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_schema_meta_validation(model_fn):
"num_params",
"recipe",
"unquantized",
"_docs",
}
# mandatory fields for each computer vision task
classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")}
Expand Down
7 changes: 7 additions & 0 deletions torchvision/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 76.130,
"acc@5": 92.862,
},
"_docs": """
These are standard weights using the basic recipe of the paper.
""",
},
)
IMAGENET1K_V2 = Weights(
Expand All @@ -368,6 +371,10 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 80.858,
"acc@5": 95.434,
},
"_docs": """
These are improved weights, using TorchVision's `new recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V2
Expand Down