# VisionTransformerモデルの学習

## 必要なライブラリをインポート

In [14]:
import os
import sys
sys.path.append('../tools')
from make_dataloader import make_dataloader
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import onnx
import onnxruntime as ort
import numpy as np

## データの準備

In [2]:
# archive内の画像を前処理して、データローダーを作成
train_loader, val_loader, test_loader = make_dataloader()

画像の読み込み完了
0〜5000の範囲でバッチ処理完了
5001〜10000の範囲でバッチ処理完了
10001〜15000の範囲でバッチ処理完了
15001〜の範囲でバッチ処理完了
画像の前処理完了
DataLoader作成完了
train_loader:  547
val_loader:  97
test_loader:  114


## モデルの準備

In [3]:
# ViTのロード (例: ViT-B/16)
model = models.vit_b_16(pretrained=True)

# クラス数の変更（10クラス分類）
num_classes = 10
model.heads.head = torch.nn.Linear(model.hidden_dim, num_classes)

# 出力層のみ学習させる
for param in model.parameters():
    param.requires_grad = False

for param in model.heads.parameters():
    param.requires_grad = True

# 損失関数の定義
criterion = nn.CrossEntropyLoss()

# 最適化手法の定義
optimizer = optim.Adam(model.parameters(), lr=0.0001)


# GPUの準備
device = torch.device('mps')
model.to(device)



VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

## モデルを学習させる

In [4]:
num_epochs = 5

# モデルの学習
model.train()

for epoch in range(5):
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if i % 100 == 0:
            print(f'epoch: {epoch}, batch: {i}, loss: {loss.item()}')
    accuracy = 100 * correct / total
    print(f'epoch: {epoch}, loss: {loss.item()}, accuracy: {accuracy}%')

epoch: 0, batch: 0, loss: 2.41745662689209
epoch: 0, batch: 100, loss: 1.261074185371399
epoch: 0, batch: 200, loss: 0.8919990062713623
epoch: 0, batch: 300, loss: 0.5796253681182861
epoch: 0, batch: 400, loss: 0.4079235792160034
epoch: 0, batch: 500, loss: 0.2765389680862427
epoch: 0, loss: 0.5733827948570251, accuracy: 87.29559748427673%
epoch: 1, batch: 0, loss: 0.2306126058101654
epoch: 1, batch: 100, loss: 0.3346799612045288
epoch: 1, batch: 200, loss: 0.24965627491474152
epoch: 1, batch: 300, loss: 0.1704227179288864
epoch: 1, batch: 400, loss: 0.18623177707195282
epoch: 1, batch: 500, loss: 0.1392485499382019
epoch: 1, loss: 0.16541068255901337, accuracy: 97.03259005145797%
epoch: 2, batch: 0, loss: 0.3399019241333008
epoch: 2, batch: 100, loss: 0.20481735467910767
epoch: 2, batch: 200, loss: 0.3353480398654938
epoch: 2, batch: 300, loss: 0.06539370119571686
epoch: 2, batch: 400, loss: 0.09884653240442276
epoch: 2, batch: 500, loss: 0.06973402947187424
epoch: 2, loss: 0.36529296

## モデルの保存

In [8]:
os.makedirs('../models', exist_ok = True)
torch.save(model.state_dict(),'../models/vit_cls.pth')

## 推論させる

In [13]:
# モデルのロード
model = models.vit_b_16(weights=None)

# クラス数の変更（10クラス分類）
num_classes = 10
model.heads.head = torch.nn.Linear(model.hidden_dim, num_classes)

model.load_state_dict(torch.load('../models/vit_cls.pth'))

# モデルをデバイスに移動
model.to(device)

# モデルの評価
model.eval()

batch_size = 32
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        
        # 予測ラベル取得
        _, predicted = torch.max(outputs, 1)
        
        # 正解数のカウント
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
# 精度計算
accuracy = 100 * correct/total

print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 97.66%


# 変換したONNX形式のモデルで再度推論

onnx-optimizer-tool > convert_to_onnx.pyを使用して変換した`.onnx`ファイルを使う

In [None]:
# モデルのロード