From 3f320ce00783cd71d4e3c014fd8cf46c1538e693 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 7 Dec 2021 12:20:09 +0000 Subject: [PATCH 1/2] Adding logging calls for raft and vit --- torchvision/models/optical_flow/raft.py | 2 ++ torchvision/prototype/models/vision_transformer.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 02705a7ebdb..45a979eaa0b 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -9,6 +9,7 @@ from torchvision.ops import ConvNormActivation from ._utils import grid_sample, make_coords_grid, upsample_flow +from ...utils import _log_api_usage_once __all__ = ( @@ -432,6 +433,7 @@ def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block If ``None`` (default), the flow is upsampled using interpolation. """ super().__init__() + _log_api_usage_once(self) self.feature_encoder = feature_encoder self.context_encoder = context_encoder diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 9794559745d..ddfb3ef45c7 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -13,6 +13,7 @@ from ._api import WeightsEnum from ._utils import handle_legacy_interface +from ...utils import _log_api_usage_once __all__ = [ @@ -139,6 +140,7 @@ def __init__( norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() + _log_api_usage_once(self) torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") self.image_size = image_size self.patch_size = patch_size From 7930c4649fbad57246489912f7753fe050078bad Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 7 Dec 2021 12:31:34 +0000 Subject: [PATCH 2/2] Linter fix --- torchvision/models/optical_flow/raft.py | 2 +- torchvision/prototype/models/vision_transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 45a979eaa0b..ba1cc8499d8 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,8 +8,8 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import ConvNormActivation -from ._utils import grid_sample, make_coords_grid, upsample_flow from ...utils import _log_api_usage_once +from ._utils import grid_sample, make_coords_grid, upsample_flow __all__ = ( diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index ddfb3ef45c7..ae8eee45539 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -11,9 +11,9 @@ import torch.nn as nn from torch import Tensor +from ...utils import _log_api_usage_once from ._api import WeightsEnum from ._utils import handle_legacy_interface -from ...utils import _log_api_usage_once __all__ = [