Skip to content
Merged
51 changes: 32 additions & 19 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,25 +334,22 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines.append("")

for field in obj:
lines += [f"**{str(field)}**:", ""]

table = []

# the `meta` dict contains another embedded `metrics` dict. To
# simplify the table generation below, we create the
# `meta_with_metrics` dict, where the metrics dict has been "flattened"
meta = copy(field.meta)
metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics)

lines += [meta_with_metrics.pop("_docs")]
lines += [f"**{str(field)}**:", ""]
lines += [meta.pop("_docs")]

if field == obj.DEFAULT:
lines += [f"Also available as ``{obj.__name__}.DEFAULT``."]

lines += [""]

for k, v in meta_with_metrics.items():
table = []
metrics = meta.pop("_metrics")
for dataset, dataset_metrics in metrics.items():
for metric_name, metric_value in dataset_metrics.items():
table.append((f"{metric_name} (on {dataset})", str(metric_value)))

for k, v in meta.items():
if k in {"recipe", "license"}:
v = f"`link <{v}>`__"
elif k == "min_size":
Expand All @@ -374,7 +371,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines.append("")


def generate_weights_table(module, table_name, metrics, include_patterns=None, exclude_patterns=None):
def generate_weights_table(module, table_name, metrics, dataset, include_patterns=None, exclude_patterns=None):
weights_endswith = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith(weights_endswith)]
weights = [w for weight_enum in weight_enums for w in weight_enum]
Expand All @@ -391,7 +388,7 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
content = [
(
f":class:`{w} <{type(w).__name__}>`",
*(w.meta["metrics"][metric] for metric in metrics_keys),
*(w.meta["_metrics"][dataset][metric] for metric in metrics_keys),
f"{w.meta['num_params']/1e6:.1f}M",
f"`link <{w.meta['recipe']}>`__",
)
Expand All @@ -408,29 +405,45 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")


generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
generate_weights_table(
module=M.quantization, table_name="classification_quant", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")]
module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="ImageNet-1K"
)
generate_weights_table(
module=M.quantization,
table_name="classification_quant",
metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")],
dataset="ImageNet-1K",
)
generate_weights_table(
module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")], exclude_patterns=["Mask", "Keypoint"]
module=M.detection,
table_name="detection",
metrics=[("box_map", "Box MAP")],
exclude_patterns=["Mask", "Keypoint"],
dataset="COCO-val2017",
)
generate_weights_table(
module=M.detection,
table_name="instance_segmentation",
metrics=[("box_map", "Box MAP"), ("mask_map", "Mask MAP")],
dataset="COCO-val2017",
include_patterns=["Mask"],
)
generate_weights_table(
module=M.detection,
table_name="detection_keypoint",
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
dataset="COCO-val2017",
include_patterns=["Keypoint"],
)
generate_weights_table(
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
module=M.segmentation,
table_name="segmentation",
metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")],
dataset="COCO-val2017-VOC-labels",
)
generate_weights_table(
module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="Kinetics-400"
)
generate_weights_table(module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])


def setup(app):
Expand Down
26 changes: 18 additions & 8 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,31 @@ def test_schema_meta_validation(model_fn):
"categories",
"keypoint_names",
"license",
"metrics",
"_metrics",
"min_size",
"num_params",
"recipe",
"unquantized",
"_docs",
}
# mandatory fields for each computer vision task
classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")}
classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
defaults = {
"all": {"metrics", "min_size", "num_params", "recipe", "_docs"},
"all": {"_metrics", "min_size", "num_params", "recipe", "_docs"},
"models": classification_fields,
"detection": {"categories", ("metrics", "box_map")},
"detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
"quantization": classification_fields | {"backend", "unquantized"},
"segmentation": {"categories", ("metrics", "miou"), ("metrics", "pixel_acc")},
"video": classification_fields,
"segmentation": {
"categories",
("_metrics", "COCO-val2017-VOC-labels", "miou"),
("_metrics", "COCO-val2017-VOC-labels", "pixel_acc"),
},
"video": {"categories", ("_metrics", "Kinetics-400", "acc@1"), ("_metrics", "Kinetics-400", "acc@5")},
"optical_flow": set(),
}
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
fields = defaults["all"] | defaults[module_name]
expected_fields = defaults["all"] | defaults[module_name]

weights_enum = _get_model_weights(model_fn)
if len(weights_enum) == 0:
Expand All @@ -115,7 +119,13 @@ def test_schema_meta_validation(model_fn):
incorrect_params = []
bad_names = []
for w in weights_enum:
missing_fields = fields - (set(w.meta.keys()) | set(("metrics", x) for x in w.meta.get("metrics", {}).keys()))
actual_fields = set(w.meta.keys())
actual_fields |= set(
("_metrics", dataset, metric_key)
for dataset in w.meta.get("_metrics", {}).keys()
for metric_key in w.meta.get("_metrics", {}).get(dataset, {}).keys()
)
missing_fields = expected_fields - actual_fields
unsupported_fields = set(w.meta.keys()) - permitted_fields
if missing_fields or unsupported_fields:
problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
Expand Down
8 changes: 5 additions & 3 deletions torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ class AlexNet_Weights(WeightsEnum):
"min_size": (63, 63),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"metrics": {
"acc@1": 56.522,
"acc@5": 79.066,
"_metrics": {
"ImageNet-1K": {
"acc@1": 56.522,
"acc@5": 79.066,
}
},
"_docs": """
These weights reproduce closely the results of the paper using a simplified training recipe.
Expand Down
32 changes: 20 additions & 12 deletions torchvision/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,11 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 28589128,
"metrics": {
"acc@1": 82.520,
"acc@5": 96.146,
"_metrics": {
"ImageNet-1K": {
"acc@1": 82.520,
"acc@5": 96.146,
}
},
},
)
Expand All @@ -238,9 +240,11 @@ class ConvNeXt_Small_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 50223688,
"metrics": {
"acc@1": 83.616,
"acc@5": 96.650,
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.616,
"acc@5": 96.650,
}
},
},
)
Expand All @@ -254,9 +258,11 @@ class ConvNeXt_Base_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 88591464,
"metrics": {
"acc@1": 84.062,
"acc@5": 96.870,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.062,
"acc@5": 96.870,
}
},
},
)
Expand All @@ -270,9 +276,11 @@ class ConvNeXt_Large_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 197767336,
"metrics": {
"acc@1": 84.414,
"acc@5": 96.976,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.414,
"acc@5": 96.976,
}
},
},
)
Expand Down
32 changes: 20 additions & 12 deletions torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,11 @@ class DenseNet121_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 7978856,
"metrics": {
"acc@1": 74.434,
"acc@5": 91.972,
"_metrics": {
"ImageNet-1K": {
"acc@1": 74.434,
"acc@5": 91.972,
}
},
},
)
Expand All @@ -288,9 +290,11 @@ class DenseNet161_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 28681000,
"metrics": {
"acc@1": 77.138,
"acc@5": 93.560,
"_metrics": {
"ImageNet-1K": {
"acc@1": 77.138,
"acc@5": 93.560,
}
},
},
)
Expand All @@ -304,9 +308,11 @@ class DenseNet169_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 14149480,
"metrics": {
"acc@1": 75.600,
"acc@5": 92.806,
"_metrics": {
"ImageNet-1K": {
"acc@1": 75.600,
"acc@5": 92.806,
}
},
},
)
Expand All @@ -320,9 +326,11 @@ class DenseNet201_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 20013928,
"metrics": {
"acc@1": 76.896,
"acc@5": 93.370,
"_metrics": {
"ImageNet-1K": {
"acc@1": 76.896,
"acc@5": 93.370,
}
},
},
)
Expand Down
24 changes: 16 additions & 8 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
**_COMMON_META,
"num_params": 41755286,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"metrics": {
"box_map": 37.0,
"_metrics": {
"COCO-val2017": {
"box_map": 37.0,
}
},
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
},
Expand All @@ -400,8 +402,10 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
**_COMMON_META,
"num_params": 43712278,
"recipe": "https://github.com/pytorch/vision/pull/5763",
"metrics": {
"box_map": 46.7,
"_metrics": {
"COCO-val2017": {
"box_map": 46.7,
}
},
"_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
},
Expand All @@ -417,8 +421,10 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"metrics": {
"box_map": 32.8,
"_metrics": {
"COCO-val2017": {
"box_map": 32.8,
}
},
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
},
Expand All @@ -434,8 +440,10 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"metrics": {
"box_map": 22.8,
"_metrics": {
"COCO-val2017": {
"box_map": 22.8,
}
},
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
},
Expand Down
6 changes: 4 additions & 2 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,10 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
"categories": _COCO_CATEGORIES,
"min_size": (1, 1),
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
"metrics": {
"box_map": 39.2,
"_metrics": {
"COCO-val2017": {
"box_map": 39.2,
}
},
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
},
Expand Down
16 changes: 10 additions & 6 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,11 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
**_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/issues/1606",
"metrics": {
"box_map": 50.6,
"kp_map": 61.1,
"_metrics": {
"COCO-val2017": {
Copy link
Contributor

@datumbox datumbox May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non blocking question. @fmassa Any thoughts on the name of keypoints dataset here?

"box_map": 50.6,
"kp_map": 61.1,
}
},
"_docs": """
These weights were produced by following a similar training recipe as on the paper but use a checkpoint
Expand All @@ -339,9 +341,11 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
**_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"metrics": {
"box_map": 54.6,
"kp_map": 65.0,
"_metrics": {
"COCO-val2017": {
"box_map": 54.6,
"kp_map": 65.0,
}
},
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
},
Expand Down
Loading