diff --git a/docs/source/conf.py b/docs/source/conf.py index d0cb718f4fa..137f4f86122 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -297,12 +297,33 @@ def inject_minigalleries(app, what, name, obj, options, lines): def inject_weight_metadata(app, what, name, obj, options, lines): + """This hook is used to generate docs for the models weights. + + Objects like ResNet18_Weights are enums with fields, where each field is a Weight object. + Enums aren't easily documented in Python so the solution we're going for is to: + + - add an autoclass directive in the model's builder docstring, e.g. + + ``` + .. autoclass:: torchvision.models.ResNet34_Weights + :members: + ``` + + (see resnet.py for an example) + - then this hook is called automatically when building the docs, and it generates the text that gets + used within the autoclass directive. + """ if obj.__name__.endswith("_Weights"): - lines[:] = ["The model builder above accepts the following values as the ``weights`` parameter:"] + lines[:] = [ + "The model builder above accepts the following values as the ``weights`` parameter.", + f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``.", + ] lines.append("") for field in obj: lines += [f"**{str(field)}**:", ""] + if field == obj.DEFAULT: + lines += [f"This weight is also available as ``{obj.__name__}.DEFAULT``.", ""] table = [] for k, v in field.meta.items():