diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 291041d7b5f..7d8008c4f27 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch.jit.annotations import Dict +from typing import Dict class IntermediateLayerGetter(nn.ModuleDict): @@ -41,7 +41,7 @@ class IntermediateLayerGetter(nn.ModuleDict): "return_layers": Dict[str, str], } - def __init__(self, model, return_layers): + def __init__(self, model: nn.Module, return_layers: Dict[str, str]): if not set(return_layers).issubset([name for name, _ in model.named_children()]): raise ValueError("return_layers are not present in model") orig_return_layers = return_layers