diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 43e4d315cec..3d68968eada 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -76,7 +76,8 @@ def __init__( self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) def forward(self, input: torch.Tensor): - torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") + if input.dim() != 3: + raise ValueError(f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") x = self.ln_1(input) x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x = self.dropout(x) @@ -120,7 +121,8 @@ def __init__( self.ln = norm_layer(hidden_dim) def forward(self, input: torch.Tensor): - torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + if input.dim() != 3: + raise ValueError(f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") input = input + self.pos_embedding return self.ln(self.layers(self.dropout(input))) @@ -145,7 +147,10 @@ def __init__( ): super().__init__() _log_api_usage_once(self) - torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + if image_size % patch_size != 0: + raise ValueError( + f"Input shape indivisible by patch size! Instead got image_size = {image_size} and patch_size = {patch_size}" + ) self.image_size = image_size self.patch_size = patch_size self.hidden_dim = hidden_dim