From 0d089d2e9158757048972fb635c6a8ad0686f135 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 17 Mar 2022 12:35:02 +0000 Subject: [PATCH 1/2] replace remaining asserts --- torchvision/models/vision_transformer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 43e4d315cec..6e8e71be904 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -76,7 +76,8 @@ def __init__( self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) def forward(self, input: torch.Tensor): - torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") + if input.dim() != 3: + raise ValueError(f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x = self.dropout(x) @@ -120,7 +121,8 @@ def __init__( self.ln = norm_layer(hidden_dim) def forward(self, input: torch.Tensor): - torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + if input.dim() != 3: + raise ValueError(f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") input = input + self.pos_embedding return self.ln(self.layers(self.dropout(input))) @@ -145,7 +147,10 @@ def __init__( ): super().__init__() _log_api_usage_once(self) - torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + if image_size % patch_size != 0: + raise ValueError( + f"Input shape indivisible by patch size! Instead got image_size = {image_size} and patch_size = {patch_size}" + ) self.image_size = image_size self.patch_size = patch_size self.hidden_dim = hidden_dim @@ -236,8 +241,10 @@ def __init__( def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape p = self.patch_size - torch._assert(h == self.image_size, "Wrong image height!") - torch._assert(w == self.image_size, "Wrong image width!") + if h != self.image_size: + raise ValueError(f"Wrong image height! Expected {self.image_size} but got {h}") + if w != self.image_size: + raise ValueError(f"Wrong image width! Expected {self.image_size} but got {w}") n_h = h // p n_w = w // p From 519831769537920ee09b12a1c9460654f6a8aac1 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Mon, 21 Mar 2022 20:53:45 +0000 Subject: [PATCH 2/2] fix tests --- torchvision/models/vision_transformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 6e8e71be904..3d68968eada 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -241,10 +241,8 @@ def __init__( def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape p = self.patch_size - if h != self.image_size: - raise ValueError(f"Wrong image height! Expected {self.image_size} but got {h}") - if w != self.image_size: - raise ValueError(f"Wrong image width! Expected {self.image_size} but got {w}") + torch._assert(h == self.image_size, "Wrong image height!") + torch._assert(w == self.image_size, "Wrong image width!") n_h = h // p n_w = w // p