In [5]:
import torch
import torch.nn as nn
from transformers import ViTModel, ViTConfig

# Load the pretrained ViT model
model_name = "google/vit-base-patch16-384"
vit_model = ViTModel.from_pretrained(model_name)

# Define a self-attention classifier
class SelfAttentionClassifier(nn.Module):
    def __init__(self, embed_dim, num_heads, num_classes):
        super(SelfAttentionClassifier, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        attn_output, attn_weights = self.self_attn(x, x, x)
        logits = self.classifier(attn_output)
        return logits, attn_weights

# Attach the self-attention classifier to the ViT model
class ViTWithClassifier(nn.Module):
    def __init__(self, vit_model, num_heads, num_classes):
        super(ViTWithClassifier, self).__init__()
        self.vit_model = vit_model
        self.classifier = SelfAttentionClassifier(
            embed_dim=vit_model.config.hidden_size,
            num_heads=num_heads,
            num_classes=num_classes
        )

    def forward(self, x):
        outputs = self.vit_model(pixel_values=x)
        last_hidden_state = outputs.last_hidden_state  # shape: (batch_size, seq_len, hidden_size)
        # Exclude the class token and pass the patch embeddings to the classifier
        patch_embeddings = last_hidden_state[:, 1:, :]
        logits, attn_weights = self.classifier(patch_embeddings)
        return logits, attn_weights

# Initialize the model with the classifier
num_classes = 10  # Number of target classes for your classification task
num_heads = 8  # Number of attention heads
model = ViTWithClassifier(vit_model, num_heads, num_classes)

  from .autonotebook import tqdm as notebook_tqdm
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-384 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
