In [None]:
import os
import shutil
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [None]:
bin_mode = True
epochs = 8

loops = 4
depth = 3
ch = 64

# 並べるタイル数
rows=2
cols=2
img_h = 32
img_w = 32

data_path = './data/mnist_iir_segmentation'
mini_batch_size = 32

## 学習データ作成

In [None]:
# dataset
transform = transforms.Compose([
                transforms.Resize((img_w, img_h)),
                transforms.ToTensor(),
            ])

dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transform, download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transform, download=True)


def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*img_h, cols*img_w), dtype=np.float32)
    teaching_img = np.zeros((11, rows*img_h, cols*img_w), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*img_w
            y = row*img_h
            img, label = gen.__next__()
            source_img[0,y:y+img_h,x:x+img_w] = img
            teaching_img[label,y:y+img_h,x:x+img_w] = img
            teaching_img[10,y:y+img_h,x:x+img_w] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # 面積で重みを載せる
    for i in range(11):
        teaching_img[i] *= weight[i]
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img

    if margin > 0:
        return source_img, teaching_img[:,margin:-margin,margin:-margin]
    return source_img, teaching_img        

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img

# フィルタ処理用にタイル化する
dataset_fname = os.path.join(data_path, 'dataset.pickle')
if os.path.exists(dataset_fname):
#if False:
    with open(dataset_fname, 'rb') as f:
        source_imgs_train = pickle.load(f)
        teaching_imgs_train = pickle.load(f)
        source_imgs_test = pickle.load(f)
        teaching_imgs_test = pickle.load(f)
        weight = pickle.load(f)
else:
    # 面積の比率で重み作成
    areas = np.zeros((11))
    for img, label in dataset_train:
        img = img.numpy()
        areas[label] += np.mean(img)
        areas[10] += np.mean(1.0-img)
    areas /= len(dataset_train)
    
    weight = 1 / areas
    weight /= np.max(weight)
    
    source_imgs_train, teaching_imgs_train = transform_data(dataset_train, 4096, rows, cols, 0) #29)
    source_imgs_test, teaching_imgs_test = transform_data(dataset_test, 128, rows, cols, 0) #, 29)
    
    os.makedirs(data_path, exist_ok=True)
    with open(dataset_fname, 'wb') as f:
        pickle.dump(source_imgs_train, f)
        pickle.dump(teaching_imgs_train, f)
        pickle.dump(source_imgs_test, f)
        pickle.dump(teaching_imgs_test, f)
        pickle.dump(weight, f)

my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
# 学習データ表示確認
plt.figure(figsize=(16,8))
for x, t in loader_test:
    break

n = min(mini_batch_size, 4)
plt.figure(figsize=(18,2*n))
for i in range(n):
    plt.subplot(n,12,i*12+1)
    plt.title('sorce')
    plt.imshow(x[i][0], 'gray')
    for j in range(11):
        plt.subplot(n,12,i*12+2+j)
        if j < 10:
            plt.title('class=%d'%i)
            plt.imshow(t[i][j], 'gray')
        else:
            plt.title('background')
            plt.imshow(t[i][j], 'gray')
plt.tight_layout()
plt.show()    

## ネットワーク定義

In [None]:
from torch.autograd import Function
class Binarize(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = x.new(x.size())
        y[x >= 0] = 1.0
        y[x < 0] = -1.0
        return y

    @staticmethod
    def backward(ctx, dy):
        x, = ctx.saved_tensors
        dx = dy.clone()
        dx[x.ge(1)]=0
        dx[x.le(-1)]=0
        return dx

binarize = Binarize.apply

In [None]:
#from torch.autograd import Function
#class Binarize(Function):
#    @staticmethod
#    def forward(ctx, x):
#        ctx.save_for_backward(x)
#        y = x.new(x.size())
#        y[x >  0] = 1.0
#        y[x <= 0] = 0
#        return y
#    
#    @staticmethod
#    def backward(ctx, dy):
#        x, = ctx.saved_tensors
#        dx = dy.clone()
#        dx[x.ge(1)]=0
#        dx[x.le(-1)]=0
#        return dx
#
#binarize = Binarize.apply

In [None]:
class ConvBlock(nn.Module):
    """基本ブロック"""
    def __init__(self, in_ch=32, out_ch=32, last=False):
        super(ConvBlock, self).__init__()
        self.last  = last
        self.cnv0  = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
        self.bn0   = nn.BatchNorm2d(out_ch)
        self.cnv1  = nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate')
        self.bn1   = nn.BatchNorm2d(out_ch)
        
    def forward(self, x):
        x = self.cnv0(x)
        x = self.bn0(x)
        if bin_mode:
            x = binarize(x)
        else:
            x = F.relu(x)
        
        x = self.cnv1(x)
        x = self.bn0(x)
        if bin_mode:
            x = binarize(x)
        elif not self.last:
            x = F.relu(x)
        return x

class ScaledNetwork(nn.Module):
    def __init__(self, ch=32, top=False):
        super(ScaledNetwork, self).__init__()
        
        self.top  = top
        self.ch   = ch
        self.up   = nn.Upsample(scale_factor=2, mode='nearest')
        self.pool = nn.MaxPool2d(2, 2)
        if self.top:
            self.cnv0 = ConvBlock(self.ch+1, self.ch)
            self.cnv1 = ConvBlock(self.ch,   11, last=self.top)
        else:
            self.cnv0 = ConvBlock(self.ch*2, self.ch)
            self.cnv1 = ConvBlock(self.ch*2, self.ch)
    
    def forward(self, x0, x1, u, train=True):
        x1 = self.up(x1)
        x = torch.cat([x0, x1], 1)
        x = self.cnv0.forward(x)        
        v = self.pool(x)
        if not self.top:
            x = torch.cat([u, x], 1)
        y = self.cnv1.forward(x)
        return y, v

class MipmapNetwork(nn.Module):
    def __init__(self, loop=4, depth=3, ch=32):
        super(MipmapNetwork, self).__init__()
        self.loop    = loop
        self.depth   = depth
        self.shape   = None
        self.ch      = ch
        self.up      = nn.Upsample(scale_factor=2, mode='nearest')
        self.m_net = ScaledNetwork(ch, top=True)
        self.s_net = ScaledNetwork(ch)
    
    def make_mipmap(self, n, h, w):
        mipmap = []
        for i in range(self.depth+1):
            h //= 2
            w //= 2
            buf = torch.zeros(n, self.ch, h, w).to(device)
            mipmap.append(buf)
        return mipmap
    
    def forward(self, x):
        n = x.shape[0]
        c = x.shape[1]
        h = x.shape[2]
        w = x.shape[3]
        
        mipmap = self.make_mipmap(n, h, w)        
        for i in range(self.loop):
            y, v = self.m_net.forward(x, mipmap[0], None)
            if i < self.loop-1:
                for j in range(self.depth):
                    mipmap[j], v = self.s_net.forward(mipmap[j], mipmap[j+1], v)
            if i == 0:
                yy = y
            else:
                yy = torch.cat([yy, y], 0)
        return yy

## 学習実施

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

net_torch = MipmapNetwork(loop=4, depth=3, ch=ch).to(device)

In [None]:
# 面積に応じて重み付けする
w = torch.from_numpy(weight.astype(np.float32)).to(device)
criterion = nn.CrossEntropyLoss(weight=w)


# 学習実施
epochs = 16
optimizer = optim.Adam(net_torch.parameters(), lr=0.001)
#optimizer = optim.SGD(net_torch.parameters(), lr=0.001)
for epoch in range(epochs):
    print('epoch:%d'%epoch)
    with tqdm(loader_train) as tqdm_loadr:
        for x, t in tqdm_loadr:
            x = x.to(device)
            t = t.to(device)
            
            optimizer.zero_grad()
            yy = net_torch(x)
            
            tt = torch.cat([t, t, t, t], 0)           
            loss = criterion(yy, torch.argmax(tt, dim=1))
            loss.backward()
            
            optimizer.step()
            tqdm_loadr.set_postfix(loss=loss.item())

## 結果表示

In [None]:
# 1ループだけ推論
with torch.no_grad():
    for x, t in loader_test:
        x = x.to(device)
        y = net_torch(x)
#       y = F.softmax(y, dim=1)
        break

# 表示
x = x.to('cpu').detach().numpy()
y = y.to('cpu').detach().numpy()
print(x.shape)
print(y.shape)
n = 8
plt.figure(figsize=(16, 2*n))
for i in range(n):
    plt.subplot(n,12,i*12+1)
    plt.imshow(x[i][0], 'gray')
    for j in range(11):
        plt.subplot(n,12,i*12+j+2)
        plt.title('class=%d'%j)
        plt.imshow(y[32*3+i][j], 'gray')
plt.tight_layout()
plt.show()