In [56]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import tqdm

from torchvision.datasets import FashionMNIST
from torchvision import transforms

fashion_mnist_train = FashionMNIST("./FashionMNIST",train=True,download=True,transform=transforms.ToTensor())
fashion_mnist_test = FashionMNIST("./FashionMNIST",train=False,download=True,transform=transforms.ToTensor())

batch_size = 128
train_loader = DataLoader(fashion_mnist_train,batch_size=batch_size, shuffle=True)
test_loader = DataLoader(fashion_mnist_test,batch_size=batch_size, shuffle=False)

In [57]:
from torch import nn, optim

class FlattenLayer(nn.Module):
    def forward(self,x):
        sizes = x.size()
        return x.view(sizes[0], -1)

conv_net = nn.Sequential(
    nn.Conv2d(1, 32, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.Dropout2d(0.25),
    nn.Conv2d(32, 64, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Dropout2d(0.25),
    FlattenLayer(),
)

test_input = torch.ones(1,1,28,28)
conv_output_size = conv_net(test_input).size()[-1]

mlp = nn.Sequential(
    nn.Linear(conv_output_size,200),
    nn.ReLU(),
    nn.BatchNorm1d(200),
    nn.Dropout(0.25),
    nn.Linear(200,10)
)

net = nn.Sequential(
    conv_net,
    mlp
)
    

In [60]:
# 評価のヘルパー関数
def eval_net(net, data_loader, device="cpu"):
    # DropoutやBatchNormを無効化
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        # toメソッドで計算を実行するデバイスに転送する
        x = x.to(device)
        y = y.to(device)
        # 確率が最大のクラスを予測。(3章参照)
        # ここではfowardの計算だけなので自動部分に
        # 必要な余計な処理はoffにする
        with torch.no_grad():
            _, y_pred = net(x).max(1)
        ys.append(y)
        ypreds.append(y_pred)
    # ミニバッチごとの予測結果などを1つにまとめる
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    # 予測精度を計算
    acc = (ys == ypreds).float().sum() / len(ys)
    return acc.item()

# 訓練のヘルパー関数
def train_net(net, train_loader, test_loader,
              optimizer_cls=optim.Adam,
              loss_fn=nn.CrossEntropyLoss(),
              n_iter=10, device="cpu", writer=None):
    train_losses = []
    train_acc = []
    val_acc = []
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        # ネットワークを訓練モードにする
        net.train()
        n = 0
        n_acc = 0
        # 非常に時間がかかるのでtqdmを使用してプログレスバーを出す
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
            total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            h = net(xx)
            loss = loss_fn(h, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
            _, y_pred = h.max(1)
            n_acc += (yy == y_pred).float().sum().item()
        train_losses.append(running_loss / i)
        # 訓練データの予測精度
        train_acc.append(n_acc / n)
        # 検証データの予測精度
        val_acc.append(eval_net(net, test_loader, device))
        # このepochでの結果を表示
        print(epoch, train_losses[-1], train_acc[-1], val_acc[-1], flush=True)
        if writer is not None:
            writer.add_scalar('train_loss', train_losses[-1], epoch)
            writer.add_scalars('accuracy', {
            "train": train_acc[-1],
            "validation": val_acc[-1]
            }, epoch)
    

In [62]:


train_net(net, train_loader, test_loader, n_iter=20, device="cpu")


100%|██████████| 469/469 [00:19<00:00, 23.76it/s]


0 0.1523215719538494 0.9416666666666667 0.9207000136375427


100%|██████████| 469/469 [00:19<00:00, 24.10it/s]


1 0.14783135146443915 0.9445166666666667 0.9205999970436096


100%|██████████| 469/469 [00:19<00:00, 23.94it/s]


2 0.14491991791078168 0.9452166666666667 0.9257000088691711


100%|██████████| 469/469 [00:19<00:00, 24.03it/s]


3 0.141862417714527 0.9462166666666667 0.9236999750137329


100%|██████████| 469/469 [00:19<00:00, 24.03it/s]


4 0.13965138497674823 0.9473166666666667 0.9239000082015991


100%|██████████| 469/469 [00:19<00:00, 23.99it/s]


5 0.1361343022595104 0.9486166666666667 0.9217000007629395


100%|██████████| 469/469 [00:20<00:00, 22.89it/s]


6 0.1346049709802764 0.9497166666666667 0.9162999987602234


100%|██████████| 469/469 [00:19<00:00, 24.01it/s]


7 0.13284392349230936 0.94895 0.9248999953269958


100%|██████████| 469/469 [00:19<00:00, 24.11it/s]


8 0.1287074496046218 0.9521666666666667 0.926800012588501


100%|██████████| 469/469 [00:19<00:00, 23.98it/s]


9 0.12882915394715008 0.9506666666666667 0.9240000247955322


100%|██████████| 469/469 [00:19<00:00, 23.95it/s]


10 0.12582424450984114 0.9526333333333333 0.9246000051498413


100%|██████████| 469/469 [00:19<00:00, 24.10it/s]


11 0.12319397196396549 0.9525 0.9243000149726868


100%|██████████| 469/469 [00:21<00:00, 21.55it/s]


12 0.123963701327801 0.9523 0.9258000254631042


100%|██████████| 469/469 [00:21<00:00, 22.17it/s]


13 0.11899544309394863 0.95465 0.923799991607666


100%|██████████| 469/469 [00:23<00:00, 20.03it/s]


14 0.11595542249707584 0.9557166666666667 0.9229000210762024


100%|██████████| 469/469 [00:19<00:00, 23.86it/s]


15 0.11746128170520194 0.9556 0.9223999977111816


100%|██████████| 469/469 [00:19<00:00, 23.84it/s]


16 0.11759430525872187 0.9558 0.9269000291824341


100%|██████████| 469/469 [00:21<00:00, 22.10it/s]


17 0.11632552545549524 0.9560166666666666 0.9218000173568726


100%|██████████| 469/469 [00:19<00:00, 24.19it/s]


18 0.11836416720468391 0.9563166666666667 0.9251000285148621


100%|██████████| 469/469 [00:20<00:00, 23.20it/s]


19 0.11131187988461083 0.9586833333333333 0.9247999787330627
