In [1]:
!pip install -q -U  datasets

In [2]:
!pip install -q -U torchinfo

In [3]:
!pip install -q -U transformers

In [4]:
!pip install accelerate -U



In [5]:
import torch
from torch import nn
from torchinfo import summary
from datasets import load_dataset
from transformers import ViTFeatureExtractor
import numpy as np
from datasets import load_metric
from transformers import TrainingArguments
from transformers import Trainer


# Model

In [6]:
class PatchEmbedding(nn.Module):
    """Turns a 2D input image into a 1D sequence learnable embedding vector."""
    def __init__(self,
                 in_channels:int=3,
                 patch_size:int=16,
                 embedding_dim:int=768):
        super().__init__()

        self.patcher = nn.Conv2d(in_channels=in_channels,
                                 out_channels=embedding_dim,
                                 kernel_size=patch_size,
                                 stride=patch_size,
                                 padding=0)

        self.flatten = nn.Flatten(start_dim=2,
                                  end_dim=3)

    def forward(self, x):
        image_resolution = x.shape[-1]
        # assert image_resolution % patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}"

        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched)
        return x_flattened.permute(0, 2, 1)

In [7]:
class MultiheadSelfAttentionBlock(nn.Module):
    """Creates a MSA block """
    def __init__(self,
                 embedding_dim:int=768,
                 num_heads:int=12,
                 attn_dropout:float=0): # according to Appendix B.1, dropout isn't used after the qkv-projections
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True)

    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn(query=x,
                                             key=x,
                                             value=x,
                                             need_weights=False)
        return attn_output

In [8]:
class MLPBlock(nn.Module):
    """Creates a MLP layer"""
    def __init__(self,
                 embedding_dim:int=768,
                 mlp_size:int=3072,
                 dropout:float=0.1):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size,
                      out_features=embedding_dim),
            nn.Dropout(p=dropout)
        )

    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

In [9]:
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""
    def __init__(self,
                 embedding_dim:int=768,
                 num_heads:int=12,
                 mlp_size:int=3072,
                 mlp_dropout:float=0.1,
                 attn_dropout:float=0):
        super().__init__()

        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)

        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)

    def forward(self, x):
        x =  self.msa_block(x) + x      #Residual Connection
        x = self.mlp_block(x) + x       #Residual Connection
        return x

In [10]:
class ClassifierHead(nn.Module):
  """Creates a Classifier ViT."""
  def __init__(self,
               embedding_dim: int= 768,
               num_classes:int = 1000):
    super().__init__()
    self.classifier = nn.Sequential(nn.LayerNorm(normalized_shape=embedding_dim),
                                    nn.Linear(in_features=embedding_dim,
                                              out_features=num_classes))

  def forward(self, x):
    x = self.classifier(x)
    return x

In [11]:
class ViT(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    def __init__(self,
                 img_size:int=224,
                 in_channels:int=3,
                 patch_size:int=16,
                 num_transformer_layers:int=12,
                 embedding_dim:int=768,
                 mlp_size:int=3072,
                 num_heads:int=12,
                 attn_dropout:float=0,
                 mlp_dropout:float=0.1,
                 embedding_dropout:float=0.1,
                 num_classes:int=1000):
        super().__init__()

        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

        self.num_patches = (img_size * img_size) // patch_size**2

        self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
                                            requires_grad=True)

        self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
                                               requires_grad=True)

        self.embedding_dropout = nn.Dropout(p=embedding_dropout)
        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                              patch_size=patch_size,
                                              embedding_dim=embedding_dim)
        # Stack 12 (num_transformer_layers) Transformer Encoder blocks
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])

        self.classifier = ClassifierHead(embedding_dim=embedding_dim,
                                         num_classes=num_classes)

    def forward(self, x):

        batch_size = x.shape[0]
        class_token = self.class_embedding.expand(batch_size, -1, -1) # creates copies of the class embedding batch_size times along 0th dim keeping first and second same

        x = self.patch_embedding(x)
        x = torch.cat((class_token, x), dim=1)
        x = self.position_embedding + x
        x = self.embedding_dropout(x)

        x = self.transformer_encoder(x)

        x = self.classifier(x[:, 0]) # Keeping all batch and taking the first element of x from 197
        return x

In [12]:
MNIST_ViT = ViT(img_size=28,
                 in_channels=1,
                 patch_size=7,
                 num_transformer_layers=3,
                 embedding_dim=49,    # patch**2 *  color channel -> 7 **2 *1 = 49
                 mlp_size=196,
                 num_heads=7,
                 attn_dropout=0,
                 mlp_dropout=0.1,
                 embedding_dropout=0.1,
                 num_classes=10)

In [13]:
summary(model=MNIST_ViT,
        input_size=(32, 1, 28, 28),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
ViT (ViT)                                                    [32, 1, 28, 28]      [32, 10]             882                  True
├─PatchEmbedding (patch_embedding)                           [32, 1, 28, 28]      [32, 16, 49]         --                   True
│    └─Conv2d (patcher)                                      [32, 1, 28, 28]      [32, 49, 4, 4]       2,450                True
│    └─Flatten (flatten)                                     [32, 49, 4, 4]       [32, 49, 16]         --                   --
├─Dropout (embedding_dropout)                                [32, 17, 49]         [32, 17, 49]         --                   --
├─Sequential (transformer_encoder)                           [32, 17, 49]         [32, 17, 49]         --                   True
│    └─TransformerEncoderBlock (0)                           [32, 17, 49]         [32, 17, 49]  

# Data

In [14]:
train = load_dataset("mnist", split='train', data_dir='/content/train')
test = load_dataset("mnist", split='test', data_dir='/content/test')



In [15]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)



In [None]:
from torchvision.transforms import Compose, ColorJitter, ToTensor

jitter = Compose(
    [ColorJitter(brightness=0.5, hue=0.5), ToTensor()]
)

In [None]:
def transforms(examples):
    examples["pixel_values"] = [jitter(image.convert("RGB")) for image in examples["image"]]
    return examples

In [31]:
type(train[0]["image"].to_numpy())

AttributeError: ignored

In [16]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

In [17]:
prepared_trainds = train.with_transform(transform)
prepared_testds = test.with_transform(transform)

In [18]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [19]:
metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

  metric = load_metric("accuracy")


In [20]:
training_args = TrainingArguments(
  output_dir="./minist_vit",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
)

In [28]:
prepared_trainds[0]["image"]

ValueError: ignored

In [23]:
trainer = Trainer(
    model=MNIST_ViT,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_trainds,
    # eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

In [24]:
trainer.train()



ValueError: ignored