In [None]:
%pip install -q transformers mediapy

In [None]:
import torch
import mediapy as media

from transformers import BeitFeatureExtractor
from transformers import BeitForMaskedImageModeling, BeitForImageClassification
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# TODO: base vs. large
keyword = 'base'
# TODO: simple vs. complex
feature_style = 'complex'

model_name = f'microsoft/beit-{keyword}-patch16-224-pt22k'
# model_name = f'microsoft/beit-{keyword}-patch16-224-pt22k-ft22k'
# model_name = f'microsoft/beit-{keyword}-patch16-224'

is_finetuned = not model_name.endswith('pt22k')

if is_finetuned:
  arch = BeitForImageClassification
else:
  arch = BeitForMaskedImageModeling

feature_extractor = BeitFeatureExtractor.from_pretrained(model_name)
model = arch.from_pretrained(model_name)

# Delete the classification head
if is_finetuned:
  model.classifier = torch.nn.Identity()
else:
  model.lm_head = torch.nn.Identity()  
  # HuggingFace suggests the classification head may include `model.layernorm`.
  # However, I do not follow this view and avoid setting this layer to Identity.

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)

if is_finetuned:
  pooled_output = outputs.logits
else:
  hidden_states = outputs.logits

  # The final hidden state of the [CLS] token
  cls_token = hidden_states[:, 0]
  # The final hidden states of the patch tokens
  patch_tokens = hidden_states[:, 1:, :]

  if feature_style == 'simple':
    pooled_output = cls_token
  else:
    # Mean pool
    # https://github.com/huggingface/transformers/blob/bcc3f7b6560c1ed427f051107c7755956a27a9f2/src/transformers/models/beit/modeling_beit.py#L664-L670
    mean_token = patch_tokens.mean(1)
    # Concatenate, as with ViT-Base for the eval_linear of DINO
    # https://github.com/facebookresearch/dino/blob/499d9e2b3d903355f67e86f9def06bccb4222b1f/eval_linear.py#L256-L260
    pooled_output = torch.cat([cls_token, mean_token], dim=-1)

print(outputs.keys())
for i in outputs.values():
  print(i.size())

if not is_finetuned:
  print(pooled_output.size())