# VisionTransformerモデルの学習

## 必要なライブラリをインポート

In [23]:
import os
import sys
sys.path.append('../tools')
from _make_dataloader import _make_dataloader
import torch
import torchvision.models as models
from torchvision.models.vision_transformer import ViT_B_16_Weights
import torch.nn as nn
import torch.optim as optim
import onnx
import onnxruntime as ort
import numpy as np

## データの準備

In [24]:
# archive内の画像を前処理して、データローダーを作成
train_loader, val_loader, test_loader = _make_dataloader()

画像の読み込み完了
0〜5000の範囲でバッチ処理完了
5001〜10000の範囲でバッチ処理完了
10001〜15000の範囲でバッチ処理完了
15001〜の範囲でバッチ処理完了
画像の前処理完了
DataLoader作成完了
train_loader:  547
val_loader:  97
test_loader:  114


## モデルの準備

In [6]:
# モデルの定義
class ModelClass(nn.Module):
    def __init__(self):
        super(ModelClass, self).__init__()
        # モデルのロード
        self.model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
        # クラス数の変更（10クラス分類）
        num_classes = 10
        self.model.heads.head = torch.nn.Linear(self.model.hidden_dim, num_classes)

    def forward(self, x):
        return self.model(x)
    

model = ModelClass()

# 出力層のみ学習させる
for param in model.model.parameters():
    param.requires_grad = False

for param in model.model.heads.parameters():
    param.requires_grad = True

# 損失関数の定義
criterion = nn.CrossEntropyLoss()

# 最適化手法の定義
optimizer = optim.Adam(model.parameters(), lr=0.0001)


# GPUの準備
device = torch.device('mps')
model.to(device)

ModelClass(
  (model): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
          (ln_1): Laye

## モデルを学習させる

In [7]:
num_epochs = 3

# モデルの学習
model.train()

for epoch in range(num_epochs):
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if i % 100 == 0:
            print(f'epoch: {epoch}, batch: {i}, loss: {loss.item()}')
    accuracy = 100 * correct / total
    print(f'epoch: {epoch}, loss: {loss.item()}, accuracy: {accuracy}%')

epoch: 0, batch: 0, loss: 2.3399062156677246
epoch: 0, batch: 100, loss: 1.2661019563674927
epoch: 0, batch: 200, loss: 0.9616289734840393
epoch: 0, batch: 300, loss: 0.496268093585968
epoch: 0, batch: 400, loss: 0.3515598475933075
epoch: 0, batch: 500, loss: 0.260417103767395
epoch: 0, loss: 0.2682090997695923, accuracy: 88.14751286449399%
epoch: 1, batch: 0, loss: 0.35107049345970154
epoch: 1, batch: 100, loss: 0.14413659274578094
epoch: 1, batch: 200, loss: 0.17832806706428528
epoch: 1, batch: 300, loss: 0.23385021090507507
epoch: 1, batch: 400, loss: 0.2473762184381485
epoch: 1, batch: 500, loss: 0.17811515927314758
epoch: 1, loss: 0.1396082639694214, accuracy: 97.18124642652944%
epoch: 2, batch: 0, loss: 0.08862470090389252
epoch: 2, batch: 100, loss: 0.23910005390644073
epoch: 2, batch: 200, loss: 0.12897047400474548
epoch: 2, batch: 300, loss: 0.19811570644378662
epoch: 2, batch: 400, loss: 0.19043220579624176
epoch: 2, batch: 500, loss: 0.1310405433177948
epoch: 2, loss: 0.0989

## モデルの保存

In [8]:
os.makedirs('../models', exist_ok = True)
torch.save(model.state_dict(),'../models/vit_cls.pth')

## 推論させる

In [9]:
# モデルのインスタンス化
model = ModelClass()

model.load_state_dict(torch.load('../models/vit_cls.pth'))


# モデルをデバイスに移動
model.to(device)

# モデルの評価
model.eval()

batch_size = 32
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        
        # 予測ラベル取得
        _, predicted = torch.max(outputs, 1)
        
        # 正解数のカウント
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
# 精度計算
accuracy = 100 * correct/total

print(f"Test Accuracy: {accuracy:.2f}%")

  model.load_state_dict(torch.load('../models/vit_cls.pth'))


Test Accuracy: 97.58%
