[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ultimatemagic79/vit-hands-on/vit.ipynb)

In [1]:
# Google Driveのマウント
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/
[Errno 2] No such file or directory: '/content/drive/MyDrive/Colab Notebook/vit-hands-on'
/content


In [2]:
%cd '/content/drive/MyDrive/Colab Notebooks/vit-hands-on'

/content/drive/MyDrive/Colab Notebooks/vit-hands-on




---



# Vision Transformerを実装してみよう
ViTモデルを実際に実装し， CIFAR10の分類タスクを行ってみましょう．
## 1. データセット準備
画像データセット[CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)をダウンロードして，前処理を行う.


In [8]:
# モジュールのインポート
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split, DataLoader, Subset
import numpy as np

import warnings
warnings.filterwarnings("ignore")

In [4]:
# データロード関数を定義
## 引数batch_sizeはミニバッチの大きさ

def load_data(batch_size):

    # クラスのラベル名
    classes = ('airplane', 'automobile', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    ## 前処理関数の準備
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # ViT に合わせて画像サイズを変更
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # CIFAR10の準備（ローカルにデータがない場合はダウンロードされる）
    # 訓練用データセット
    trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    # 評価用データセット
    testset = CIFAR10(root='./data', train=False, download=True, transform=transform)

    # 訓練用データローダー
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    # 評価用データローダー
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return (trainloader, testloader, classes)

In [6]:
trainloader, testloader, classs = load_data(64)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 16079236.37it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## 2. VisionTransformerの定義
実際に学習するVisionTransformerモデルを定義してみよう．ViTの全体像は以下のようになっており，

* Patch Embedding
* Positional Encoding
* Transformer Encoder
* Head

の大きく四つを組み立てることで作ることができます．
![](https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png)


### Patch Embedding

In [None]:
# Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__():
      pass
    def forward():
      pass

#### 例

In [9]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, C, H, W)
        x = x.flatten(2)  # (B, C, N)
        x = x.transpose(1, 2)  # (B, N, C)
        return x

### Positional Encoding

In [None]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__():
        pass
    def forward():
        pass

#### 例

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].detach()

### Encoder

In [None]:
# Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__():
        pass
    def forward():
        pass

#### 例

In [11]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        src = src + self.attn(self.norm1(src), self.norm1(src), self.norm1(src))[0]
        src = src + self.ff(self.norm2(src))
        return src

### Vision Transformer

In [None]:
# Vision Transformer 全体
class VisionTransformer(nn.Module):
    def __init__():
        pass
    def forward():
        pass

#### 例

In [12]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_heads=12, ff_dim=3072, num_layers=12):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = PositionalEncoding(embed_dim)
        self.transformer = nn.ModuleList([
            TransformerEncoder(embed_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])
        self.head = nn.Linear(embed_dim, 1000)  # assuming 1000 classes

    def forward(self, x):
        x = self.patch_embed(x)
        b, n, _ = x.shape
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_embed(x)
        for layer in self.transformer:
            x = layer(x)
        x = self.head(x[:, 0])
        return x

## 3. モデルの学習

In [None]:
# 訓練関数
def train(model, data_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)

        # モデルで予測
        outputs = model(images)

        # 損失計算
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # 逆伝播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 正解数を計算
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()

    avg_loss = total_loss / len(data_loader)
    accuracy = correct / len(data_loader.dataset)
    return avg_loss, accuracy

# 評価関数
def evaluate(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()

    avg_loss = total_loss / len(data_loader)
    accuracy = correct / len(data_loader.dataset)
    return avg_loss, accuracy


In [None]:
# 定義
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
# 訓練
num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_accuracy = train(model, trainloader, criterion, optimizer, device)
    test_loss, test_accuracy = evaluate(model, testloader, criterion, device)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, '
          f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
