# ViT:Vision-Transformer

reference: [Transfromers](https://github.com/huggingface/transformers)

Notebook Author: [xiaodongguaAIGC](https://github.com/dhcode-cpp)

Author: xiaodongguaAIGC

github: dhcode-cpp

gmail: dhcode95@gmail.com

![ViT](./images/ViT.png)

In [2]:
import collections.abc
from collections import OrderedDict
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# config 

In [3]:
class ViTConfig():
    model_type = "vit"

    def __init__(
        self,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        image_size=224,
        patch_size=16,
        num_channels=3,
        qkv_bias=True,
        encoder_stride=16,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.qkv_bias = qkv_bias
        self.encoder_stride = encoder_stride

hidden_size = 64 
num_attention_heads= 8
num_hidden_layers = 2
intermediate_size = 256
image_size = 224
patch_size = 16

config = ViTConfig(hidden_size = hidden_size,
                  num_attention_heads= num_attention_heads,
                  num_hidden_layers = num_hidden_layers,
                  intermediate_size = intermediate_size,
                  image_size = image_size,
                  patch_size = patch_size)
print(config.__dict__)

# Dummy Dataset

In [4]:
batch_size = 8
num_channels = 3
x_src = torch.randn(batch_size, num_channels, image_size, image_size)
x_src.shape

torch.Size([8, 3, 224, 224])

# Model

## ViT Embedding

In [5]:
class ViTPatchEmbeddings(nn.Module):
    '''
    input:(batch_size, num_channels, height, width)
    output:(batch_size, seq_length, hidden_size)
    '''
    
    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size

        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        # channel -> hidden size
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) 

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        # embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        projection = self.projection(pixel_values) # 8, 64, 16, 16
        flatten = projection.flatten(2)            # 8, 64, 196
        embeddings = flatten.transpose(1, 2)       # 8, 196, 64
        return embeddings

ViT_patch = ViTPatchEmbeddings(config)
print(ViT_patch)

In [7]:
print(ViT_patch.projection)
y_patch = ViT_patch(x_src)

print(config.hidden_size)
print(x_src.shape)
print(y_patch.shape)

In [8]:
# 每16*16 窗口的patch，算一次卷积
min_input = torch.randn(batch_size, num_channels, 16, 16)
min_output = ViT_patch(min_input)
print(min_output.shape)


min_input = torch.randn(batch_size, num_channels, 16, 32)
min_output = ViT_patch(min_input)
print(min_output.shape)


min_input = torch.randn(batch_size, num_channels, 32, 32)
min_output = ViT_patch(min_input)
print(min_output.shape)

In [9]:
class ViTEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
    """

    def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
        super().__init__()

        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
        self.patch_embeddings = ViTPatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.config = config


    def forward(
        self,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        embeddings = self.patch_embeddings(pixel_values)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1) 
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
        embeddings = embeddings + self.position_embeddings
        embeddings = self.dropout(embeddings)

        return embeddings
        
ViT_embedding = ViTEmbeddings(config, use_mask_token=False)
print(ViT_embedding)

In [10]:
print(ViT_embedding.cls_token.shape)
print(ViT_embedding.patch_embeddings.projection.weight.shape) 
print(ViT_embedding.position_embeddings.shape) # 196+1

x_embd = ViT_embedding.patch_embeddings(x_src)
x_cls_embd = ViT_embedding(x_src)

print(x_src.shape)
print(x_embd.shape)
print('add cls token:', x_cls_embd.shape)

print('cls token expand')
print(ViT_embedding.cls_token.expand(batch_size, -1, -1).shape) # 1, 1, 64 -> 8, 1, 64

## ViT: Attention

In [11]:
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, e=1e-12):
        batch_size, head, seq_len, head_dim = k.size()
        k_t = k.transpose(2,3)
        score = (q @ k_t) / math.sqrt(head_dim) 
        if mask is not None:
            score = score.masked_fill(mask == 0, -10000)
        score = self.softmax(score)
        v = score @ v
        return v, score

In [12]:
class ViTSelfAttention(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        # self.output = nn.Linear(config.hidden_size, config.hidden_size)
        # self.dropout_output = nn.Dropout(config.hidden_dropout_prob)

        self.attention = ScaleDotProductAttention()
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        
        batch_size, seq_len, _ = hidden_states.shape
        q, k, v = self.query(hidden_states), self.key(hidden_states), self.value(hidden_states)
        q = q.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size)
        k = k.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size)
        v = v.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size)

        out,score = self.attention(q,k,v,head_mask)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        # attention_probs = self.dropout(attention_probs)

        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.all_head_size)
        # out = self.dropout_output(self.output(out))
        
        outputs = (out, score) if output_attentions else (out,)

        return outputs
        
ViT_attention = ViTSelfAttention(config)
print(ViT_attention)

In [13]:
x_attn = ViT_attention(x_cls_embd)[0]
print(x_cls_embd.shape)
print(x_attn.shape)

In [14]:
class ViTSelfOutput(nn.Module):
    """
    The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)

        return hidden_states
ViT_selfout = ViTSelfOutput(config)
print(ViT_selfout)

In [15]:
x_selfout = ViT_selfout(x_attn, None)
print(x_selfout.shape)

In [16]:
class ViTAttention(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.attention = ViTSelfAttention(config)
        self.output = ViTSelfOutput(config)
        # self.pruned_heads = set()

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)
        self_outputs[0]
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them 
        return outputs
        
ViT_attention = ViTAttention(config)
print(ViT_attention)

In [17]:
print(x_cls_embd.shape)
x_attn = ViT_attention(x_cls_embd)
print(x_attn[0].shape)

## ViT: ffn

### ViT: activation

gelu & gelu difference

In [18]:
class GELUActivation(nn.Module):
    """
    Original Implementation of the GELU activation function in Google BERT repo when initially created. For
    information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
    torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
    Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, use_gelu_python: bool = False):
        super().__init__()
        if use_gelu_python:
            self.act = self._gelu_python
        else:
            self.act = nn.functional.gelu

    def _gelu_python(self, input: torch.Tensor) -> torch.Tensor:
        return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.act(input)
gelu = GELUActivation()
dummy_tensor = torch.randn(1,2,3)
print(dummy_tensor)
activate_tensor = gelu(dummy_tensor)
print(activate_tensor)

In [19]:
class ClassInstantier(OrderedDict):
    def __getitem__(self, key):
        content = super().__getitem__(key)
        cls, kwargs = content if isinstance(content, tuple) else (content, {})
        return cls(**kwargs)
ACT2CLS = {
    "gelu": GELUActivation,
    "relu": nn.ReLU,
    "sigmoid": nn.Sigmoid,
    "silu": nn.SiLU,
    "swish": nn.SiLU,
}
ACT2FN = ClassInstantier(ACT2CLS)

In [20]:
class ViTIntermediate(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states

ViT_ffn = ViTIntermediate(config)
print(ViT_ffn)

In [21]:
x_ffn = ViT_ffn(x_attn[0])
print(x_ffn.shape)

In [22]:
class ViTOutput(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)

        hidden_states = hidden_states + input_tensor # resnet

        return hidden_states
ViT_down = ViTOutput(config)
print(ViT_down)

In [23]:
x_down = ViT_down(x_ffn, x_attn[0])
print(x_down.shape)

## ViT: Layer

In [24]:
class ViTLayer(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        # self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = ViTAttention(config)
        self.intermediate = ViTIntermediate(config)
        self.output = ViTOutput(config)
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # in ViT, layernorm is applied before self-attention
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # first residual connection
        hidden_states = attention_output + hidden_states

        # in ViT, layernorm is also applied after self-attention
        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)

        # second residual connection is done here
        layer_output = self.output(layer_output, hidden_states)

        outputs = (layer_output,) + outputs

        return outputs
ViT_layer = ViTLayer(config)
print(ViT_layer)

In [25]:
x_layer = ViT_layer(x_cls_embd)
print(x_layer[0].shape)

## ViT: Encoder

In [26]:
from dataclasses import dataclass
@dataclass
class BaseModelOutput(OrderedDict):
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

In [27]:
class ViTEncoder(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> Union[tuple, BaseModelOutput]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None

            # if self.gradient_checkpointing and self.training:
            #     layer_outputs = self._gradient_checkpointing_func(
            #         layer_module.__call__,
            #         hidden_states,
            #         layer_head_mask,
            #         output_attentions,
            #     )
            # else:
            layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
ViT_encoder = ViTEncoder(config)
print(ViT_encoder)

In [28]:
output = ViT_encoder(x_cls_embd, output_attentions=True, output_hidden_states=True)
print(output.last_hidden_state.shape)
print(output.hidden_states[0].shape)
print(output.attentions[0].shape)

## ViT: PretrainedModel

In [29]:
class ViTPreTrainedModel(nn.Module):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = ViTConfig
    base_model_prefix = "vit"
    main_input_name = "pixel_values"
    supports_gradient_checkpointing = True
    _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
    _supports_sdpa = True

    def __init__(self, config: ViTConfig, *inputs, **kwargs):
        super().__init__()
        self.config = config
        self.warnings_issued = {}
        

    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
            # `trunc_normal_cpu` not implemented in `half` issues
            module.weight.data = nn.init.trunc_normal_(
                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
            ).to(module.weight.dtype)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, ViTEmbeddings):
            module.position_embeddings.data = nn.init.trunc_normal_(
                module.position_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config.initializer_range,
            ).to(module.position_embeddings.dtype)

            module.cls_token.data = nn.init.trunc_normal_(
                module.cls_token.data.to(torch.float32),
                mean=0.0,
                std=self.config.initializer_range,
            ).to(module.cls_token.dtype)
            
    def post_init(self):
        pass
        # self._init_weights()

## ViT: Model

In [30]:
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
        """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
        assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
        head_mask = head_mask.to(dtype=self.dtype)  # switch to float if need + fp16 compatibility
        return head_mask

def get_head_mask(
        self, head_mask: Optional[torch.Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
    ) -> torch.Tensor:
        if head_mask is not None:
            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
            if is_attention_chunked is True:
                head_mask = head_mask.unsqueeze(-1)
        else:
            head_mask = [None] * num_hidden_layers
        return head_mask

In [35]:
@dataclass
class BaseModelOutputWithPooling(nn.Module):
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class ViTPooler(nn.Module):
    def __init__(self, config: ViTConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [36]:
class ViTModel(ViTPreTrainedModel):
# class ViTModel(nn.Module):
    def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
        super().__init__(config)
        self.config = config

        self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
        self.encoder = ViTEncoder(config)

        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.pooler = ViTPooler(config) if add_pooling_layer else None

        # Initialize weights and apply final processing
        self.post_init()
        self.config.output_attentions = False
        self.config.output_hidden_states = False
        self.config.use_return_dict = True

        # def get_head_mask(
        #     self, head_mask: Optional[torch.Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
        # ) -> torch.Tensor:
        #     if head_mask is not None:
        #         head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
        #         if is_attention_chunked is True:
        #             head_mask = head_mask.unsqueeze(-1)
        #     else:
        #         head_mask = [None] * num_hidden_layers
        #     return head_mask
        self.get_head_mask = get_head_mask

    def get_input_embeddings(self) -> ViTPatchEmbeddings:
        return self.embeddings.patch_embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        # head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        head_mask = [None] * self.config.num_hidden_layers

        embedding_output = self.embeddings(
            pixel_values, 
            # bool_masked_pos=bool_masked_pos, 
            # interpolate_pos_encoding=interpolate_pos_encoding
        )
        # print(embedding_output.shape)

        encoder_outputs = self.encoder(
            embedding_output,
            # head_mask=head_mask,
            # output_attentions=output_attentions,
            # output_hidden_states=output_hidden_states,
            # return_dict=return_dict,
        )
        # print(encoder_outputs)
        # sequence_output = encoder_outputs[0]
        sequence_output = encoder_outputs.last_hidden_state
        # print(sequence_output.shape)
        sequence_output = self.layernorm(sequence_output)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
            return head_outputs + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

ViT = ViTModel(config)
print(ViT)

In [37]:
output = ViT(x_src,)
output.last_hidden_state.shape

torch.Size([8, 197, 64])

## ViT: Calssifier

In [38]:
class ViTForImageClassification(ViTPreTrainedModel):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vit = ViTModel(config, add_pooling_layer=False)

        # Classifier head
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

        # Initialize weights and apply final processing
        self.post_init()
        self.config.use_return_dict=True

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        # sequence_output = outputs[0]
        sequence_output = outputs.last_hidden_state

        logits = self.classifier(sequence_output[:, 0, :]) # 第0号token向量进行分类

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            # if self.config.problem_type is None:
            #     if self.num_labels == 1:
            #         self.config.problem_type = "regression"
            #     elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
            #         self.config.problem_type = "single_label_classification"
            #     else:
            #         self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            # elif self.config.problem_type == "multi_label_classification":
            #     loss_fct = BCEWithLogitsLoss()
            #     loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return (
            loss,
            logits,
            outputs.hidden_states,
            outputs.attentions,
        )
config.problem_type = "single_label_classification"
config.num_labels = 2
ViT_classifier = ViTForImageClassification(config)
print(ViT_classifier)

In [39]:
print(ViT_classifier.classifier.weight.shape)
y = ViT_classifier(x_src)
# print(y)
# print(y[0].shape)
print(y[1].shape)

# Loss

In [40]:
labels = torch.randint(0, 2, (batch_size,))
print(labels.shape)
loss = ViT_classifier(x_src, labels = labels)[0]
print(loss)

# Dataset Loader

ref:[transformers official tutorials](https://huggingface.co/docs/transformers/tasks/image_classification)

In [41]:
from datasets import load_dataset

# keremberke/pokemon-classification
# datasets = load_dataset("JannikB/food101_sample_n100", split="train[:5000]")
# datasets_eval = load_dataset("JannikB/food101_sample_n100", split="validation[:1024]")
datasets = load_dataset("keremberke/pokemon-classification",name="full", split="train")
datasets_eval = load_dataset("keremberke/pokemon-classification", name="full", split="test")

In [42]:
print(datasets[0])
print(datasets)

In [43]:
# datasets = datasets.train_test_split(test_size=0.2)
# print(datasets)

In [44]:
labels = datasets.features["labels"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
# print(label2id)
# print(id2label)
print(len(label2id))
print(id2label['1'])
print(label2id['Goldeen'])

In [45]:
# from transformers import AutoImageProcessor

# checkpoint = "google/vit-base-patch16-224-in21k"
# image_processor = AutoImageProcessor.from_pretrained(checkpoint)
# print(image_processor)

In [46]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

# mean=image_processor.image_mean

# normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
size = (
    224, 224
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [47]:
datasets_transforms = datasets.with_transform(transforms)
datasets_eval = datasets_eval.with_transform(transforms)

In [48]:
print(datasets_transforms[0]['pixel_values'].shape)
print(datasets_transforms[0]['labels'])

In [49]:
data = datasets_transforms.__getitem__(0)
data['pixel_values'].shape
data['labels']

57

In [50]:
def collate_fn(batch):
    # print(batch[0])
    # print(batch['pixel_values'][0].shape)
    images = [ item['pixel_values']  for item in batch ]
    labels = [ item['labels']  for item in batch ]
    # print(labels)
    images = torch.stack(images, 0)
    labels = torch.tensor(labels, dtype=torch.int32)
    return (images, labels)

In [51]:
from torch.utils.data import DataLoader
batch_size = 16
dataloader = DataLoader(datasets_transforms, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dataloader_eval = DataLoader(datasets_eval, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [52]:
for i, batch in enumerate(dataloader):
    print(i)
    print(batch[0].shape)
    print(batch[1].shape)
    break

# Trainer

[ref](https://learn.microsoft.com/en-us/windows/ai/windows-ml/tutorials/pytorch-train-model)

In [53]:
 # mps for mac, if you haven't mps or cuda, use 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps") # mps for mac 
# device = 'cuda:0'

config.hidden_size = 256 
config.num_attention_heads= 8
config.num_hidden_layers = 8
config.intermediate_size = 512
config.image_size = 224
config.patch_size = 16

config.problem_type = "single_label_classification"
config.num_labels = len(label2id)

model = ViTForImageClassification(config)
model.to(device)
print(model)

In [54]:
import torch.optim
# loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)

In [55]:
def testAccuracy():
    model.eval()
    accuracy = 0.0
    total = 0.0
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    with torch.no_grad():
        # for data in dataloader_eval:
        for data in dataloader_eval:
            images, labels = data
            # print(images.shape)
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs[1], 1)
            # print(outputs[1].shape)
            # print(predicted)
            # print(labels)
            total += labels.size(0)
            accuracy += (predicted == labels.to(device)).sum().item()
            # break
    accuracy = (100.0 * accuracy / total)
    return(accuracy)
testAccuracy()   

0.0

In [56]:
# num_epochs = 1
import tqdm #
num_epochs = 10
best_accuracy = 0.0
for epoch in range(num_epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    running_acc = 0.0

    total_steps = len(dataloader)
    progress_bar = tqdm.tqdm(enumerate(dataloader), total=total_steps, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    # for step, batch in enumerate(progress_bar):
    for step, batch in progress_bar:
        optimizer.zero_grad()
        loss = model(batch[0].to(device), labels = batch[1].to(device))[0]
        loss.backward()
        optimizer.step()
        running_loss += loss.item()     # extract the loss value
        if step % 10 == 0:    
            # print('[%d, %3d] loss: %f' %
            #       (epoch + 1, step, running_loss / 10))
            
            progress_bar.set_postfix(mse=float(running_loss/10))
            running_loss = 0.0
            
    accuracy = testAccuracy()
    progress_bar.set_postfix(acc=float(accuracy))
    print('For epoch', epoch+1,'the test accuracy over the whole test set is %f %%' % (accuracy))
    
    # we want to save the model if the accuracy is the best
    # if accuracy > best_accuracy:
    #     saveModel()
    #     best_accuracy = accuracy

Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████| 305/305 [01:14<00:00,  4.07it/s, mse=4.2]


Epoch 2/10:   1%|▉                                                                   | 4/305 [00:01<01:18,  3.83it/s, mse=0.389]


# Other

## Pooler out

In [57]:
class Pooler(nn.Module):
    def __init__(self, config: ViTConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0] # seq position 0 token is [cls] token
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

pooler = Pooler(config)

hidden_state = torch.randn(1, 197, config.hidden_size)
print(hidden_state.shape)
pooler_output = pooler(hidden_state)
print(pooler_output.shape)

## position interploration

In [61]:
# implemention in 
# class ViTEmbeddings(nn.Module): 
#     def interpolate_pos_encoding(....):

import copy

width_scale = 3
height_scale = 2
embedding = torch.randn(1, 196 * height_scale * width_scale + 1, 256) # other embedding

print(model.vit.embeddings.position_embeddings.shape)
basic_pe = copy.deepcopy(model.vit.embeddings.position_embeddings).to('cpu')

height = config.image_size * height_scale
width = config.image_size * width_scale

num_patches = embedding.shape[1] - 1
num_positions = basic_pe.shape[1] - 1
if num_patches == num_positions and height == width:
    print(basic_pe.shape)
    
class_pos_embed = basic_pe[:, 0]
patch_pos_embed = basic_pe[:, 1:]
dim = embeddings.shape[-1]
h0 = height // config.patch_size
w0 = width // config.patch_size

# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

print(patch_pos_embed.shape)

patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)


print(patch_pos_embed.shape)


patch_pos_embed = nn.functional.interpolate(
    patch_pos_embed,
    scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
    mode="bicubic",
    align_corners=False,
)


print(patch_pos_embed.shape)


patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
print(patch_pos_embed.shape)
print(class_pos_embed.shape)

new_pe = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

print(new_pe.shape)

In [62]:

original_embed = torch.randn(1,4,2,2) # 4是位置编码特征维度， (2,2)表示 4个patch
print(original_embed.shape)
print(num_positions)


print(h0) # original 14
print(w0)
print(h0 / math.sqrt(num_positions))
print(w0 / math.sqrt(num_positions))

interpolate_embed = nn.functional.interpolate(
    original_embed,
    scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
    mode="bilinear", # bilinear, bicubic
    align_corners=False,
)
print(interpolate_embed.shape)

print(original_embed)
print(interpolate_embed)

# 
print(original_embed[0,0,...]) # 取第1号的向量， 取第0个位置编码特征值， (2,2)表示 4个patch
print(interpolate_embed[0,0,...]) # 取第1号的向量， 取第0个位置编码特征值， (4,6)表示 25个patch

## head mask

In [63]:
# 可以将特定头进行mask, 个人理解本质上与dropout作用相近
import torch

batch_size = 1  # 批次大小
num_patches = 2  # 图像patch数量
num_heads = 3   # 注意力头数量
dim = 4  # 每个头的维度
num_layers = 5  # 层数

head_mask = torch.rand(num_layers, num_heads)  # 随机生成mask
head_mask = (head_mask > 0.5).float()  # 将mask二值化
# print(head_mask)
print(head_mask.shape)

# 假设我们有一个attention输出
attention_output = torch.rand(batch_size, num_patches, num_heads, dim)
print(attention_output.shape)

# 扩展head_mask以匹配attention输出的形状
expanded_mask = head_mask[None, None, :, :, None].expand(batch_size, num_patches, -1, -1, dim)
print(expanded_mask.shape)

# 应用mask
masked_attention_output = attention_output * expanded_mask[:, :, 0]  # 假设我们只对第一层进行mask