In [3]:
import torch

from transformers import ViTConfig, ViTModel
from transformers import ViTForImageClassification
from transformers import ViTFeatureExtractor

 * To load the PyTorch native [scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), specify `attn_implementation="sdpa"`

In [4]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    attn_implementation="sdpa",
    torch_dtype=torch.float16
)

model

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [5]:
model = ViTModel.from_pretrained(
    "google/vit-base-patch16-224",
    attn_implementation="sdpa",
)

model

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTSdpaAttention(
          (attention): ViTSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUAct

In [9]:
list(model.state_dict().keys())

['embeddings.cls_token',
 'embeddings.position_embeddings',
 'embeddings.patch_embeddings.projection.weight',
 'embeddings.patch_embeddings.projection.bias',
 'encoder.layer.0.attention.attention.query.weight',
 'encoder.layer.0.attention.attention.query.bias',
 'encoder.layer.0.attention.attention.key.weight',
 'encoder.layer.0.attention.attention.key.bias',
 'encoder.layer.0.attention.attention.value.weight',
 'encoder.layer.0.attention.attention.value.bias',
 'encoder.layer.0.attention.output.dense.weight',
 'encoder.layer.0.attention.output.dense.bias',
 'encoder.layer.0.intermediate.dense.weight',
 'encoder.layer.0.intermediate.dense.bias',
 'encoder.layer.0.output.dense.weight',
 'encoder.layer.0.output.dense.bias',
 'encoder.layer.0.layernorm_before.weight',
 'encoder.layer.0.layernorm_before.bias',
 'encoder.layer.0.layernorm_after.weight',
 'encoder.layer.0.layernorm_after.bias',
 'encoder.layer.1.attention.attention.query.weight',
 'encoder.layer.1.attention.attention.query.b

In [12]:
model.encoder.layer[0].attention.attention.query.weight

Parameter containing:
tensor([[ 0.0182,  0.2526, -0.1410,  ...,  0.0096,  0.2832, -0.0863],
        [ 0.1067,  0.3417,  0.0522,  ...,  0.2693, -0.0120, -0.0830],
        [ 0.0321, -0.1272, -0.1964,  ...,  0.0236,  0.0351, -0.1173],
        ...,
        [ 0.0295,  0.0133,  0.1008,  ...,  0.0110, -0.0184, -0.1565],
        [ 0.0468,  0.0314, -0.0288,  ..., -0.0640,  0.0282, -0.2290],
        [ 0.0561,  0.1137, -0.0054,  ..., -0.0041,  0.0535,  0.0439]],
       requires_grad=True)

In [13]:
model.encoder.layer[0].attention.attention.query.bias

Parameter containing:
tensor([-2.0960e-01, -4.7015e-02, -4.5674e-01, -4.2780e-02, -4.1338e-01,
        -4.6941e-01, -4.6764e-01,  4.9690e-01, -4.6630e-01,  3.0337e-01,
        -3.1801e-01, -4.9101e-02, -1.3426e-01, -3.3898e-01,  1.3503e-02,
         2.1375e-01, -4.8420e-02,  4.7879e-01,  3.1575e-01, -4.9546e-01,
         5.1784e-02,  7.6204e-02,  4.1113e-01, -8.2276e-02, -1.3096e-01,
        -4.8452e-01, -4.3603e-01, -1.0970e-01,  2.7601e-01, -2.7287e-02,
         3.6353e-01,  2.1364e-01,  4.9591e-01, -1.0548e-01, -1.8365e-01,
         6.5965e-03,  5.3930e-02, -9.0674e-02, -1.5761e-01,  2.4053e-01,
         4.3008e-01,  1.0813e-01,  3.5249e-01,  4.7268e-01,  2.0266e-01,
         4.1082e-01,  7.8327e-02,  1.0505e-01,  1.0408e-01, -4.1858e-01,
        -1.4877e-01, -4.9772e-01, -3.5321e-02, -7.3208e-02,  3.9252e-02,
         4.9966e-01, -9.2619e-02, -4.3568e-01,  1.4587e-01, -4.1887e-01,
         4.5389e-01, -1.1517e-01,  1.8625e-01,  6.2355e-04, -9.4567e-02,
         1.7681e-01, -3.4719e