# Use the pretrained ViTs and add two new tokens

`[CMD]` and `[SPD]` tokens will be concatenated at the end of the input sequence

In [1]:
import torch
import torchvision

model = torchvision.models.vit_b_32(pretrained=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (linear_1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (dropout_1): Dropout(p=0.0, inplace=False)
          (linear_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout_2): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
 

In [4]:
random_img = torch.rand(1, 3, 224, 224)

In [6]:
image_size = model.image_size
patch_size = model.patch_size
hidden_dim = model.hidden_dim

## Typical ViT pipeline

The following is basically the `forward` method of the `VisionTransformer` class in `torchvision.models`

In [27]:
model.conv_proj(torch.randn(3, 3, 224, 224)).shape

torch.Size([3, 768, 7, 7])

In [18]:
# Pass the image through the convolutional/patch embedding layer
# _process_input
img = model.conv_proj(random_img)  # [1, 3, 224, 224] => [1, 768, 7, 7]
img = img.reshape(1, hidden_dim, (image_size // patch_size) ** 2)  # [1, 768, 7, 7] => [1, 768, 49]
img = img.permute(0, 2, 1)  # [1, 768, 49] => [1, 49, 768]

# Now the rest of forward
# Add the class token
cls_token = model.class_token
n = img.shape[0]

# Expand class token to match batch size
cls_token = cls_token.expand(n, -1, -1)  # [1, 1, 768] => [1, 1, 768]
img = torch.cat([cls_token, img], dim=1)  # [1, 50, 768]])

# Pass the image through the encoder (we implicitly add the pos_embedding here)
# TODO: here, add the Command and Speed "tokens" (next for ablation study, for now just the mapping)
img = img + nn.Linear(1, hidden_dim)(torch.ones(1, requires_grad=True)).view(1, 1, hidden_dim)  # SPEED [1, 50, 768] => [1, 50, 768]
img = img + nn.Linear(4, hidden_dim)(torch.ones(4, requires_grad=True)).view(1, 1, hidden_dim)  # COMMAND [1, 50, 768] => [1, 50, 768]
img = model.encoder(img)  # [1, 50, 768] => [1, 50, 768]

# Pass the result to the head for classification (only the CLS token is used)
img = img[:, 0]  # [1, 50, 768] => [1, 768]
img = model.heads(img)  # [1, 768] => [1, 1000]

In [19]:
print(img.shape)

torch.Size([1, 1000])


In [11]:
model.class_token.shape

torch.Size([1, 1, 768])

In [15]:
type(model.encoder.pos_embedding)

torch.nn.parameter.Parameter

So, if we add `[CMD]` and `[SPD]` tokens to the input sequence, we must interpolate the positional embedding to be of size `[1, 52, 768]`, bascially. We can redefine the `interpolate_embeddings` function from `torchvision.models.vision_transformer` to do this:

In [14]:
help(torchvision.models.vision_transformer.interpolate_embeddings)

Help on function interpolate_embeddings in module torchvision.models.vision_transformer:

interpolate_embeddings(image_size: int, patch_size: int, model_state: 'OrderedDict[str, torch.Tensor]', interpolation_mode: str = 'bicubic', reset_heads: bool = False) -> 'OrderedDict[str, torch.Tensor]'
    This function helps interpolating positional embeddings during checkpoint loading,
    especially when you want to apply a pre-trained model on images with different resolution.
    
    Args:
        image_size (int): Image size of the new model.
        patch_size (int): Patch size of the new model.
        model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
        interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
        reset_heads (bool): If true, not copying the state of heads. Default: False.
    
    Returns:
        OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.



In [20]:
state = model.state_dict()

In [23]:
state['encoder.pos_embedding'].shape

torch.Size([1, 50, 768])

In [None]:
import math
from collections import OrderedDict
import torch.nn as nn

def interpolate_embeddings(
    image_size: int,
    patch_size: int,
    model_state: "OrderedDict[str, torch.Tensor]",
    interpolation_mode: str = "bicubic",
    reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
    """This function helps interpolating positional embeddings during checkpoint loading,
    especially when you want to apply a pre-trained model on images with different resolution.

    Args:
        image_size (int): Image size of the new model.
        patch_size (int): Patch size of the new model.
        model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
        interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
        reset_heads (bool): If true, not copying the state of heads. Default: False.

    Returns:
        OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
    """
    # Shape of pos_embedding is (1, seq_length, hidden_dim)
    pos_embedding = model_state["encoder.pos_embedding"]
    n, seq_length, hidden_dim = pos_embedding.shape
    if n != 1:
        raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")

    new_seq_length = (image_size // patch_size) ** 2 + 1

    # Need to interpolate the weights for the position embedding.
    # We do this by reshaping the positions embeddings to a 2d grid, performing
    # an interpolation in the (h, w) space and then reshaping back to a 1d grid.

    # The class token embedding shouldn't be interpolated so we split it up.
    seq_length -= 1
    new_seq_length -= 1
    pos_embedding_token = pos_embedding[:, :1, :]  # [1, 1, 768]
    pos_embedding_img = pos_embedding[:, 1:, :]  # [1, seq_length, 768]  ( we have already decreased seq_length by 1)

    # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
    pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
    seq_length_1d = int(math.sqrt(seq_length))
    torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")

    # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
    pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
    new_seq_length_1d = image_size // patch_size

    # Perform interpolation.
    # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
    new_pos_embedding_img = nn.functional.interpolate(
        pos_embedding_img,
        size=new_seq_length_1d,
        mode=interpolation_mode,
        align_corners=True,
    )

    # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
    new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)

    # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
    new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
    new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)

    model_state["encoder.pos_embedding"] = new_pos_embedding

    if reset_heads:
        model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
        for k, v in model_state.items():
            if not k.startswith("heads"):
                model_state_copy[k] = v
        model_state = model_state_copy

    return model_state

In [None]:
def interpolate_pos_embedding(model: nn.Module, new_pos_embed_seq_len: int) -> 'OrderedDict[str, torch.Tensor]':
    """ Interpolate position encoding to the new sequence length. """
    old_state_dict = model.state_dict()

    # Get the position embedding from the old model
    pos_embed = old_state_dict["encoder.pos_embedding"]
    n, old_seq_len, d = pos_embed.shape
    pass
    
        