In [15]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import tqdm

from torchvision import models

# 事前学習済みのresnet18をロード
net = models.resnet18(pretrained=True)

# すべてのパラメータを微分対象外にする
for p in net.parameters():
    p.requires_grad=False
    
# 最後の線形層を付け替える
fc_input_dim = net.fc.in_features
net.fc = nn.Linear(fc_input_dim, 2)



In [16]:
from torchvision.datasets import ImageFolder
from torchvision import transforms

# ImageFolder関数を使用してDatasetを作成する
train_imgs = ImageFolder(
    "./train/",
    transform=transforms.Compose([
      transforms.RandomCrop(224),
      transforms.ToTensor()]
))
test_imgs = ImageFolder(
    "./test/",
    transform=transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor()]
))

# DataLoaderを作成
train_loader = DataLoader(
    train_imgs, batch_size=32, shuffle=True)
test_loader = DataLoader(
    test_imgs, batch_size=32, shuffle=False)

In [17]:
def eval_net(net, data_loader, device="cpu"):
    # DropoutやBatchNormを無効化
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        # toメソッドで計算を実行するデバイスに転送する
        x = x.to(device)
        y = y.to(device)
        # 確率が最大のクラスを予測(リスト2.14参照)
        # ここではforward（推論）の計算だけなので自動微分に
        # 必要な処理はoffにして余計な計算を省く
        with torch.no_grad():
            _, y_pred = net(x).max(1)
        ys.append(y)
        ypreds.append(y_pred)
    # ミニバッチごとの予測結果などを1つにまとめる
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    # 予測精度を計算
    acc = (ys == ypreds).float().sum() / len(ys)
    return acc.item()

def train_net(net, train_loader, test_loader,
              only_fc=True,
              optimizer_cls=optim.Adam,
              loss_fn=nn.CrossEntropyLoss(),
              n_iter=10, device="cpu"):
    train_losses = []
    train_acc = []
    val_acc = []
    if only_fc:
        # 最後の線形層のパラメータのみを、
        # optimizerに渡す
        optimizer = optimizer_cls(net.fc.parameters())
    else:
        optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        # ネットワークを訓練モードにする
        net.train()
        n = 0
        n_acc = 0
        # 非常に時間がかかるのでtqdmを使用してプログレスバーを出す
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
            total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            h = net(xx)
            loss = loss_fn(h, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
            _, y_pred = h.max(1)
            n_acc += (yy == y_pred).float().sum().item()
        train_losses.append(running_loss / i)
        # 訓練データの予測精度
        train_acc.append(n_acc / n)
        # 検証データの予測精度
        val_acc.append(eval_net(net, test_loader, device))
        # このepochでの結果を表示
        print(epoch, train_losses[-1], train_acc[-1],
              val_acc[-1], flush=True)

In [18]:
train_net(net, train_loader, test_loader, n_iter=20, device="cpu")

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|██████████| 23/23 [00:20<00:00,  1.10it/s]


0 0.7120946401899512 0.5842696629213483 0.75


100%|██████████| 23/23 [00:18<00:00,  1.24it/s]


1 0.5713946196165952 0.7317415730337079 0.8500000238418579


100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


2 0.4928287890824405 0.8033707865168539 0.8500000238418579


100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


3 0.45581467991525476 0.7963483146067416 0.8500000238418579


100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


4 0.41439543528990314 0.8300561797752809 0.8666666746139526


100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


5 0.42009292813864624 0.8300561797752809 0.8666666746139526


100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


6 0.39981270920146594 0.8300561797752809 0.8666666746139526


100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


7 0.3975791030309417 0.851123595505618 0.8666666746139526


100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


8 0.35625700652599335 0.8764044943820225 0.8833333253860474


100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


9 0.37283470549366693 0.8412921348314607 0.8666666746139526


100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


10 0.38331797041676263 0.8623595505617978 0.8666666746139526


100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


11 0.3469416519457644 0.8665730337078652 0.8833333253860474


100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


12 0.33070391890677536 0.8665730337078652 0.8666666746139526


100%|██████████| 23/23 [00:20<00:00,  1.11it/s]


13 0.3393551043488763 0.8595505617977528 0.8833333253860474


100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


14 0.3280466009270061 0.8820224719101124 0.8833333253860474


100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


15 0.3208178004080599 0.8848314606741573 0.8833333253860474


100%|██████████| 23/23 [00:18<00:00,  1.27it/s]


16 0.3222892690788616 0.8721910112359551 0.8666666746139526


100%|██████████| 23/23 [00:20<00:00,  1.10it/s]


17 0.32595928622917697 0.8735955056179775 0.8999999761581421


100%|██████████| 23/23 [00:17<00:00,  1.29it/s]


18 0.2970754991878163 0.8820224719101124 0.8833333253860474


100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


19 0.3330279453234239 0.8665730337078652 0.8833333253860474
