In [None]:
import os
import shutil
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#from tqdm import tqdm
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Function

import binarybrain as bb

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
bb.get_version_string()

In [None]:
# configuration
bb.set_device(0)

verbose = 0

net_name               = 'cmp_torch_mnist_iir_segmentation'
data_path              = os.path.join('./data/', net_name)

bin_mode               = False
#frame_modulation_size  = 1
#depth_integration_size = 1
epochs                 = 16
mini_batch_size        = 16

loops = 3
depth = 4

In [None]:
# 並べるタイル数
rows=2
cols=2

img_h = 32
img_w = 32

# dataset
transform = transforms.Compose([
                transforms.Resize((img_w, img_h)),
                transforms.ToTensor(),
            ])

dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transform, download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transform, download=True)


def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*img_h, cols*img_w), dtype=np.float32)
    teaching_img = np.zeros((11, rows*img_h, cols*img_w), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*img_w
            y = row*img_h
            img, label = gen.__next__()
            source_img[0,y:y+img_h,x:x+img_w] = img
            teaching_img[label,y:y+img_h,x:x+img_w] = img
            teaching_img[10,y:y+img_h,x:x+img_w] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # 面積で重みを載せる
    for i in range(11):
        teaching_img[i] *= weight[i]
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img

    if margin > 0:
        return source_img, teaching_img[:,margin:-margin,margin:-margin]
    return source_img, teaching_img

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img

# フィルタ処理用にタイル化する
dataset_fname = os.path.join(data_path, 'dataset.pickle')
if os.path.exists(dataset_fname):
#if False:
    with open(dataset_fname, 'rb') as f:
        source_imgs_train = pickle.load(f)
        teaching_imgs_train = pickle.load(f)
        source_imgs_test = pickle.load(f)
        teaching_imgs_test = pickle.load(f)
        weight = pickle.load(f)
else:
    # 面積の比率で重み作成
    areas = np.zeros((11))
    for img, label in dataset_train:
        img = img.numpy()
        areas[label] += np.mean(img)
        areas[10] += np.mean(1.0-img)
    areas /= len(dataset_train)
    
    weight = 1 / areas
    weight /= np.max(weight)
    
    source_imgs_train, teaching_imgs_train = transform_data(dataset_train, 4096, rows, cols, 0) #29)
    source_imgs_test, teaching_imgs_test = transform_data(dataset_test, 128, rows, cols, 0) #, 29)
    
    os.makedirs(data_path, exist_ok=True)
    with open(dataset_fname, 'wb') as f:
        pickle.dump(source_imgs_train, f)
        pickle.dump(teaching_imgs_train, f)
        pickle.dump(source_imgs_test, f)
        pickle.dump(teaching_imgs_test, f)
        pickle.dump(weight, f)

my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

In [None]:
# 学習データ表示確認
def plt_data(x, y):
    plt.figure(figsize=(16,8))
    plt.subplot(1,12,1)
    plt.title('sorce')
    plt.imshow(x[0], 'gray')
    for i in range(11):
        plt.subplot(1,12,2+i)
        if i < 10:
            plt.title('class=%d'%i)
            plt.imshow(y[i], 'gray')
        else:
            plt.title('background')
            plt.imshow(y[i], 'gray')
    plt.tight_layout()
    _ = plt.show()

if False:
    plt.figure(figsize=(16,8))
    for source_imgs, teaching_imgs in loader_test:
        print(source_imgs[0].shape)
        print(teaching_imgs[0].shape)
        for i in range(min(mini_batch_size, 4)):
            plt_data(source_imgs[i], teaching_imgs[i])
        break

In [None]:
def view(net, loader, num=2):
    """表示確認"""
    n = 0
    for x_imgs, t_imgs in loader:
        plt.figure(figsize=(16,8))
        x_buf = bb.FrameBuffer.from_numpy(np.array(x_imgs).astype(np.float32))
#       t0_buf = bb.FrameBuffer.from_numpy(np.array(t_imgs[:,0:10,:,:]).astype(np.float32))
#       t1_buf = bb.FrameBuffer.from_numpy(np.array(1.0 - t_imgs[:,10:11,:,:]).astype(np.float32)) 
        y_buf = net.forward(x_buf, train=False)
        result_imgs = y_buf.numpy()
        plt_data(x_imgs[0], result_imgs[0])
        n += 1
        if n >= num: break

## BinaryBrain版ネット定義

In [None]:
def create_dense_conv(output_ch, filter_size=(3, 3), padding='same'):
    """バイナリ化したDenseConv層生成"""
    return bb.Convolution2d(
                bb.Sequential([
                    bb.DenseAffine([output_ch, 1, 1]),
#                   bb.BatchNormalization(),
                ]),
                filter_size=filter_size,
                padding=padding)

class ConvBlock_bb(bb.Sequential):
    """基本ブロック"""
    def __init__(self, in_ch=32, out_ch=32, last=False, log_name=""):
        self.log_name = log_name
        self.cnv0  = create_dense_conv(out_ch)
        self.relu0 = bb.ReLU()
        self.cnv1  = create_dense_conv(out_ch)
        if last:
            self.relu1 = None
            super(ConvBlock_bb, self).__init__([self.cnv0, self.relu0, self.cnv1])
        else:
            self.relu1 = bb.ReLU()
            super(ConvBlock_bb, self).__init__([self.cnv0, self.relu0, self.cnv1, self.relu1])
    
    def backward(self, dy):
        if self.relu1 is not None:
            dy = self.relu1.backward(dy)
        dy = self.cnv1.backward(dy)
        dy = self.relu0.backward(dy)
        if verbose >= 1:
            print('[bb %s] min:%f max:%f'%(self.log_name, np.min(dy.numpy()), np.max(dy.numpy())))
        dy = self.cnv0.backward(dy)
        return dy

class ScaledNetwork_bb(bb.Sequential):
    def __init__(self, ch=32, top=False):
        self.top = top
        self.ch  = ch
        
        self.up   = bb.UpSampling((2, 2))
        self.pool = bb.MaxPooling((2, 2))
        self.cat0 = bb.Concatenate()
        self.cat1 = bb.Concatenate()
        if self.top:
            self.cnv0 = ConvBlock_bb(self.ch+1, self.ch, log_name="m_cnv0")
            self.cnv1 = ConvBlock_bb(self.ch,   11, last=self.top, log_name="m_cnv1")
        else:
            self.cnv0 = ConvBlock_bb(self.ch*2, self.ch, log_name="s_cnv0")
            self.cnv1 = ConvBlock_bb(self.ch*2, self.ch, log_name="s_cnv1")
        
        super(ScaledNetwork_bb, self).__init__([self.up, self.pool, self.cnv0, self.cnv1])
        
    def set_input_shape(self, x0_shape, x1_shape, u_shape):
        x1_shape = self.up.set_input_shape(x1_shape)
        x_shape = self.cat0.set_input_shape([x0_shape, x1_shape])
        x_shape = self.cnv0.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(x_shape)
        if not self.top:
            x_shape = self.cat1.set_input_shape([u_shape, x_shape])
        y_shape = self.cnv1.set_input_shape(x_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, u, train=True):
        x1 = self.up.forward(x1, train=train)
        x = self.cat0.forward([x0, x1])
        x = self.cnv0.forward(x, train=train)        
        v = self.pool.forward(x, train=train)
        if not self.top:
            x = self.cat1.forward([u, x])
        y = self.cnv1.forward(x, train=train)
        return y, v
    
    def backward(self, dy, dv):
        dy = self.cnv1.backward(dy)
        if not self.top:
            du, dy0 = self.cat1.backward([dy])
        else:
            du, dy0 = None, dy
        dy1 = self.pool.backward(dv)
        dy = self.cnv0.backward(dy0 + dy1)
        dx0, dx1 = self.cat0.backward([dy])
        dx1 = self.up.backward(dx1)
        return dx0, dx1, du


class MipmapNetwork_bb(bb.Sequential):
    def __init__(self, loop=3, depth=4, ch=32):
        self.loop    = loop
        self.depth   = depth
        self.ch      = ch
        self.shape   = None
        self.up      = bb.UpSampling((2, 2))
        self.m_net = ScaledNetwork_bb(ch, top=True)
        self.s_net = ScaledNetwork_bb(ch)
        super(MipmapNetwork_bb, self).__init__([self.m_net, self.s_net])
    
    def set_input_shape(self, shape):
        self.shape = copy.copy(shape)
        
        x0_shape = copy.copy(shape)
        x1_shape = copy.copy(shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        u_shape = copy.copy(shape)
        y_shape, v_shape = self.m_net.set_input_shape(x0_shape, x1_shape, u_shape)
        
        x0_shape = copy.copy(x1_shape)
        x0_shape[0] = self.ch
        x1_shape = copy.copy(x1_shape)
        x1_shape[0] = self.ch
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        self.s_net.set_input_shape(x0_shape, x1_shape, v_shape)
        
        return y_shape
    
    def make_mipmap(self, frame_size, dtype=bb.DType.FP32):
        h = self.shape[1]
        w = self.shape[2]
        mipmap = []
        for i in range(self.depth+1):
            h //= 2
            w //= 2
            buf = bb.FrameBuffer(frame_size, (self.ch, h, w), dtype=dtype)
            buf.fill_zero()
            mipmap.append(buf)
        return mipmap
    
    def forward(self, x, train=True):
        self.m_net.clear()
        self.s_net.clear()
        
        frame_size = x.get_frame_size()
        self.shape = x.get_node_shape()
        
        mipmap = self.make_mipmap(frame_size)        
        for _ in range(self.loop):
            y, v = self.m_net.forward(x, mipmap[0], None, train=train)
            for i in range(self.depth):
                mipmap[i], v = self.s_net.forward(mipmap[i], mipmap[i+1], v)
        
        self.v_shape = v.get_node_shape()
        return y
    
    def backward(self, dy):
        frame_size = dy.get_frame_size()
        dv = bb.FrameBuffer(frame_size, self.v_shape, dtype=bb.DType.FP32)
        dv.fill_zero()
        
        mipmap = self.make_mipmap(frame_size, dtype=bb.DType.FP32)
#        du_shape = du.get_node_shape()
        for _ in range(self.loop):
            du = dv
            for i in reversed(range(self.depth)):
                dx0, dx1, du = self.s_net.backward(mipmap[i], du)
                mipmap[i]    = dx0
                mipmap[i+1] += dx1
            dx0, dx1, du = self.m_net.backward(dy, du)
            mipmap[0] = dx1
            dy = bb.FrameBuffer(frame_size, dy.get_node_shape(), dtype=bb.DType.FP32)
            dy.fill_zero()
#            du = bb.FrameBuffer(frame_size, du_shape, dtype=bb.DType.FP32)
#            du.fill_zero()
        return dx0

def check_param_affine_bb(affine_bb):
    assert(not np.isnan(affine_bb.W().numpy()).any())
    assert(not np.isnan(affine_bb.b().numpy()).any())
    assert(not np.isnan(affine_bb.dW().numpy()).any())
    assert(not np.isnan(affine_bb.db().numpy()).any())

def check_param_bb(net_bb):
    check_param_affine_bb(net_bb.m_net.cnv0[0][1][0])
    check_param_affine_bb(net_bb.m_net.cnv0[2][1][0])
    check_param_affine_bb(net_bb.m_net.cnv1[0][1][0])
    check_param_affine_bb(net_bb.m_net.cnv1[2][1][0])
    check_param_affine_bb(net_bb.s_net.cnv0[0][1][0])
    check_param_affine_bb(net_bb.s_net.cnv0[2][1][0])
    check_param_affine_bb(net_bb.s_net.cnv1[0][1][0])
    check_param_affine_bb(net_bb.s_net.cnv1[2][1][0])
    
def clamp_param_affine_bb(affine_bb, a, b, rate=1.0):
    affine_bb.W().clamp_inplace(a, b)
    W = affine_bb.W()
    W *= rate
    affine_bb.b().clamp_inplace(a, b)
    b = affine_bb.b()
    b *= rate
#   affine_bb.dW().clamp_inplace(a, b, rate)
#   affine_bb.db().clamp_inplace(a, b, rate)

def clamp_param_bb(net_bb, a, b, rate=1.0):
    clamp_param_affine_bb(net_bb.m_net.cnv0[0][1][0], a, b, rate)
    clamp_param_affine_bb(net_bb.m_net.cnv0[2][1][0], a, b, rate)
    clamp_param_affine_bb(net_bb.m_net.cnv1[0][1][0], a, b, rate)
    clamp_param_affine_bb(net_bb.m_net.cnv1[2][1][0], a, b, rate)
    clamp_param_affine_bb(net_bb.s_net.cnv0[0][1][0], a, b, rate)
    clamp_param_affine_bb(net_bb.s_net.cnv0[2][1][0], a, b, rate)
    clamp_param_affine_bb(net_bb.s_net.cnv1[0][1][0], a, b, rate)
    clamp_param_affine_bb(net_bb.s_net.cnv1[2][1][0], a, b, rate)
    
#net_bb = MipmapNetwork_bb(loop=loops, depth=depth)
#net_bb.set_input_shape([1, img_h*rows, img_w*cols])

In [None]:
def print_param_status(param, name=""):
    print('[%s] mean:%f std:%f min:%f max:%f nan:'%(name, np.nanmean(param), np.nanstd(param), np.nanmin(param), np.nanmax(param)), np.isnan(param).any())

def print_param_affine_bb(affine_bb, name=""):
    print_param_status(affine_bb.W().numpy(), name+'.W')
    print_param_status(affine_bb.b().numpy(), name+'.b')
    print_param_status(affine_bb.dW().numpy(), name+'.dW')
    print_param_status(affine_bb.db().numpy(), name+'.db')

def print_param_bb(net_bb):
    print_param_affine_bb(net_bb.m_net.cnv0[0][1][0], 'm_cnv00')
    print_param_affine_bb(net_bb.m_net.cnv0[2][1][0], 'm_cnv01')
    print_param_affine_bb(net_bb.m_net.cnv1[0][1][0], 'm_cnv10')
    print_param_affine_bb(net_bb.m_net.cnv1[2][1][0], 'm_cnv11')
    print_param_affine_bb(net_bb.s_net.cnv0[0][1][0], 's_cnv00')
    print_param_affine_bb(net_bb.s_net.cnv0[2][1][0], 's_cnv01')
    print_param_affine_bb(net_bb.s_net.cnv1[0][1][0], 's_cnv10')
    print_param_affine_bb(net_bb.s_net.cnv1[2][1][0], 's_cnv11')

In [None]:
loops = 3
depth = 4
net_bb = MipmapNetwork_bb(loop=loops, depth=depth)
net_bb.set_input_shape([1, img_h*rows, img_w*cols])

In [None]:
if True:
    epochs = 32
    
    criterion_bb = bb.LossSoftmaxCrossEntropy()
    metrics_bb   = bb.MetricsCategoricalAccuracy()
    optimizer_bb = bb.OptimizerAdam(learning_rate=0.001)
#   optimizer_bb = bb.OptimizerSgd(learning_rate=0.0001)

    optimizer_bb.set_variables(net_bb.get_parameters(), net_bb.get_gradients())
    
    num = 0
    for epoch in range(epochs):
        # learning
        criterion_bb.clear()
        metrics_bb.clear()
        net_bb.clear()
        with tqdm(loader_train) as tqdm_loadr:
            for x_torch, t_torch in tqdm_loadr:

                x_bb = bb.FrameBuffer.from_numpy(np.array(x_torch).astype(np.float32))
                t_bb = bb.FrameBuffer.from_numpy(np.array(t_torch).astype(np.float32))

                y_bb = net_bb.forward(x_bb, train=True)

                dy_bb = criterion_bb.calculate(y_bb, t_bb)
                metrics_bb.calculate(y_bb, t_bb)
                
#                if np.isnan(criterion_bb.get()):
#                    print('nan')
#                    assert(0)
                
                net_bb.backward(dy_bb)
#                print_param_status(y_bb.numpy(), name="y_bb")
#                print_param_bb(net_bb)

                optimizer_bb.update()
#               clamp_param_bb(net_bb, -1.0, +1.0, rate=0.9999)
                
                num += 1
#                if num % 100 == 99:
#                    print_param_status(y_bb.numpy(), name="y_bb")
#                    print_param_bb(net_bb)
    
                tqdm_loadr.set_postfix(loss=criterion_bb.get(), acc=metrics_bb.get())
        if False:
            print_param_status(y_bb.numpy(), name="y_bb")
            print_param_bb(net_bb)
            check_param_bb(net_bb)
        view(net_bb, loader_test, num=2)

In [None]:
view(net_bb, loader_test, num=8)

In [None]:
bb.save_networks(data_path, net_bb)

## PyTorch版ネット定義

In [None]:
class Through(Function):
    @staticmethod
    def forward(ctx, x):
        y = x.clone()
        return y
    
    @staticmethod
    def backward(ctx, dy):
        dx = dy.clone()
        if verbose >= 1:
            print('[torch] min:%f max:%f'%(np.min(dx.to('cpu').detach().numpy()), np.max(dx.to('cpu').detach().numpy())))
        return dx

through = Through.apply

In [None]:
class ConvBlock_torch(nn.Module):
    """基本ブロック"""
    def __init__(self, in_ch=32, out_ch=32, last=False):
        super(ConvBlock_torch, self).__init__()
        self.last  = last
#        self.through = Through()
        self.cnv0  = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
#        self.bn0   = nn.BatchNorm2d(out_ch)
        self.relu0 = nn.ReLU(inplace=True)
        self.cnv1  = nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate')
#        self.bn1   = nn.BatchNorm2d(out_ch)
        if self.last:
            self.relu1 = None
        else:
            self.relu1 = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.cnv0(x)
        x = through(x)
#        x = self.bn0(x)
        x = self.relu0(x)
        x = self.cnv1(x)
#        x = self.bn0(x)
        if self.relu1 is not None:
            x = self.relu1(x)
        return x

class ScaledNetwork_torch(nn.Module):
    def __init__(self, ch=32, top=False):
        super(ScaledNetwork_torch, self).__init__()
        
        self.top  = top
        self.ch   = ch
        self.up   = nn.Upsample(scale_factor=2, mode='nearest')
        self.pool = nn.MaxPool2d(2, 2)
        if self.top:
            self.cnv0 = ConvBlock_torch(self.ch+1, self.ch)
            self.cnv1 = ConvBlock_torch(self.ch,   11, last=self.top)
        else:
            self.cnv0 = ConvBlock_torch(self.ch*2, self.ch)
            self.cnv1 = ConvBlock_torch(self.ch*2, self.ch)
    
    def forward(self, x0, x1, u, train=True):
        x1 = self.up(x1)
        x = torch.cat([x0, x1], 1)
        x = self.cnv0.forward(x)        
        v = self.pool(x)
        if not self.top:
            x = torch.cat([u, x], 1)
        y = self.cnv1.forward(x)
        return y, v

class MipmapNetwork_torch(nn.Module):
    def __init__(self, loop=3, depth=4, ch=32):
        super(MipmapNetwork_torch, self).__init__()
        self.loop    = loop
        self.depth   = depth
        self.shape   = None
        self.ch      = ch
        self.up      = nn.Upsample(scale_factor=2, mode='nearest')
        self.m_net = ScaledNetwork_torch(ch, top=True)
        self.s_net = ScaledNetwork_torch(ch)
    
    def make_mipmap(self, n, h, w):
        mipmap = []
        for i in range(self.depth+1):
            h //= 2
            w //= 2
            buf = torch.zeros(n, self.ch, h, w).to(device)
            mipmap.append(buf)
        return mipmap
    
    def forward(self, x):
        n = x.shape[0]
        c = x.shape[1]
        h = x.shape[2]
        w = x.shape[3]
        
        mipmap = self.make_mipmap(n, h, w)        
        for _ in range(self.loop):
            y, v = self.m_net.forward(x, mipmap[0], None)
            for i in range(self.depth):
                mipmap[i], v = self.s_net.forward(mipmap[i], mipmap[i+1], v)
        return y
    
#net_torch = MipmapNetwork_torch().to(device)

## 比較

In [None]:
loops = 4
depth = 3

net_torch = MipmapNetwork_torch(loop=loops, depth=depth).to(device)

net_bb = MipmapNetwork_bb(loop=loops, depth=depth)
net_bb.set_input_shape([1, img_h*rows, img_w*cols])

def compare_param(param_bb, param_torch, name=""):
    pb = param_bb.numpy()
    pt = param_torch.to('cpu').detach().numpy().reshape(pb.shape)
    df = pb - pt
#    print('[%s] diff:%f min:%f max:%f'%(name, np.std(df), np.min(df), np.max(df)))
#    print('[%s] std bb:%f torch:%f'%(name, np.std(pb), np.std(pt)))
    print('[%s] min bb:%f torch:%f'%(name, np.min(pb), np.min(pt)))
    print('[%s] max bb:%f torch:%f'%(name, np.max(pb), np.max(pt)))

def compare_affine_param(affine_bb, affine_torch, name=""):
#    compare_param(affine_bb.W(), affine_torch.weight, name=name+'.W')
#    compare_param(affine_bb.b(), affine_torch.bias, name=name+'.b')
    compare_param(affine_bb.dW(), affine_torch.weight.grad, name=name+'.dW')
    compare_param(affine_bb.db(), affine_torch.bias.grad, name=name+'.db')   

def compare_net_param(net_bb, net_torch):
    compare_affine_param(net_bb.m_net.cnv0[0][1][0], net_torch.m_net.cnv0.cnv0, 'm_cnv00')
    compare_affine_param(net_bb.m_net.cnv0[2][1][0], net_torch.m_net.cnv0.cnv1, 'm_cnv01')
    compare_affine_param(net_bb.m_net.cnv1[0][1][0], net_torch.m_net.cnv1.cnv0, 'm_cnv10')
    compare_affine_param(net_bb.m_net.cnv1[2][1][0], net_torch.m_net.cnv1.cnv1, 'm_cnv11')
    compare_affine_param(net_bb.s_net.cnv0[0][1][0], net_torch.s_net.cnv0.cnv0, 's_cnv00')
    compare_affine_param(net_bb.s_net.cnv0[2][1][0], net_torch.s_net.cnv0.cnv1, 's_cnv01')
    compare_affine_param(net_bb.s_net.cnv1[0][1][0], net_torch.s_net.cnv1.cnv0, 's_cnv10')
    compare_affine_param(net_bb.s_net.cnv1[2][1][0], net_torch.s_net.cnv1.cnv1, 's_cnv11')
    
def copy_affine_param(model_bb, model_torch):
    model_bb.W().set_numpy(model_torch.weight.to('cpu').detach().numpy().reshape(model_bb.W().get_shape()))
    model_bb.b().set_numpy(model_torch.bias.to('cpu').detach().numpy().reshape(model_bb.b().get_shape()))

copy_affine_param(net_bb.m_net.cnv0[0][1][0], net_torch.m_net.cnv0.cnv0)
copy_affine_param(net_bb.m_net.cnv0[2][1][0], net_torch.m_net.cnv0.cnv1)
copy_affine_param(net_bb.m_net.cnv1[0][1][0], net_torch.m_net.cnv1.cnv0)
copy_affine_param(net_bb.m_net.cnv1[2][1][0], net_torch.m_net.cnv1.cnv1)

copy_affine_param(net_bb.s_net.cnv0[0][1][0], net_torch.s_net.cnv0.cnv0)
copy_affine_param(net_bb.s_net.cnv0[2][1][0], net_torch.s_net.cnv0.cnv1)
copy_affine_param(net_bb.s_net.cnv1[0][1][0], net_torch.s_net.cnv1.cnv0)
copy_affine_param(net_bb.s_net.cnv1[2][1][0], net_torch.s_net.cnv1.cnv1)

weight_torch = torch.from_numpy(weight.astype(np.float32)).clone().to(device)
criterion_torch = nn.CrossEntropyLoss(weight=weight_torch)  # 面積に応じて重み付けする
optimizer_torch = optim.Adam(net_torch.parameters(), lr=0.0001)

criterion_bb = bb.LossSoftmaxCrossEntropy()
metrics_bb   = bb.MetricsCategoricalAccuracy()
optimizer_bb = bb.OptimizerAdam(learning_rate=0.0001)
optimizer_bb.set_variables(net_bb.get_parameters(), net_bb.get_gradients())

criterion_bb.clear()
metrics_bb.clear()
net_bb.clear()

num = 0
epochs = 4
for epoch in range(epochs):
    with tqdm(loader_train) as tqdm_loadr:
        for x_torch, t_torch in tqdm_loadr:
            
            x = x_torch.detach().numpy()
            t = t_torch.detach().numpy()
            x_torch = x_torch.to(device)
            t_torch = t_torch.to(device)
            x_bb = bb.FrameBuffer.from_numpy(x)
            t_bb = bb.FrameBuffer.from_numpy(t)

            optimizer_torch.zero_grad()
            y_torch = net_torch(x_torch)
            loss_torch = criterion_torch(y_torch, torch.argmax(t_torch, dim=1))
            loss_torch.backward()

            y_bb = net_bb.forward(x_bb, train=True)
            dy_bb = criterion_bb.calculate(y_bb, t_bb)
            metrics_bb.calculate(y_bb, t_bb)
            net_bb.backward(dy_bb)
        #   break
        
            if verbose >= 1:
                print('--------------')
                compare_net_param(net_bb, net_torch)
            
            optimizer_torch.step()
            optimizer_bb.update()
            
            num += 1
#            if num % 10 == 9:
#                print(num)

#            print(loss_torch.item())
#            print(criterion_bb.get())

#            if num > 99:
#                break
            
            tqdm_loadr.set_postfix(loss_bb=criterion_bb.get(), loss_torch=loss_torch.item())
            criterion_bb.clear()
    break

In [None]:
print(num)
print(loss_torch.item())
print(criterion_bb.get())

In [None]:
----------------

In [None]:
 compare_net_param(net_bb, net_torch)

In [None]:
compare_param(y_bb, y_torch)

In [None]:
diff = y_torch.to('cpu').detach().numpy() - y_bb.numpy().reshape(y_torch.shape)
print(np.unravel_index(np.argmax(diff), diff.shape))

In [None]:
y_bb.numpy()[13,:,63,15]

In [None]:
y_torch.to('cpu').detach().numpy()[13,:,63,15]

In [None]:
--------

In [None]:
net_torch.m_net.cnv0.cnv0.bias.detach().numpy()

In [None]:
net_bb.m_net.cnv0[2][1][0]

## PyTorch 版学習テスト

In [None]:
# 面積に応じて重み付けする
weight_torch = torch.from_numpy(weight.astype(np.float32)).clone().to(device)
criterion = nn.CrossEntropyLoss(weight=weight_torch)

# 学習実施
epochs = 16
optimizer = optim.Adam(net_torch.parameters(), lr=0.0001)
for epoch in range(epochs):
    with tqdm(loader_train) as tqdm_loadr:
        for x, t in tqdm_loadr:
            x = x.to(device)
            t = t.to(device)
            
            optimizer.zero_grad()
            y = net_torch(x)
            loss = criterion(y, torch.argmax(t, dim=1))
            loss.backward()
            optimizer.step()
            tqdm_loadr.set_postfix(loss=loss.item())

In [None]:
# 結果表示
with torch.no_grad():
    for x, t in loader_test:
        x = x.to(device)
        y = net_torch(x)
#       y = F.softmax(net(x_torch), dim=1)
        break

x = x.to('cpu').detach().numpy()
y = y.to('cpu').detach().numpy()
for i in range(8):
    plt.figure(figsize=(16, 6))
    plt.subplot(1,12,1)
    plt.imshow(x[i][0], 'gray')
    for j in range(11):
        plt.subplot(1,12,j+2)
        plt.title('class=%d'%j)
#       print(np.min(x[i][j]), np.max(x[i][j]))
        plt.imshow(y[i][j], 'gray')
    plt.tight_layout()
    plt.show()