From 68f511eb3510ab6fffd258c8aaeffdc1a574212c Mon Sep 17 00:00:00 2001 From: Yiwen Song <34639474+sallysyw@users.noreply.github.com> Date: Mon, 10 Jan 2022 12:16:29 -0800 Subject: [PATCH] [ViT] Graduate ViT from prototype (#5173) * graduate vit from prototype * nit * add vit to docs and hubconf * ufmt * re-correct ufmt * again * fix linter --- docs/source/models.rst | 26 ++ hubconf.py | 9 +- torchvision/models/__init__.py | 1 + torchvision/models/vision_transformer.py | 429 ++++++++++++++++++ .../prototype/models/vision_transformer.py | 331 +------------- 5 files changed, 464 insertions(+), 332 deletions(-) create mode 100644 torchvision/models/vision_transformer.py diff --git a/docs/source/models.rst b/docs/source/models.rst index ee8503a0857..9c750908b06 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -40,6 +40,7 @@ architectures for image classification: - `MNASNet`_ - `EfficientNet`_ - `RegNet`_ +- `VisionTransformer`_ You can construct a model with random weights by calling its constructor: @@ -82,6 +83,10 @@ You can construct a model with random weights by calling its constructor: regnet_x_8gf = models.regnet_x_8gf() regnet_x_16gf = models.regnet_x_16gf() regnet_x_32gf = models.regnet_x_32gf() + vit_b_16 = models.vit_b_16() + vit_b_32 = models.vit_b_32() + vit_l_16 = models.vit_l_16() + vit_l_32 = models.vit_l_32() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -125,6 +130,10 @@ These can be constructed by passing ``pretrained=True``: regnet_x_8gf = models.regnet_x_8gf(pretrained=True) regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue) regnet_x_32gf = models.regnet_x_32gf(pretrained=True) + vit_b_16 = models.vit_b_16(pretrained=True) + vit_b_32 = models.vit_b_32(pretrained=True) + vit_l_16 = models.vit_l_16(pretrained=True) + vit_l_32 = models.vit_l_32(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -233,6 +242,10 @@ regnet_y_3_2gf 78.948 94.576 regnet_y_8gf 80.032 95.048 regnet_y_16gf 80.424 95.240 regnet_y_32gf 80.878 95.340 +vit_b_16 81.072 95.318 +vit_b_32 75.912 92.466 +vit_l_16 79.662 94.638 +vit_l_32 76.972 93.070 ================================ ============= ============= @@ -250,6 +263,7 @@ regnet_y_32gf 80.878 95.340 .. _MNASNet: https://arxiv.org/abs/1807.11626 .. _EfficientNet: https://arxiv.org/abs/1905.11946 .. _RegNet: https://arxiv.org/abs/2003.13678 +.. _VisionTransformer: https://arxiv.org/abs/2010.11929 .. currentmodule:: torchvision.models @@ -433,6 +447,18 @@ RegNet regnet_x_16gf regnet_x_32gf +VisionTransformer +----------------- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + vit_b_16 + vit_b_32 + vit_l_16 + vit_l_32 + Quantized Models ---------------- diff --git a/hubconf.py b/hubconf.py index e89754d87f3..81b15ff9ff1 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,7 +1,6 @@ # Optional list of dependencies required by the package dependencies = ["torch"] -# classification from torchvision.models.alexnet import alexnet from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.efficientnet import ( @@ -47,8 +46,6 @@ wide_resnet50_2, wide_resnet101_2, ) - -# segmentation from torchvision.models.segmentation import ( fcn_resnet50, fcn_resnet101, @@ -60,3 +57,9 @@ from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn +from torchvision.models.vision_transformer import ( + vit_b_16, + vit_b_32, + vit_l_16, + vit_l_32, +) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index c9d11f88f01..22e2e45d4ce 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -10,6 +10,7 @@ from .shufflenetv2 import * from .efficientnet import * from .regnet import * +from .vision_transformer import * from . import detection from . import feature_extraction from . import optical_flow diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py new file mode 100644 index 00000000000..11ecd9d97ad --- /dev/null +++ b/torchvision/models/vision_transformer.py @@ -0,0 +1,429 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn + +from .._internally_replaced_utils import load_state_dict_from_url +from ..utils import _log_api_usage_once + +__all__ = [ + "VisionTransformer", + "vit_b_16", + "vit_b_32", + "vit_l_16", + "vit_l_32", +] + +model_urls = { + "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", + "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", +} + + +class MLPBlock(nn.Sequential): + """Transformer MLP block.""" + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float): + super().__init__() + self.linear_1 = nn.Linear(in_dim, mlp_dim) + self.act = nn.GELU() + self.dropout_1 = nn.Dropout(dropout) + self.linear_2 = nn.Linear(mlp_dim, in_dim) + self.dropout_2 = nn.Dropout(dropout) + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.linear_1.weight) + nn.init.xavier_uniform_(self.linear_2.weight) + nn.init.normal_(self.linear_1.bias, std=1e-6) + nn.init.normal_(self.linear_2.bias, std=1e-6) + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.ln_2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") + x = self.ln_1(input) + x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) + x = self.dropout(x) + x = x + input + + y = self.ln_2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.layers = nn.Sequential(layers) + self.ln = norm_layer(hidden_dim) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + input = input + self.pos_embedding + return self.ln(self.layers(self.dropout(input))) + + +class VisionTransformer(nn.Module): + """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + _log_api_usage_once(self) + torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout + self.num_classes = num_classes + self.representation_size = representation_size + self.norm_layer = norm_layer + + input_channels = 3 + + # The conv_proj is a more efficient version of reshaping, permuting + # and projecting the input + self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size) + + seq_length = (image_size // patch_size) ** 2 + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + seq_length, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.seq_length = seq_length + + heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + if representation_size is None: + heads_layers["head"] = nn.Linear(hidden_dim, num_classes) + else: + heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) + heads_layers["act"] = nn.Tanh() + heads_layers["head"] = nn.Linear(representation_size, num_classes) + + self.heads = nn.Sequential(heads_layers) + self._init_weights() + + def _init_weights(self): + fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.conv_proj.bias) + + if hasattr(self.heads, "pre_logits"): + fan_in = self.heads.pre_logits.in_features + nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.heads.pre_logits.bias) + + nn.init.zeros_(self.heads.head.weight) + nn.init.zeros_(self.heads.head.bias) + + def _process_input(self, x: torch.Tensor) -> torch.Tensor: + n, c, h, w = x.shape + p = self.patch_size + torch._assert(h == self.image_size, "Wrong image height!") + torch._assert(w == self.image_size, "Wrong image width!") + n_h = h // p + n_w = w // p + + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) + + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) + + return x + + def forward(self, x: torch.Tensor): + # Reshape and permute the input tensor + x = self._process_input(x) + n = x.shape[0] + + # Expand the class token to the full batch + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + x = self.encoder(x) + + # Classifier "token" as used by standard language architectures + x = x[:, 0] + + x = self.heads(x) + + return x + + +def _vision_transformer( + arch: str, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + pretrained: bool, + progress: bool, + **kwargs: Any, +) -> VisionTransformer: + image_size = kwargs.pop("image_size", 224) + + model = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + **kwargs, + ) + + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + + return model + + +def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_b_16", + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_b_32", + patch_size=32, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_l_16", + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_l_32", + patch_size=32, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def interpolate_embeddings( + image_size: int, + patch_size: int, + model_state: "OrderedDict[str, torch.Tensor]", + interpolation_mode: str = "bicubic", + reset_heads: bool = False, +) -> "OrderedDict[str, torch.Tensor]": + """This function helps interpolating positional embeddings during checkpoint loading, + especially when you want to apply a pre-trained model on images with different resolution. + + Args: + image_size (int): Image size of the new model. + patch_size (int): Patch size of the new model. + model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. + interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. + reset_heads (bool): If true, not copying the state of heads. Default: False. + + Returns: + OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. + """ + # Shape of pos_embedding is (1, seq_length, hidden_dim) + pos_embedding = model_state["encoder.pos_embedding"] + n, seq_length, hidden_dim = pos_embedding.shape + if n != 1: + raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") + + new_seq_length = (image_size // patch_size) ** 2 + 1 + + # Need to interpolate the weights for the position embedding. + # We do this by reshaping the positions embeddings to a 2d grid, performing + # an interpolation in the (h, w) space and then reshaping back to a 1d grid. + if new_seq_length != seq_length: + # The class token embedding shouldn't be interpolated so we split it up. + seq_length -= 1 + new_seq_length -= 1 + pos_embedding_token = pos_embedding[:, :1, :] + pos_embedding_img = pos_embedding[:, 1:, :] + + # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) + pos_embedding_img = pos_embedding_img.permute(0, 2, 1) + seq_length_1d = int(math.sqrt(seq_length)) + torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") + + # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) + pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) + new_seq_length_1d = image_size // patch_size + + # Perform interpolation. + # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) + new_pos_embedding_img = nn.functional.interpolate( + pos_embedding_img, + size=new_seq_length_1d, + mode=interpolation_mode, + align_corners=True, + ) + + # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) + new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) + + # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) + new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) + new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) + + model_state["encoder.pos_embedding"] = new_pos_embedding + + if reset_heads: + model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() + for k, v in model_state.items(): + if not k.startswith("heads"): + model_state_copy[k] = v + model_state = model_state_copy + + return model_state diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 60923a5d0b5..f5b6bfff790 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -2,18 +2,13 @@ # https://github.com/google-research/vision_transformer # https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py -import math -from collections import OrderedDict from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Optional -import torch -import torch.nn as nn -from torch import Tensor from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode -from ...utils import _log_api_usage_once +from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface @@ -31,217 +26,6 @@ ] -class MLPBlock(nn.Sequential): - """Transformer MLP block.""" - - def __init__(self, in_dim: int, mlp_dim: int, dropout: float): - super().__init__() - self.linear_1 = nn.Linear(in_dim, mlp_dim) - self.act = nn.GELU() - self.dropout_1 = nn.Dropout(dropout) - self.linear_2 = nn.Linear(mlp_dim, in_dim) - self.dropout_2 = nn.Dropout(dropout) - self._init_weights() - - def _init_weights(self): - nn.init.xavier_uniform_(self.linear_1.weight) - nn.init.xavier_uniform_(self.linear_2.weight) - nn.init.normal_(self.linear_1.bias, std=1e-6) - nn.init.normal_(self.linear_2.bias, std=1e-6) - - -class EncoderBlock(nn.Module): - """Transformer encoder block.""" - - def __init__( - self, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - dropout: float, - attention_dropout: float, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), - ): - super().__init__() - self.num_heads = num_heads - - # Attention block - self.ln_1 = norm_layer(hidden_dim) - self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) - self.dropout = nn.Dropout(dropout) - - # MLP block - self.ln_2 = norm_layer(hidden_dim) - self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) - - def forward(self, input: Tensor): - torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") - x = self.ln_1(input) - x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) - x = self.dropout(x) - x = x + input - - y = self.ln_2(x) - y = self.mlp(y) - return x + y - - -class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - - def __init__( - self, - seq_length: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - dropout: float, - attention_dropout: float, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), - ): - super().__init__() - # Note that batch_size is on the first dim because - # we have batch_first=True in nn.MultiAttention() by default - self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT - self.dropout = nn.Dropout(dropout) - layers: OrderedDict[str, nn.Module] = OrderedDict() - for i in range(num_layers): - layers[f"encoder_layer_{i}"] = EncoderBlock( - num_heads, - hidden_dim, - mlp_dim, - dropout, - attention_dropout, - norm_layer, - ) - self.layers = nn.Sequential(layers) - self.ln = norm_layer(hidden_dim) - - def forward(self, input: Tensor): - torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") - input = input + self.pos_embedding - return self.ln(self.layers(self.dropout(input))) - - -class VisionTransformer(nn.Module): - """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" - - def __init__( - self, - image_size: int, - patch_size: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - dropout: float = 0.0, - attention_dropout: float = 0.0, - num_classes: int = 1000, - representation_size: Optional[int] = None, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), - ): - super().__init__() - _log_api_usage_once(self) - torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") - self.image_size = image_size - self.patch_size = patch_size - self.hidden_dim = hidden_dim - self.mlp_dim = mlp_dim - self.attention_dropout = attention_dropout - self.dropout = dropout - self.num_classes = num_classes - self.representation_size = representation_size - self.norm_layer = norm_layer - - input_channels = 3 - - # The conv_proj is a more efficient version of reshaping, permuting - # and projecting the input - self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size) - - seq_length = (image_size // patch_size) ** 2 - - # Add a class token - self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) - seq_length += 1 - - self.encoder = Encoder( - seq_length, - num_layers, - num_heads, - hidden_dim, - mlp_dim, - dropout, - attention_dropout, - norm_layer, - ) - self.seq_length = seq_length - - heads_layers: OrderedDict[str, nn.Module] = OrderedDict() - if representation_size is None: - heads_layers["head"] = nn.Linear(hidden_dim, num_classes) - else: - heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) - heads_layers["act"] = nn.Tanh() - heads_layers["head"] = nn.Linear(representation_size, num_classes) - - self.heads = nn.Sequential(heads_layers) - self._init_weights() - - def _init_weights(self): - fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] - nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) - nn.init.zeros_(self.conv_proj.bias) - - if hasattr(self.heads, "pre_logits"): - fan_in = self.heads.pre_logits.in_features - nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) - nn.init.zeros_(self.heads.pre_logits.bias) - - nn.init.zeros_(self.heads.head.weight) - nn.init.zeros_(self.heads.head.bias) - - def _process_input(self, x: torch.Tensor) -> torch.Tensor: - n, c, h, w = x.shape - p = self.patch_size - torch._assert(h == self.image_size, "Wrong image height!") - torch._assert(w == self.image_size, "Wrong image width!") - n_h = h // p - n_w = w // p - - # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) - x = self.conv_proj(x) - # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) - x = x.reshape(n, self.hidden_dim, n_h * n_w) - - # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) - # The self attention layer expects inputs in the format (N, S, E) - # where S is the source sequence length, N is the batch size, E is the - # embedding dimension - x = x.permute(0, 2, 1) - - return x - - def forward(self, x: torch.Tensor): - # Reshaping and permuting the input tensor - x = self._process_input(x) - n = x.shape[0] - - # Expand the class token to the full batch - batch_class_token = self.class_token.expand(n, -1, -1) - x = torch.cat([batch_class_token, x], dim=1) - - x = self.encoder(x) - - # Classifier "token" as used by standard language architectures - x = x[:, 0] - - x = self.heads(x) - - return x - - _COMMON_META = { "task": "image_classification", "architecture": "ViT", @@ -345,15 +129,6 @@ def _vision_transformer( @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.ImageNet1K_V1)) def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - """ - Constructs a vit_b_16 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. - - Args: - weights (ViT_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet. - Default: None. - progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. - """ weights = ViT_B_16_Weights.verify(weights) return _vision_transformer( @@ -370,15 +145,6 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru @handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.ImageNet1K_V1)) def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - """ - Constructs a vit_b_32 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. - - Args: - weights (ViT_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet. - Default: None. - progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. - """ weights = ViT_B_32_Weights.verify(weights) return _vision_transformer( @@ -395,15 +161,6 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru @handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.ImageNet1K_V1)) def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - """ - Constructs a vit_l_16 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. - - Args: - weights (ViT_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. - Default: None. - progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. - """ weights = ViT_L_16_Weights.verify(weights) return _vision_transformer( @@ -420,15 +177,6 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru @handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.ImageNet1K_V1)) def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - """ - Constructs a vit_l_32 architecture from - `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. - - Args: - weights (ViT_L_32Weights, optional): If not None, returns a model pre-trained on ImageNet. - Default: None. - progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. - """ weights = ViT_L_32_Weights.verify(weights) return _vision_transformer( @@ -441,78 +189,3 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru progress=progress, **kwargs, ) - - -def interpolate_embeddings( - image_size: int, - patch_size: int, - model_state: "OrderedDict[str, torch.Tensor]", - interpolation_mode: str = "bicubic", - reset_heads: bool = False, -) -> "OrderedDict[str, torch.Tensor]": - """This function helps interpolating positional embeddings during checkpoint loading, - especially when you want to apply a pre-trained model on images with different resolution. - - Args: - image_size (int): Image size of the new model. - patch_size (int): Patch size of the new model. - model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. - interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. - reset_heads (bool): If true, not copying the state of heads. Default: False. - - Returns: - OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. - """ - # Shape of pos_embedding is (1, seq_length, hidden_dim) - pos_embedding = model_state["encoder.pos_embedding"] - n, seq_length, hidden_dim = pos_embedding.shape - if n != 1: - raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") - - new_seq_length = (image_size // patch_size) ** 2 + 1 - - # Need to interpolate the weights for the position embedding. - # We do this by reshaping the positions embeddings to a 2d grid, performing - # an interpolation in the (h, w) space and then reshaping back to a 1d grid. - if new_seq_length != seq_length: - # The class token embedding shouldn't be interpolated so we split it up. - seq_length -= 1 - new_seq_length -= 1 - pos_embedding_token = pos_embedding[:, :1, :] - pos_embedding_img = pos_embedding[:, 1:, :] - - # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) - pos_embedding_img = pos_embedding_img.permute(0, 2, 1) - seq_length_1d = int(math.sqrt(seq_length)) - torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") - - # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) - pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) - new_seq_length_1d = image_size // patch_size - - # Perform interpolation. - # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) - new_pos_embedding_img = nn.functional.interpolate( - pos_embedding_img, - size=new_seq_length_1d, - mode=interpolation_mode, - align_corners=True, - ) - - # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) - new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) - - # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) - new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) - new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) - - model_state["encoder.pos_embedding"] = new_pos_embedding - - if reset_heads: - model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() - for k, v in model_state.items(): - if not k.startswith("heads"): - model_state_copy[k] = v - model_state = model_state_copy - - return model_state