# Example 07: PyTorch Fine-tuning

参考
- [PyTorchによるファインチューニングの実装](https://venoda.hatenablog.com/entry/2020/10/18/014516)
- [【PyTorch】畳み込みニューラルネットワーク（CNN）で転移学習・ファインチューニングをする方法（VGG16を題材に添えて）](https://qiita.com/harutine/items/d37656affad4ce7e088d#%E8%BB%A2%E7%A7%BB%E5%AD%A6%E7%BF%92)

## 事前準備

In [1]:
import torch

# GPUが使えるか確認してデバイスを設定
# NOTE: `x = x.to(device) ` とすることで対象のデバイスに切り替え可能
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## ファインチューニング

In [2]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torchvision.datasets import CIFAR10
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

### DataLoader設定

In [3]:
# transformを準備
affine = transforms.RandomAffine((-30, 30), scale=(0.8, 1.2))
flip = transforms.RandomHorizontalFlip(p=0.5)
normalize = transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))  # 平均0、標準偏差1
resize = transforms.Resize((224, 224))                              # VGGの入力に合わせてリサイズ

transform_train = transforms.Compose([
    affine,
    flip,
    resize,
    transforms.ToTensor(),
    normalize
])

transform_test = transforms.Compose([
    resize,
    transforms.ToTensor(),
    normalize
])

In [4]:
# DataLoader作成
cifar10_train = CIFAR10(root='../cache/data', train=True, download=True, transform=transform_train)
cifar10_test = CIFAR10(root='../cache/data', train=False, download=True, transform=transform_test)
cifar10_classes = cifar10_train.classes

In [5]:
# DataLoaderの設定
batch_size = 128
train_loader = DataLoader(cifar10_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(cifar10_test, batch_size=batch_size, shuffle=False)

In [6]:
len(cifar10_train), len(cifar10_test), len(cifar10_classes)

(50000, 10000, 10)

### ネットワークの作成

VGGの学習済みモデルをファインチューニングのベースとする

#### モデル読み込み

In [7]:
## 学習済みの重みを使用
use_pretrained = True

# モデルをロード
net = models.vgg16(pretrained=use_pretrained)



Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


100% 528M/528M [00:48<00:00, 11.5MB/s] 


#### パラメータ固定

学習させるパラメータ以外は勾配計算をなくし、変化しないように設定  
(本来はすべて調整したいが、学習時間減らすために一部固定している)

In [8]:
update_param_names = [
    "classifier.0.weight",
    "classifier.0.bias",
    "classifier.3.weight",
    "classifier.3.bias",
    "classifier.6.weight",
    "classifier.6.bias",
]
for name, param in net.named_parameters():
    param.requires_grad = (name in update_param_names)

# NOTE1: パラメータ名は下記コードで確認可能
#for name, param in net.named_parameters():
#    print(name)

# NOTE2: ファインチューニングではなく転移学習をしたい場合は、以下の処理でパラメータを固定する
#       下記コードでは、全パタメータ固定し、最終出力層だけ差し替えるので、最終層だけ学習することになるが、
#       学習させるレイヤーを増やすかどうかは任意。
#for param in net.parameters():
#    param.requires_grad = False

#### 最終出力層の書き換え

In [9]:
net.classifier[6] = nn.Linear(in_features=4096, out_features=len(cifar10_classes))

net = net.to(device)

In [10]:
net

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

### 学習

In [None]:
%%time
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())

record_loss_train = []
record_loss_test = []

n_epochs = 30
verbose = 1

for epoch in tqdm(range(n_epochs)):
    net.train()
    loss_train = 0.0

    for (x, t) in tqdm(train_loader):
        x, t = x.to(device), t.to(device)
        y = net(x)

        loss = criterion(y, t)
        loss_train += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_train /= len(train_loader)
    record_loss_train.append(loss_train)

    net.eval()
    loss_test = 0.0

    for (x, t) in tqdm(test_loader):
        x, t = x.to(device), t.to(device)
        y = net(x)

        loss = criterion(y, t)
        loss_test += loss.item()

    loss_test /= len(test_loader)
    record_loss_test.append(loss_test)

    if epoch % verbose == 0:
        print(f'epoch: {epoch + 1}, loss_train: {loss_train:.4f}, loss_test: {loss_test:.4f}')

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 1, loss_train: 1.3702, loss_test: 0.8135


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 2, loss_train: 1.2119, loss_test: 0.7544


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 3, loss_train: 1.1769, loss_test: 0.7875


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 4, loss_train: 1.1336, loss_test: 0.7067


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 5, loss_train: 1.1014, loss_test: 0.7188


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 6, loss_train: 1.0929, loss_test: 0.7259


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 7, loss_train: 1.0892, loss_test: 0.7039


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 8, loss_train: 1.0732, loss_test: 0.6686


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 9, loss_train: 1.0600, loss_test: 0.6887


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 10, loss_train: 1.0526, loss_test: 0.6784


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 11, loss_train: 1.0460, loss_test: 0.6822


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 12, loss_train: 1.0289, loss_test: 0.6739


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 13, loss_train: 1.0257, loss_test: 0.6281


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 14, loss_train: 1.0182, loss_test: 0.6440


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 15, loss_train: 1.0070, loss_test: 0.6183


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 16, loss_train: 1.0165, loss_test: 0.6346


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 17, loss_train: 1.0005, loss_test: 0.6476


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 18, loss_train: 0.9935, loss_test: 0.6227


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 19, loss_train: 0.9887, loss_test: 0.6192


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 20, loss_train: 0.9842, loss_test: 0.6079


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 21, loss_train: 0.9761, loss_test: 0.6049


  0%|          | 0/391 [00:00<?, ?it/s]

### 誤差の推移

In [None]:
plt.plot(range(len(record_loss_train)), record_loss_train, label='Train')
plt.plot(range(len(record_loss_test)), record_loss_test, label='Test')
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()

### 正解率算出

In [None]:
correct = 0
total = 0

net.eval()

for j, (x, t) in enumerate(test_loader):
    x, t = x.to(device), t.to(device)
    y = net(x)

    correct += (y.argmax(1) == t).sum().item()
    total += len(x)

print(f"accuracy: {correct / total}")

### 訓練済みモデルを使用した予測

In [None]:
def get_sample_image() -> tuple[torch.Tensor, torch.Tensor]:
    cifar10_test = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    cifar10_loader = DataLoader(cifar10_test, batch_size=1, shuffle=True)

    images, labels = next(iter(cifar10_loader))

    select_index = 0
    return images[select_index], labels[select_index]

def show_image(image: torch.Tensor) -> None:
    plt.imshow(image.permute(1, 2, 0))
    # ラベルとメモリを非表示に設定
    plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    plt.show()

In [None]:
image, label = get_sample_image()
show_image(image)

image = image.unsqueeze(dim=0)

net.eval()
image, label = image.to(device), label.to(device)
y = net(image)

print(f"answer: {cifar10_classes[label]}, predict: {cifar10_classes[y.argmax().item()]}")