-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Closed
Description
Lines 403 to 413 in 297e2b8
@pytest.mark.parametrize("model_fn", [models.mobilenet_v2, models.mobilenet_v3_large, models.mobilenet_v3_small]) | |
def test_mobilenet_norm_layer(model_fn): | |
model = model_fn() | |
assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) | |
def get_gn(num_channels): | |
return nn.GroupNorm(32, num_channels) | |
model = model_fn(norm_layer=get_gn) | |
assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) | |
assert any(isinstance(x, nn.GroupNorm) for x in model.modules()) |
fails on main
due to pytorch/pytorch#74293.