## UNet

[こちらのサイト](https://github.com/yoyoyo-yo/DeepLearningMugenKnock/blob/master/notes_pytorch/ImgSeg/UNet_VOC2012_pytorch.ipynb)を参考にUNetを実装する。UNetは全結合層を持たず、畳み込み層のみで構成されている。左右対称のEncoder-Decoder構造で、Encoderのpoolingを経てダウンサンプリングされた特徴マップをDecoderでアップサンプリングしていく。SegNetとの大きな違いは、Encoderの各層で出力される特徴マップをDecoderの対応する各層の特徴マップに連結(concatenation)するアプローチを導入した点。このアプローチはスキップ接続と呼ばれている。

セグメンテーションではピクセル単位で推論する。ラベルの形状は(-1,128,128)で、各ピクセルは22通りの値となっているので、モデルの出力は(-1,22,128,128)

<img src = 'module.png'>

## import

In [1]:
import os
import time

from tqdm.notebook import tqdm

import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sklearn.model_selection import KFold, GroupKFold, StratifiedKFold

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
import albumentations.pytorch as Ap

import torchvision
import torchvision.models as models

In [12]:
IN_HEIGHT, IN_WIDTH = 128, 128

FOLD = 'KFOLD'
FOLD_N = 2
RANDOM_SEED = 42

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class Encoder(nn.Module):
    def __init__(self, dim1, dim2, stack_num = 2):
        super(Encoder, self).__init__()
        
        module = []
        for i in range(stack_num):
            f = dim1 if i == 0 else dim2
            module.append(nn.Conv2d(f, dim2, kernel_size = 3,
                                    padding = 1, stride = 1))
            module.append(nn.BatchNorm2d(dim2))
            module.append(nn.ReLU())
            
        self.module = nn.Sequential(*module)
        
    def forward(self, x):
        # input: (-1, dim1, H_in, W_in)
        # output: (-1, dim2, H_in, W_in)
        return self.module(x)
    
class UpConv(nn.Module):
    def __init__(self, dim1, dim2):
        super(UpConv, self).__init__()
        self.module = nn.Sequential(
            nn.ConvTranspose2d(dim1, dim2, kernel_size = 2, stride = 2),
            nn.BatchNorm2d(dim2)
        )
    
    def forward(self, x):
        # input: (-1, dim1, H_in, W_in)
        # output: (-1, dim2, H_in*2, W_in*2)
        return self.module(x)
    
class UNet(nn.Module):
    def __init__(self, dim = 64, in_channel = 3, out_channel = 22):
        super(UNet, self).__init__()
        
        self.encoder1 = Encoder(in_channel, dim)
        self.encoder2 = Encoder(dim, dim * 2)
        self.encoder3 = Encoder(dim * 2, dim * 4)
        self.encoder4 = Encoder(dim * 4, dim * 8)
        self.encoder5 = Encoder(dim * 8, dim * 16)
        
        self.upconv4 = UpConv(dim * 16, dim * 8)
        self.decoder4 = Encoder(dim * 16, dim * 8)
        
        self.upconv3 = UpConv(dim * 8, dim * 4)
        self.decoder3 = Encoder(dim * 8, dim * 4)
        
        self.upconv2 = UpConv(dim * 4, dim * 2)
        self.decoder2 = Encoder(dim * 4, dim * 2)
        
        self.upconv1 = UpConv(dim * 2, dim)
        self.decoder1 = nn.Sequential(
            Encoder(dim * 2, dim),
            nn.Conv2d(dim, out_channel, kernel_size = 1,
                      padding = 0, stride = 1)
        )
        
    def forward(self, x):
        x_e1 = self.encoder1(x)
        x = F.max_pool2d(x_e1, 2, stride = 2, padding = 0)
        x_e2 = self.encoder2(x)
        x = F.max_pool2d(x_e2, 2, stride = 2, padding = 0)
        x_e3 = self.encoder3(x)
        x = F.max_pool2d(x_e3, 2, stride = 2, padding = 0)
        x_e4 = self.encoder4(x)
        x = F.max_pool2d(x_e4, 2, stride = 2, padding = 0)
        x = self.encoder5(x)
        
        x = self.upconv4(x)
        x = torch.cat([x, x_e4], dim = 1)
        x = self.decoder4(x)
        
        x = self.upconv3(x)
        x = torch.cat([x, x_e3], dim = 1)
        x = self.decoder3(x)
        
        x = self.upconv2(x)
        x = torch.cat([x, x_e2], dim = 1)
        x = self.decoder2(x)
        
        x = self.upconv1(x)
        x = torch.cat([x, x_e1], dim = 1)
        x = self.decoder1(x)
        
        return x

## transforms

In [4]:
transforms_train = A.Compose([
    A.Resize(IN_HEIGHT, IN_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(max_pixel_value=255.0, p=1.0),
    Ap.ToTensorV2(p=1.0),
])

transforms_val = A.Compose([
    A.Resize(IN_HEIGHT, IN_WIDTH),
    A.Normalize(max_pixel_value=255.0, p=1.0),
    Ap.ToTensorV2(p=1.0),
])

## Datasets

In [5]:
class VOCDataset(Dataset):
    def __init__(self, Xs, ts, transforms = None, ignore_border = True):
        self.Xs = Xs
        self.ts = ts
        self.transforms = transforms
        self.data_num = len(Xs)
        self.ignore_border = ignore_border
        
    def __len__(self):
        return self.data_num
    
    def __getitem__(self, idx):
        x = self.Xs[idx]
        t = self.ts[idx]
        
        if self.ignore_border:
            t[t ==  255] = 0
            
        if self.transforms:
            transformed = self.transforms(image = x, mask = t)
            x = transformed['image']
            t = transformed['mask']
            
        return x, t.long()

In [6]:
def show_sample(Xs, show_num = 8, name = 'input'):
    # (8, 3, H, W)
    
    Xs = Xs.detach().cpu().numpy().transpose(0, 2, 3, 1)
    # (8, H, W, 3)
    Xs -= Xs.min()
    Xs /= Xs.max()
    Xs = (Xs * 255).astype(np.uint8)
    
    plt.figure(figsize = (12,1))
    
    for i in range(show_num):
        x = Xs[i]
        plt.subplot(1, show_num, i+1)
        plt.imshow(x)
        plt.title(name)
        plt.axis('off')
        
    plt.show()
    
def show_sample_seg(Xs, show_num = 8, name = 'output'):
    # (8, 22, H, W)
    
    Xs = Xs.detach().cpu().numpy().argmax(axis = 1)
    # (8, H, W)
    
    for i in range(show_num):
        x = Xs[i]
        plt.subplot(1, show_num, i+1)
        plt.imshow(x, cmap = 'jet', vmax = 21)
        # cmap = 'jet'で　最大値は21
        # https://beiznotes.org/matplot-cmap-list/
        
        plt.title(name)
        plt.axis('off')
        
    plt.show()
    
def show_sample_label(Xs, show_num = 8, name = 'label'):
    # (8, H, W)
    
    Xs = Xs.detach().cpu().numpy()
    
    plt.figure(figsize = (12, 1))
    
    for i in range(show_num):
        x = Xs[i]
        plt.subplot(1, show_num, i+1)
        plt.title(x, cmap = 'jet', vmax = 21)
        plt.title(name)
        plt.axis('off')
        
    plt.show()

In [27]:
def train():
    ds = torchvision.datasets.VOCSegmentation(root = './', 
                                              image_set = 'train', download = True)
    _inds = np.arange(len(ds))
    
    # fold
    if FOLD == 'KFOLD':
        kf = KFold(n_splits = FOLD_N, shuffle = True, random_state = RANDOM_SEED)
        spl = kf.split(_inds)
    
    elif FOLD == 'GroupKFold':
        kf = GroupKFold(n_splits = FOLD_N )
        spl = kf.split(_inds)
        
    elif FOLD == 'StatifiedKFold':
        kf = StratifiedKFold(n_splits = FOLD_N, shuffle = True, random_state = RANDOM_SEED)
        spl = kf.split(_inds, _inds)
        
    else:
        print('invalid fold')
        return None
    
    train_models = []
    train_model_paths = []
    
    EPOCH = 200
    
    for fold_i, (train_idx, val_idx) in enumerate(spl):
        
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        
        print(f'{FOLD} fold:{fold_i + 1}/{FOLD_N}')
        print(f'train_N={len(train_idx)}, val_N={len(val_idx)}')
        
        #---
        # dataset
        #---
        
        X_train = [np.array(ds[idx][0], dtype = np.float32) for idx  in train_idx]
        # X_train = ds.data[train_idx]で良さそう
        X_val = [np.array(ds[idx][0], dtype = np.float32) for idx in val_idx]
        t_train = [np.array(ds[idx][1].convert('P')) for idx in train_idx]
        t_val = [np.array(ds[idx][1].convert('P')) for idx in val_idx]
        
        dataset_train = VOCDataset(X_train, t_train, transforms = transforms_train)
        dataset_val = VOCDataset(X_val, t_val, transforms = transforms_val)
        
        dataloader_train = DataLoader(dataset_train, batch_size = 16, 
                                      num_workers = 0, shuffle = True, pin_memory = False)
        dataloader_val = DataLoader(dataset_val, batch_size = 16,
                                    num_workers = 0, shuffle = False, pin_memory = False)
        
        train_n = len(X_train)
        val_n = len(X_val)
        
        #---
        # model
        #---
        
        model = UNet()
        model = model.to(device)
        
        criterion = nn.CrossEntropyLoss(reduction = 'sum')
        optimizer = optim.Adam(model.parameters(), lr = 0.0001)
        
        
        #---
        # epoch
        #---
        
        for epoch in range(EPOCH):
            
            model.train()
            
            tr_loss = 0
            correct = 0
            total = 0
            
            #---
            # train
            #---
            
            train_time_start = time.time()
            
            for step, batch in enumerate(dataloader_train):
                optimizer.zero_grad()
                
                xs = batch[0].to(device)
                ts = batch[1].to(device)
                # (-1, 128, 128)
                
                ys = model(xs)
                # (-1, 22, 128, 128)
                
                _ys = ys.permute(0, 2, 3, 1).reshape(-1, 22)
                # (-1 * 128 * 128, 22)
                # 128*128がデータ数分
                _ts = ts.view(-1)
                # (-1 * 128 * 128)
                
                loss = criterion(_ys, _ts) / train_n / IN_HEIGHT / IN_WIDTH
                loss.backward()
                
                loss = loss.item()
                tr_loss += loss
                
                _, predicted = torch.max(ys.data, 1)
                total += ys.size(0)
                correct += (predicted == ts).sum().item()
                
                optimizer.step()
                
            train_losses.append(tr_loss)
            
            train_accuracy = correct / total / IN_HEIGHT / IN_WIDTH
            train_accuracies.append(train_accuracy)
            
            train_time_end = time.time()
            
            #---
            # val
            #---
            
            model.eval()
            
            val_loss = 0
            val_correct = 0
            val_total = 0
            
            val_time_start = time.time()
            
            val_labels = []
            val_preds = []
            
            with torch.no_grad():
                for step, batch in enumerate(dataloader_val):
                    xs = batch[0].to(device)
                    ts = batch[1].to(device)
                    
                    ys = model(xs)
                    
                    _ys = ys.permute(0, 2, 3, 1).reshape(-1, 22)
                    _ts = ts.view(-1)
                    
                    loss = criterion(_ys, _ts)
                    val_loss += loss.item() / val_n / IN_HEIGHT / IN_WIDTH
                    
                    _, predicted = torch.max(ys.data, 1)
                    val_total += ys.size(0)
                    val_correct += (predicted == ts).sum().item()
                    
            val_time_end = time.time()
            train_time_total = train_time_end - train_time_start
            val_time_total = val_time_end - val_time_start
            total_time = train_time_total + val_time_total
            
            val_losses.append(val_loss)
            
            val_accuracy = val_correct / val_total / IN_HEIGHT / IN_WIDTH
            val_accuracies.append(val_accuracy)
            
            print(f'fold:{fold_i + 1} epoch:{epoch + 1}/{EPOCH} [tra]loss: {tr_loss:.4f} acc: {train_accuracy:.4f} [val]loss: {val_loss:.4f} acc:{val_accuracy:.4f} [time]total: {total_time:.2f}sec tra:{train_time_total:.2f}sec val:{val_time_total:.2f}sec')
            
            if (epoch + 1) % 50 == 0:
                show_sample(xs)
                show_sample_seg(ys)
                show_sample_label(ts)
                
            if (epoch + 1) % 100 == 0:
                savename = f'model_epoch{epoch + 1}_{EPOCH}_{FOLD}_{fold_i + 1}_{FOLD_N}.pth'
                torch.save(model.state_dict(), savename)
                print(f'model saved to >> {savename}')
        
        #---
        # save model
        #---
        
        savename = f'model_epoch{EPOCH}_{FOLD}_{fold_i}_{FOLD_N}.pth'
        torch.save(model.state_dict(), savename)
        print(f'model saved to >> {savename}')
        
        train_models.append(model)
        train_model_paths.append(savename)
        
        fig, ax1 = plt.subplots()
        ax2 = ax1.twinx()
        ax1.grid()
        ax1.plot(train_losses, marker = '.', markersize = 6, color = 'red', label = 'train loss')
        ax1.plot(val_losses, marker = '.', markersize = 6, color = 'blue', label = 'val losses')
        ax2.plot(train_accuracies, marker = '.', markersize = 6, color = 'green', label = 'train accuracy')
        h1, l1 = ax1.get_legend_handles_labels()
        h2, l2 = ax2.get_legend_handles_labels()
        ax1.legend(h1 + h2, l1 + l2, loc = 'upper right')
        ax1.set(xlabel = 'Epoch', ylabel = 'Loss')
        ax2.set(ylabel = 'Accuracy')
        
        break
        
    return train_models, train_model_paths


In [None]:
train_models, train_model_paths = train()

Using downloaded and verified file: ./VOCtrainval_11-May-2012.tar
KFOLD fold:1/2
train_N=732, val_N=732
fold:1 epoch:1/200 [tra]loss: 2.9165 acc: 0.3433 [val]loss: 2.8415 acc:0.4367 [time]total: 1645.97sec tra:1296.47sec val:349.50sec
fold:1 epoch:2/200 [tra]loss: 2.5667 acc: 0.5608 [val]loss: 2.6816 acc:0.6049 [time]total: 1671.01sec tra:1284.71sec val:386.30sec
fold:1 epoch:3/200 [tra]loss: 2.4096 acc: 0.6569 [val]loss: 2.4354 acc:0.6803 [time]total: 1973.96sec tra:1641.68sec val:332.28sec
fold:1 epoch:4/200 [tra]loss: 2.2817 acc: 0.6759 [val]loss: 2.1872 acc:0.6199 [time]total: 1490.37sec tra:1161.82sec val:328.56sec
fold:1 epoch:5/200 [tra]loss: 2.1460 acc: 0.6184 [val]loss: 2.0165 acc:0.6878 [time]total: 1504.64sec tra:1177.33sec val:327.31sec
fold:1 epoch:6/200 [tra]loss: 2.0322 acc: 0.6303 [val]loss: 2.0368 acc:0.6262 [time]total: 1497.87sec tra:1170.84sec val:327.03sec
fold:1 epoch:7/200 [tra]loss: 1.9252 acc: 0.6813 [val]loss: 1.7931 acc:0.7142 [time]total: 1507.65sec tra:1179