# IIR型MNISTセマンティックセグメンテーション

PyTorch との比較

In [None]:
import os
import shutil
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#from tqdm import tqdm
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Function

import binarybrain as bb

In [None]:
use_bb = True
use_torch = False

bin_mode = True
lut_mode = False
bin_modulation = False
frame_modulation_size = 1

if not bin_modulation:
    frame_modulation_size = 1

verbose = 0

epochs = 16
mini_batch_size = 32

# ネット構造
loops = 5
depth = 3
#ch = 32

# 並べるタイル数
rows=2
cols=2
img_h = 32
img_w = 32

data_path = './data/cmp_torch_mnist_iir_segmentation_v2'
#if bin_mode:
#    data_path += '_bin'

In [None]:
bb.get_version_string()

In [None]:
if use_torch:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

## 学習データ作成

In [None]:
# dataset
transform = transforms.Compose([
                transforms.Resize((img_w, img_h)),
                transforms.ToTensor(),
            ])

dataset_path = './data/'
dataset_train = torchvision.datasets.MNIST(root=dataset_path, train=True, transform=transform, download=True)
dataset_test  = torchvision.datasets.MNIST(root=dataset_path, train=False, transform=transform, download=True)

def make_teacher_image(gen, rows, cols, margin=0):
    source_img  = np.zeros((1, rows*img_h, cols*img_w), dtype=np.float32)
    teaching_img = np.zeros((11, rows*img_h, cols*img_w), dtype=np.float32)
    for row in range(rows):
        for col in range(cols):
            x = col*img_w
            y = row*img_h
            img, label = gen.__next__()
            source_img[0,y:y+img_h,x:x+img_w] = img
            teaching_img[label,y:y+img_h,x:x+img_w] = img
            teaching_img[10,y:y+img_h,x:x+img_w] = (1.0-img)
    teaching_img = (teaching_img > 0.5).astype(np.float32)
    
    # 面積で重みを載せる
    for i in range(11):
        teaching_img[i] *= weight[i]
    
    # ランダムに反転
    if random.random() > 0.5:
        source_img = 1.0 - source_img

    if margin > 0:
        return source_img, teaching_img[:,margin:-margin,margin:-margin]
    return source_img, teaching_img        

def transform_data(dataset, n, rows, cols, margin):
    def data_gen():
        l = len(dataset)
        i = 0
        while True:
            yield dataset[i%l]
            i += 1
    
    gen = data_gen()
    source_imgs = []
    teaching_imgs = []
    for _ in range(n):
        x, t = make_teacher_image(gen, rows, cols, margin)
        source_imgs.append(x)
        teaching_imgs.append(t)
    return source_imgs, teaching_imgs

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self, source_imgs, teaching_imgs, transforms=None):
        self.transforms = transforms
        self.source_imgs = source_imgs
        self.teaching_imgs = teaching_imgs
        
    def __len__(self):
        return len(self.source_imgs)

    def __getitem__(self, index):
        source_img = self.source_imgs[index]
        teaching_img = self.teaching_imgs[index]
        if self.transforms:
            source_img, teaching_img = self.transforms(source_img, teaching_img)
        return source_img, teaching_img

# フィルタ処理用にタイル化する
dataset_fname = os.path.join(data_path, 'dataset.pickle')
if os.path.exists(dataset_fname):
#if False:
    with open(dataset_fname, 'rb') as f:
        source_imgs_train = pickle.load(f)
        teaching_imgs_train = pickle.load(f)
        source_imgs_test = pickle.load(f)
        teaching_imgs_test = pickle.load(f)
        weight = pickle.load(f)
else:
    # 面積の比率で重み作成
    areas = np.zeros((11))
    for img, label in dataset_train:
        img = img.numpy()
        areas[label] += np.mean(img)
        areas[10] += np.mean(1.0-img)
    areas /= len(dataset_train)
    
    weight = 1 / areas
    weight /= np.sum(weight)
    
    source_imgs_train, teaching_imgs_train = transform_data(dataset_train, 4096, rows, cols, 0) #29)
    source_imgs_test, teaching_imgs_test = transform_data(dataset_test, 128, rows, cols, 0) #, 29)
    
    os.makedirs(data_path, exist_ok=True)
    with open(dataset_fname, 'wb') as f:
        pickle.dump(source_imgs_train, f)
        pickle.dump(teaching_imgs_train, f)
        pickle.dump(source_imgs_test, f)
        pickle.dump(teaching_imgs_test, f)
        pickle.dump(weight, f)

my_dataset_train = MyDatasets(source_imgs_train, teaching_imgs_train)
my_dataset_test = MyDatasets(source_imgs_test, teaching_imgs_test)

loader_train = torch.utils.data.DataLoader(dataset=my_dataset_train, batch_size=mini_batch_size, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset=my_dataset_test, batch_size=mini_batch_size, shuffle=False)

## アシスト関数定義

In [None]:
def plot_data(x, y, n=2):
    """データ表示確認"""
    n = min(x.shape[0], n)
    plt.figure(figsize=(18,2*n))
    for i in range(n):
        plt.subplot(n,12,i*12+1)
        plt.title('sorce')
        plt.imshow(x[i][0], 'gray')
        for j in range(11):
            plt.subplot(n,12,i*12+2+j)
            if j < 10:
                plt.title('class=%d'%j)
                plt.imshow(y[i][j], 'gray')
            else:
                plt.title('background')
                plt.imshow(y[i][j], 'gray')
    plt.tight_layout()
    plt.show()

In [None]:
# 学習データ表示確認
if False:
    plt.figure(figsize=(16,8))
    for x, t in loader_test:
        break

    plot_data(x, t, 4)

In [None]:
def to_numpy(x):
    if type(x) == np.ndarray:
        return x
    if type(x) == bb.FrameBuffer or type(x) == bb.Tensor:
        return x.numpy()
    if type(x) == torch.Tensor:
        return x.to('cpu').detach().clone().numpy()
    return x.to('cpu').detach().clone().numpy()

def calc_corrcoef(a, b):
    a = to_numpy(a).reshape(-1)
    b = to_numpy(b).reshape(-1)
    return np.corrcoef(a, b)[0][1]

def print_summary(x, text=""):
    if x is None:
        print("[%s] None"%(text))
        return
    x = to_numpy(x)
    print("[%s] mean:%1.8f std:%1.8f min:%1.8f max:%1.8f isnan:%d"%(text, np.mean(x), np.std(x), np.min(x), np.max(x), np.isnan(x).any()))

def print_diff(a, b, text=""):
    a = to_numpy(a)
    b = to_numpy(b).reshape(a.shape)
    print_summary(a - b, text=text)

def print_diff_summary(a_bb, a_torch, text=""):
    a_bb    = to_numpy(a_bb)
    a_torch = to_numpy(a_torch).reshape(a_bb.shape)
    r = calc_corrcoef(a_bb, a_torch)
    print('[corrcoef:%f]'%r)
    print_summary(a_bb,           text=text+" bb   ")
    print_summary(a_torch,        text=text+" torch")
    print_summary(a_bb - a_torch, text=text+" diff")

def print_shape(x, text=""):
    if x is None:
        print("[%s] None"%text)
    else:
        print("[%s]"%text, x.get_shape())

In [None]:
def dump_object(obj, file):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

def load_object(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

## ネットワーク定義

In [None]:
# バイナリ時は BIT型を使えばメモリ削減可能
bin_dtype = bb.DType.BIT if bin_mode else bb.DType.FP32
# bin_dtype = bb.DType.FP32

### 汎用バッファ

In [None]:
class Buffer():
    """PyTorch/BinaryBrain共用バッファ"""
    def __init__(self, shape=None, dtype=bb.DType.FP32):
        self.torch = None
        self.bb    = None
        if shape is not None:
            if use_torch:
                self.torch = torch.zeros(shape).to(device)
            if use_bb:
                self.bb = bb.FrameBuffer(shape[0]*frame_modulation_size, shape[1:], dtype=dtype)
    
    def get_shape(self):
        if use_torch and self.torch is not None:
            return list(self.torch.shape)
        if use_bb and self.bb is not None:
            return [self.bb.get_frame_size()//frame_modulation_size] + self.bb.get_node_shape()
        return None
    
    def numpy(self):
        if use_torch and self.torch is not None:
            return to_numpy(self.torch)
        if use_bb and self.bb is not None:
            return to_numpy(self.bb)
        return None
    
    @staticmethod
    def zeros(shape, dtype=bb.DType.FP32):
        x = Buffer()
        if use_torch:
            x.torch = torch.zeros(shape).to(device)
        if use_bb:
            x.bb = bb.FrameBuffer.zeros(shape[0]*frame_modulation_size, shape[1:], dtype=dtype)
        return x
    
    @staticmethod
    def zeros_like(x, dtype=bb.DType.FP32):
        return Buffer.zeros(x.get_shape(), dtype=dtype)

### 基本ブロック定義

In [None]:
from torch.autograd import Function
class Binarize(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = x.new(x.size())
        y[x >  0] = +1.0
        y[x <= 0] = -1.0
        return y

    @staticmethod
    def backward(ctx, dy):
        x, = ctx.saved_tensors
        dx = dy.clone()
        dx[x.ge(1)]=0
        dx[x.le(-1)]=0
        return dx
binarize = Binarize.apply

class Through(Function):
    @staticmethod
    def forward(ctx, x):
        y = x.clone()
        return y
    
    @staticmethod
    def backward(ctx, dy):
        dx = dy.clone()
        return dx

through = Through.apply

In [None]:
class UpSampling(bb.Sequential):
    def __init__(self, name=""):
        if use_torch:
            self.up_torch = nn.Upsample(scale_factor=2, mode='nearest').to(device)
        
        if use_bb:
            self.up_bb    = bb.UpSampling((2, 2), fw_dtype=bin_dtype)
            super(UpSampling, self).__init__([self.up_bb], name=name)
        else:
            super(UpSampling, self).__init__([], name=name)
    
    def parameters(self):
        return list()
        
    def forward(self, x, train=True):
        y = Buffer()
        if use_torch:
            y.torch = self.up_torch(x.torch)
        if use_bb:
            y.bb = self.up_bb.forward(x.bb, train=train)
        return y

class MaxPooling(bb.Sequential):
    def __init__(self, name=""):
        if use_torch:
            self.pool_torch = nn.MaxPool2d(2, 2).to(device)
        if use_bb:
            self.pool_bb    = bb.MaxPooling((2, 2), fw_dtype=bin_dtype)
            super(MaxPooling, self).__init__([self.pool_bb], name=name)
        else:
            super(MaxPooling, self).__init__([], name=name)
    
    def parameters(self):
        return list()
    
    def forward(self, x, train=True):
        y = Buffer()
        if use_torch:
            y.torch = self.pool_torch(x.torch)
        if use_bb:
            y.bb = self.pool_bb.forward(x.bb, train=train)
        return y

class Concatenate(bb.Sequential):
    def __init__(self, name=""):
        self.cat_bb = bb.Concatenate()
        super(Concatenate, self).__init__([self.cat_bb], name=name)

    def parameters(self):
        return list()

    def set_input_shape(self, shape):
        return self.cat_bb.set_input_shape(shape)
    
    def forward(self, x, train=True):
        y = Buffer()
        if use_torch:
            y.torch = torch.cat([x[0].torch, x[1].torch], 1)
        if use_bb:
            y.bb    = self.cat_bb.forward([x[0].bb, x[1].bb], train=train)
        return y
    
    def backward(self, dy):
        dy = self.cat_bb.backward(dy)
        return dy


class ConvBlock(bb.Sequential):
    """基本ブロック"""
    def __init__(self, in_ch=32, hid_ch=32, out_ch=32, name=""):
        self.cnv0  = Convolution(in_ch, hid_ch, name=name+"_cnv0")
        self.cnv1  = Convolution(hid_ch, out_ch, name=name+"_cnv1")
        super(ConvBlock, self).__init__([self.cnv0, self.cnv1], name=name)
    
    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters()
    
    def copy_param_torch_to_bb(self):
        self.cnv0.copy_param_torch_to_bb()
        self.cnv1.copy_param_torch_to_bb()
        
    def copy_param_bb_to_torch(self):
        self.cnv0.copy_param_bb_to_torch()
        self.cnv1.copy_param_bb_to_torch()

    def set_input_shape(self, shape):
        shape = self.cnv0.set_input_shape(shape)
        shape = self.cnv1.set_input_shape(shape)
        return shape
    
    def forward(self, x, train=True):
        x = self.cnv0.forward(x, train=train)
        x = self.cnv1.forward(x, train=train)
        return x
    
    def backward(self, dy):
        dy = self.cnv1.backward(dy)
        dy = self.cnv0.backward(dy)
        return dy

In [None]:
class Distillation(bb.Sequential):
    def __init__(self, model_list, name=""):
        super(Distillation, self).__init__(model_list, name=name)
    
    def send_command(self, command, send_to='all'):
        super(BinaryConvolution, self).send_command(command, send_to)
        if send_to != "all" and send_to != self.get_name():
            return
        
        args = command.split()
        if len(args) == 2 and args[0] == 'log':
            if args[1] == 'start':
                print("[%s] log start"%self.get_name())
                self.log_file_x = open(self.log_x_name, 'wb')
                self.log_file_y = open(self.log_y_name, 'wb')
            if args[1] == 'stop':
                print("[%s] log stop"%self.get_name())
                if self.log_file_x is not None:
                    self.log_file_x.close()
                    self.log_file_x = None
                if self.log_file_y is not None:
                    self.log_file_y.close()
                    self.log_file_y = None
        
        if len(args) == 2 and args[0] == 'distillation':
            if args[1] == 'start':
                print("[%s] distillation start"%self.get_name())
                self.distillation_bb = True
                self.criterion_bb.clear()
                self.optimizer_bb.set_variables(self.lut_bb.get_parameters(), self.lut_bb.get_gradients())
            if args[1] == 'stop':
                self.distillation_bb = False                
                print("[%s] distillation stop loss=%f"%(self.get_name(), self.criterion_bb.get()))
            if args[1] == 'info':
                print("[%s] distillation loss=%f"%(self.get_name(), self.criterion_bb.get()))
    
    def forward(self, x, train=True):
        if verbose > 2:
            print("[%s] forward"%self.name)
            print_summary(x.torch,                    text='in torch')
            print_summary(x.bb.astype(bb.DType.FP32), text='in bb   ')
        
        if self.log_file_x is not None:
            pickle.dump(x.bb.astype(bb.DType.FP32).numpy(), self.log_file_x)
        
        if self.distillation:
            # 蒸留
            self.send_command("switch_model dense")                
            y = self._forward(x, train=train)
            self.send_command("switch_model lut")
            y_lut = self._forward(x, train=True)
            dy_lut = Buffer()
            dy_lut.bb = self.criterion_bb.calculate(y_lut.bb.astype(bb.DType.FP32), y.bb.astype(bb.DType.FP32))
            self.backward(dy_lut)
            self.optimizer_bb.update()
            self.send_command("switch_model dense")                
        else:
            y = self._forward(x, train=Train)
            
        if self.log_file_y is not None:
            pickle.dump(y.bb.astype(bb.DType.FP32).numpy(), self.log_file_y)
        
        if verbose > 2:
            print_summary(y.torch,                    text='out torch')
            print_summary(y.bb.astype(bb.DType.FP32), text='out bb   ')
        return y
    
    def backward(self, dy):
        if use_bb:
            dy = self.cnv_bb.backward(dy)
        return dy

### BinaryConvolution

In [None]:
class BinaryDenseAffineBb(bb.Sequential):
    def __init__(self, out_ch, batch_norm=True, binarize=True, depthwise=False, name=""):
        self.batch_norm = batch_norm
        self.binarize   = binarize
        if depthwise:
            self.affine = bb.DepthwiseDenseAffine([out_ch, 1, 1], name=name+'_affine', initializer="")
        else:
            self.affine = bb.DenseAffine([out_ch, 1, 1], name=name+'_affine', initializer="")
        self.bn     = bb.BatchNormalization(name=name+'_bn')
        self.act    = bb.Binarize(name=name+'_act', bin_dtype=bin_dtype)
        layers = [self.affine]
        if batch_norm: layers.append(self.bn)
        if binarize:   layers.append(self.act)
        super(BinaryDenseAffineBb, self).__init__(layers, name=name)
    
    def forward(self, x, train=True):
        x = self.affine.forward(x, train=True)
        if self.batch_norm:
            x = self.bn.forward(x, train=True)
        if self.binarize:
            x = self.act.forward(x, train=True)
        return x
    
    def backward(self, dy):
        if self.binarize:
            dy = self.act.backward(dy)
        if self.batch_norm:
            dy = self.bn.backward(dy)
        dy = self.affine.backward(dy)
        if verbose > 2:
            print_summary(dy)
        return dy

class BinaryDenseConvolutionBb(bb.Sequential):
    def __init__(self, out_ch, kernel_size=3, depthwise=False, batch_norm=True, binarize=True, name=""):
        self.cnv = bb.Convolution2d(
                        BinaryDenseAffineBb(out_ch, name=name+'_affine', depthwise=depthwise, batch_norm=batch_norm, binarize=binarize),
                        filter_size=(kernel_size, kernel_size),
                        padding='same',
                        name=name+'_cnv',
                        fw_dtype=bin_dtype)
        
        super(BinaryDenseConvolutionBb, self).__init__([self.cnv], name=name)

class LutConvolutionBb(bb.Sequential):
    def __init__(self, ch, kernel_size=3, depthwise=False, batch_norm=True, binarize=True, name=""):
        if depthwise:
            self.luts = bb.Sequential([
                            bb.DifferentiableLut([ch, 1, 1], connection='depthwise', name=name+'_lut', batch_norm=batch_norm, binarize=binarize, bin_dtype=bin_dtype),
                        ])
        else:
            self.luts = bb.Sequential([
                            bb.DifferentiableLut([ch*6*6, 1, 1], connection='random', name=name+'_lut2', momentum=0.9, batch_norm=batch_norm, binarize=binarize, bin_dtype=bin_dtype),
                            bb.DifferentiableLut([ch*6,   1, 1], connection='serial', name=name+'_lut1', momentum=0.9, batch_norm=batch_norm, binarize=binarize, bin_dtype=bin_dtype),
                            bb.DifferentiableLut([ch,     1, 1], connection='serial', name=name+'_lut0', momentum=0.9, batch_norm=batch_norm, binarize=binarize, bin_dtype=bin_dtype),
                        ])
        
        self.cnv = bb.Convolution2d(
                        self.luts,
                        filter_size=(kernel_size, kernel_size),
                        padding='same',
                        name=name+'_cnv',
                        fw_dtype=bin_dtype)
        super(LutConvolutionBb, self).__init__([self.cnv], name=name)

class BinaryConvolution(bb.Sequential):
    def __init__(self, in_ch, out_ch, kernel_size=3, depthwise=False, batch_norm=True, binarize=True, name=""):
        self.batch_norm = batch_norm
        self.binarize   = binarize
        
        self.log_file_x = None
        self.log_file_y = None
        self.log_x_name = os.path.join(data_path, "%s_x.pickle"%name)
        self.log_y_name = os.path.join(data_path, "%s_y.pickle"%name)
        
        self.distillation_bb = False
        self.criterion_bb = bb.LossMeanSquaredError()
        self.optimizer_bb = bb.OptimizerAdam(learning_rate=0.001)
        
        if use_torch:
            if kernel_size == 1:
                self.cnv_torch = nn.Conv2d(in_ch, out_ch, 1).to(device) # pointwise
            else:
                if depthwise:
                    self.cnv_torch = nn.Conv2d(in_ch, out_ch, kernel_size, padding=1, groups=in_ch, padding_mode='reflect').to(device)
                else:
                    self.cnv_torch = nn.Conv2d(in_ch, out_ch, kernel_size, padding=1, padding_mode='reflect').to(device) # 'replicate'
            self.bn_torch  = nn.BatchNorm2d(out_ch, eps=1e-07).to(device)
        
        if use_bb:
            self.dense_bb = BinaryDenseConvolutionBb(out_ch, name=name+'_cnv', depthwise=depthwise, batch_norm=batch_norm, binarize=binarize)
            self.lut_bb   = LutConvolutionBb(out_ch, name=name+'_lut', depthwise=depthwise, batch_norm=batch_norm, binarize=binarize)
            self.cnv_bb   = bb.Switcher({'lut': self.lut_bb, 'dense': self.dense_bb}, init_model_name = 'dense')
            super(BinaryConvolution, self).__init__([self.cnv_bb], name=name)
        else:
            super(BinaryConvolution, self).__init__([], name=name)
    
    def send_command(self, command, send_to='all'):
        super(BinaryConvolution, self).send_command(command, send_to)
        if send_to != "all" and send_to != self.get_name():
            return
        
        args = command.split()
        if len(args) == 2 and args[0] == 'log':
            if args[1] == 'start':
                print("[%s] log start"%self.get_name())
                self.log_file_x = open(self.log_x_name, 'wb')
                self.log_file_y = open(self.log_y_name, 'wb')
            if args[1] == 'stop':
                print("[%s] log stop"%self.get_name())
                if self.log_file_x is not None:
                    self.log_file_x.close()
                    self.log_file_x = None
                if self.log_file_y is not None:
                    self.log_file_y.close()
                    self.log_file_y = None
        
        if len(args) == 2 and args[0] == 'distillation':
            if args[1] == 'start':
                print("[%s] distillation start"%self.get_name())
                self.distillation_bb = True
                self.criterion_bb.clear()
                self.optimizer_bb.set_variables(self.lut_bb.get_parameters(), self.lut_bb.get_gradients())
            if args[1] == 'stop':
                self.distillation_bb = False                
                print("[%s] distillation stop loss=%f"%(self.get_name(), self.criterion_bb.get()))
            if args[1] == 'info':
                print("[%s] distillation loss=%f"%(self.get_name(), self.criterion_bb.get()))
                
    def parameters(self):
        if self.batch_norm:
            return list(self.cnv_torch.parameters()) + list(self.bn_torch.parameters())
        else:
            return list(self.cnv_torch.parameters())
    
    def copy_param_torch_to_bb(self):
        affine_bb    = self.cnv_bb['dense'].cnv[1].affine
        bn_bb        = self.cnv_bb['dense'].cnv[1].bn
        affine_torch = self.cnv_torch
        bn_torch     = self.bn_torch
        affine_bb.W().set_numpy(to_numpy(affine_torch.weight).reshape(affine_bb.W().get_shape()))
        affine_bb.b().set_numpy(to_numpy(affine_torch.bias).reshape(affine_bb.b().get_shape()))
        bn_bb.gamma().set_numpy(to_numpy(bn_torch.weight).reshape(bn_bb.gamma().get_shape()))
        bn_bb.beta().set_numpy(to_numpy(bn_torch.bias).reshape(bn_bb.beta().get_shape()))
        bn_bb.running_mean().set_numpy(to_numpy(bn_torch.running_mean).reshape(bn_bb.running_mean().get_shape()))
        bn_bb.running_var().set_numpy(to_numpy(bn_torch.running_var).reshape(bn_bb.running_var().get_shape()))

    def copy_param_bb_to_torch(self):
        affine_bb    = self.cnv_bb['dense'].cnv[1].affine
        bn_bb        = self.cnv_bb['dense'].cnv[1].bn
        affine_torch = self.cnv_torch
        bn_torch     = self.bn_torch
        affine_torch.weight = nn.Parameter(torch.tensor(affine_bb.W().numpy().reshape(list(affine_torch.weight.shape))).to(device))
        affine_torch.bias   = nn.Parameter(torch.tensor(affine_bb.b().numpy().reshape(list(affine_torch.bias.shape))).to(device))
        bn_torch.weight = nn.Parameter(torch.tensor(bn_bb.gamma().numpy().reshape(list(bn_torch.weight.shape))).to(device))
        bn_torch.bias = nn.Parameter(torch.tensor(bn_bb.beta().numpy().reshape(list(bn_torch.bias.shape))).to(device))
        bn_torch.running_mean = torch.tensor(bn_bb.running_mean().numpy().reshape(list(bn_torch.running_mean.shape))).to(device)
        bn_torch.running_var = torch.tensor(bn_bb.running_var().numpy().reshape(list(bn_torch.running_var.shape))).to(device)
        
    def set_input_shape(self, shape):
        shape = self.cnv_bb.set_input_shape(shape)
        return shape
    
    def forward(self, x, train=True):
        if verbose > 2:
            print("[%s] forward"%self.name)
            print_summary(x.torch,                    text='in torch')
            print_summary(x.bb.astype(bb.DType.FP32), text='in bb   ')
        
        if self.log_file_x is not None:
            pickle.dump(x.bb.astype(bb.DType.FP32).numpy(), self.log_file_x)
            
        y = Buffer()
        if use_torch:
            y.torch = self.cnv_torch(x.torch)
            if self.batch_norm:
                y.torch = self.bn_torch(y.torch)
            if self.binarize:
                y.torch = binarize(y.torch)
            y.torch = through(y.torch)
            
        if use_bb:
            if self.distillation_bb:
                # 蒸留
                self.send_command("switch_model dense")                
                y.bb = self.cnv_bb.forward(x.bb, train=train)
                self.send_command("switch_model lut")
                y_lut  = self.cnv_bb.forward(x.bb, train=True)
                dy_lut = self.criterion_bb.calculate(y_lut.astype(bb.DType.FP32), y.bb.astype(bb.DType.FP32))
                self.cnv_bb.backward(dy_lut)
                self.optimizer_bb.update()
                self.send_command("switch_model dense")                
            else:
                y.bb = self.cnv_bb.forward(x.bb, train=train)
        
        if self.log_file_y is not None:
            pickle.dump(y.bb.astype(bb.DType.FP32).numpy(), self.log_file_y)
        
        if verbose > 2:
            print_summary(y.torch,                    text='out torch')
            print_summary(y.bb.astype(bb.DType.FP32), text='out bb   ')
        return y
    
    def backward(self, dy):
        if use_bb:
            dy = self.cnv_bb.backward(dy)
        return dy

In [None]:
class MobileBlock(bb.Sequential):
    """MobileNet風モデル"""
    def __init__(self, in_ch=36, hid_ch=36, out_ch=36, top=False, name=""):
        self.cnv0 = BinaryConvolution(in_ch,  hid_ch, 1, name=name+"_pw0")
        self.cnv1 = BinaryConvolution(hid_ch, hid_ch, 3, name=name+"_dw", depthwise=True, )
        self.cnv2 = BinaryConvolution(hid_ch, out_ch, 1, name=name+"_pw1")
        super(MobileBlock, self).__init__([self.cnv0, self.cnv1, self.cnv2], name=name)
    
    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters() + self.cnv2.parameters()
    
    def copy_param_torch_to_bb(self):
        self.cnv0.copy_param_torch_to_bb()
        self.cnv1.copy_param_torch_to_bb()
        self.cnv2.copy_param_torch_to_bb()
        
    def copy_param_bb_to_torch(self):
        self.cnv0.copy_param_bb_to_torch()
        self.cnv1.copy_param_bb_to_torch()
        self.cnv2.copy_param_bb_to_torch()

### サブブロック

In [None]:
class MainBlock(bb.Sequential):
    """最上位(1/1スケール)階層モデル"""
    def __init__(self, top=False):
        self.up   = UpSampling()
        self.cat  = Concatenate()
        self.cnv0 = BinaryConvolution(1,  18, 3, name="m_cnv0")
        self.cnv1 = BinaryConvolution(32, 18, 1, name="m_cnv1")
        self.blk0 = MobileBlock(36, 36, 32, name="m_blk0")
        self.blk1 = MobileBlock(32, 36, 11, name="m_blk1")
        self.pool = MaxPooling()
        super(MainBlock, self).__init__([self.cnv0, self.cnv1, self.blk0, self.blk1])
    
    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters() + self.blk0.parameters() + self.blk1.parameters()
    
    def copy_param_torch_to_bb(self):
        self.cnv0.copy_param_torch_to_bb()
        self.cnv1.copy_param_torch_to_bb()
        self.blk0.copy_param_torch_to_bb()
        self.blk1.copy_param_torch_to_bb()
        
    def copy_param_bb_to_torch(self):
        self.cnv0.copy_param_bb_to_torch()
        self.cnv1.copy_param_bb_to_torch()
        self.blk0.copy_param_bb_to_torch()
        self.blk1.copy_param_bb_to_torch()
    
    def set_input_shape(self, x0_shape, x1_shape):
        x0_shape = self.cnv0.set_input_shape(x0_shape)
        
        x1_shape = self.up.set_input_shape(x1_shape)
        x1_shape = self.cnv1.set_input_shape(x1_shape)
        
        x_shape = self.cat.set_input_shape([x0_shape, x1_shape])
        print(x_shape)
        x_shape = self.blk0.set_input_shape(x_shape)
        y_shape = self.blk1.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(x_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, first, last, train=True):
        x0 = self.cnv0.forward(x0, train=train)
        
        if first:
            shape = x0.get_shape()
            shape[1] = 18
            x1 = Buffer.zeros(shape, dtype=bin_dtype)
        else:
            x1 = self.up.forward(x1, train=train)
            x1 = self.cnv1.forward(x1, train=train)
        
        x = self.cat.forward([x0, x1], train=train)
        x = self.blk0.forward(x, train=train)
        y = self.blk1.forward(x, train=train)
        if last:
            v = None
        else:
            v = self.pool.forward(x, train=train)
        return y, v
    
    def backward(self, dy, dv, first, last):
        dy = self.blk1.backward(dy)
        if last:
            dx = self.blk0.backward(dy)
        else:
            dv = self.pool.backward(dv)
            dx = self.blk0.backward(dy + dv)
        dx0, dx1 = self.cat.backward([dx])
        if first:
            dx1 = None
        else:
            dx1 = self.cnv1.backward(dx1)
            dx1 = self.up.backward(dx1)
        
        dx0 = self.cnv0.backward(dx0)        
        return dx0, dx1


class ScaleBlock(bb.Sequential):
    """下位層(1/4スケール以下)階層モデル"""
    def __init__(self, top=False):
        self.up   = UpSampling()
        self.cat0 = Concatenate()
        self.cat1 = Concatenate()
        self.cnv0  = BinaryConvolution(32, 12, 1, name="s_cnv0")
        self.cnv1  = BinaryConvolution(32, 12, 1, name="s_cnv1")
        self.cnv2  = BinaryConvolution(32, 12, 1, name="s_cnv2")
        self.blk  = MobileBlock(36, 36, 32, name="s_blk")
        self.pool = MaxPooling()
        super(ScaleBlock, self).__init__([self.cnv0, self.cnv1, self.cnv2, self.blk])
    
    def parameters(self):
        return self.cnv0.parameters() + self.cnv1.parameters() + self.cnv2.parameters() + self.blk.parameters()

    def copy_param_torch_to_bb(self):
        self.cnv.copy_param_torch_to_bb()
    
    def copy_param_bb_to_torch(self):
        self.cnv.copy_param_bb_to_torch()

    def set_input_shape(self, x0_shape, x1_shape, u_shape):
        x0_shape = self.cnv0.set_input_shape(x0_shape)
        x1_shape = self.up.set_input_shape(x1_shape)
        x1_shape = self.cnv1.set_input_shape(x1_shape)
        u_shape = self.cnv2.set_input_shape(u_shape)
        x_shape = self.cat0.set_input_shape([x0_shape, x1_shape])
        x_shape = self.cat1.set_input_shape([x_shape, u_shape])
        y_shape = self.blk.set_input_shape(x_shape)
        v_shape = self.pool.set_input_shape(y_shape)
        return y_shape, v_shape
    
    def forward(self, x0, x1, u, bottom, first, last, train=True):
        if first:
            u_shape = u.get_shape()
            u_shape[1] = 12
            x0 = Buffer.zeros(u_shape, dtype=bin_dtype)
        else:
            x0 = self.cnv0.forward(x0, train=train)
            
        if first or bottom:
            u_shape = u.get_shape()
            u_shape[1] = 12
            x1 = Buffer.zeros(u_shape, dtype=bin_dtype)
        else:
            x1 = self.up.forward(x1, train=train)
            x1 = self.cnv1.forward(x1, train=train)
        u = self.cnv2.forward(u, train=train)
        
        x = self.cat0.forward([x0, x1], train=train)
        x = self.cat1.forward([x, u], train=train)
        y = self.blk.forward(x, train=train)
        if last:
            v = None
        else:
            v = self.pool.forward(y, train=train)
        return y, v
       
    def backward(self, dy, dv, bottom, first, last):
        if last:
            dx = self.blk.backward(dy)
        else:
            dv = self.pool.backward(dv)
            dx = self.blk.backward(dy + dv)
        dx, du = self.cat1.backward([dx])
        dx0, dx1 = self.cat0.backward([dx])
        if first:
            dx0 = None
        else:
            dx0 = self.cnv0.backward(dx0)
        
        if first or bottom:
            dx1 = None
        else:
            dx1 = self.cnv1.backward(dx1)
            dx1 = self.up.backward(dx1)
        du = self.cnv2.backward(du)
        return dx0, dx1, du

### Mipmap型ネット

In [None]:
class MipmapNetwork(bb.Sequential):
    """ミップマップ型ネット"""
    def __init__(self, loops=2, depth=4):
        self.loops = loops
        self.depth = min(depth, loops-2)
        self.r2b   = bb.RealToBinary(frame_modulation_size=frame_modulation_size, bin_dtype=bin_dtype)
        self.b2r   = bb.BinaryToReal(frame_integration_size=frame_modulation_size, bin_dtype=bin_dtype)
        self.up    = UpSampling()
        self.m_blk = MainBlock()
        self.s_blk = ScaleBlock()
        super(MipmapNetwork, self).__init__([self.m_blk, self.s_blk])
    
    def parameters(self):
        return self.m_blk.parameters() + self.s_blk.parameters()
    
    def copy_param_torch_to_bb(self):
        self.m_blk.copy_param_torch_to_bb()
        self.s_blk.copy_param_torch_to_bb()
    
    def copy_param_bb_to_torch(self):
        self.m_blk.copy_param_bb_to_torch()
        self.s_blk.copy_param_bb_to_torch()

    def set_input_shape(self, shape):
        self.shape = copy.copy(shape)
        shape = self.r2b.set_input_shape(shape)
        
        x0_shape = copy.copy(shape)
        x1_shape = copy.copy(shape)
        x1_shape[0] = 32
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        y_shape, v_shape = self.m_blk.set_input_shape(x0_shape, x1_shape)
        
        x0_shape = copy.copy(v_shape)
        x0_shape[0] = 32
        x1_shape = copy.copy(v_shape)
        x1_shape[0] = 32
        x1_shape[1] //= 2
        x1_shape[2] //= 2
        self.s_blk.set_input_shape(x0_shape, x1_shape, v_shape)
        
        y_shape = self.b2r.set_input_shape(y_shape)
        return y_shape
    
    def forward(self, x, train=True):
        if use_bb:
            if bin_modulation:
                x.bb = self.r2b.forward(x.bb, train=train)
            else:
                x.bb = x.bb.astype(bin_dtype)
        
        mipmap = [None]*(self.depth+1)        
        y_list = []
        for i in range(self.loops):
            depth = min(self.depth, self.loops-1 - i)
            y, v = self.m_blk.forward(x, mipmap[0], first=(i<=1), last=(i==self.loops-1), train=train)
            if i >= 1:
                for j in range(depth):
                    bottom = (j==self.depth-1)
                    first  = (i==1)
                    last   = (j==(depth-1))
                    mipmap[j], v = self.s_blk.forward(mipmap[j], mipmap[j+1], v, bottom, first, last, train=train)
            if use_bb:
                if bin_modulation:
                    y.bb = self.b2r.forward(y.bb, train=train)
                else:
                    y.bb = y.bb.astype(bb.DType.FP32)
            y_list.append(y)
        return y_list
    
    def backward(self, dy_list):
        du = None
        mipmap = [None]*(self.depth+1)        
        for i in reversed(range(self.loops)):
            depth = min(self.depth, self.loops-1 - i)
            if i >= 1:
                for j in reversed(range(depth)):
                    bottom = (j==self.depth-1)
                    first  = (i==1)
                    last   = (j==(depth-1))
                    dx0, dx1, du = self.s_blk.backward(mipmap[j], du, bottom, first, last)
                    if last:
                        mipmap[j] = dx0
                        if not bottom:
                            mipmap[j+1] = dx1
                    elif not first:
                        mipmap[j] += dx0
                        mipmap[j+1] += dx1
            dy = dy_list[i]
            if use_bb and bin_modulation:
                dy = self.b2r.backward(dy_list[i])
            dx0, dx1 = self.m_blk.backward(dy, du, first=(i<=1), last=(i==self.loops-1))
            mipmap[0] = dx1
        return dx0
        
def view(net, loader, n=2):
    """表示確認"""
    for x_torch, t_torch in loader:
        break
    
    if not bin_modulation:
        x_torch[x_torch >  0.5] = +1
        x_torch[x_torch <= 0.5] = -1
    
    x_np = np.array(x_torch).astype(np.float32)
    x = Buffer()
    if use_bb:
        x.bb = bb.FrameBuffer.from_numpy(x_np)
    if use_torch:
        x.torch = x_torch.to(device)
    
    with torch.no_grad():
        y_list = net.forward(x, train=False)
    
    if use_torch:
        print('PyTorch')
        y_torch = to_numpy(y_list[-1].torch)
        plot_data(x_np, y_torch, n)
    if use_bb:
        print('BinaryBrain')
        y_bb = to_numpy(y_list[-1].bb)
        plot_data(x_np, y_bb, n)

## 学習

### ネット生成

In [None]:
# ネット生成
net = MipmapNetwork(loops=loops, depth=depth)
if use_bb:
    net.set_input_shape([1, img_h*rows, img_w*cols])
    net.send_command("binary true")

In [None]:
# 初期パラメータのヒストグラム比較
if False:
    W_bb    = to_numpy(net.m_blk.cnv0.cnv0.blk_bb.affine.W()).reshape(-1)
    W_torch = to_numpy(net.m_blk.cnv0.cnv0.cnv_torch.weight).reshape(-1)
    
    plt.figure(figsize=(10,4))
    plt.subplot(121)
    plt.title('BinaryBrain')
    plt.hist(W_bb, bins=30)
    plt.subplot(122)
    plt.title('PyTorch')
    plt.hist(W_torch, bins=30)

In [None]:
if use_bb:
    bb.load_networks(data_path, net)
#   bb.load_networks(data_path, net, name='dense_002_layers', read_layers=True)
    pass

net.send_command("switch_model dense")

In [None]:
# 初期パラメータを揃える場合
if False:
    net.copy_param_torch_to_bb()
#if True and use_torch and use_bb:
#    net.copy_param_bb_to_torch()

### 学習実施

In [None]:
def learning(data_path, net, epochs, learning_rate=0.001):
    if use_torch:
        weight_torch    = torch.from_numpy(weight.astype(np.float32)).clone().to(device)
        criterion_torch = nn.CrossEntropyLoss(weight=weight_torch)  # 面積に応じて重み付けする
        optimizer_torch = optim.Adam(net.parameters(), lr=learning_rate)

    if use_bb:
        criterion_bb = bb.LossSoftmaxCrossEntropy()
        metrics_bb   = bb.MetricsCategoricalAccuracy()
        optimizer_bb = bb.OptimizerAdam(learning_rate=learning_rate)
        optimizer_bb.set_variables(net.get_parameters(), net.get_gradients())

        criterion_bb.clear()
        metrics_bb.clear()
        net.clear()

    log_loss_bb    = []
    log_loss_torch = []

    for epoch in range(epochs):
        # learning
        if use_bb:
            criterion_bb.clear()
            metrics_bb.clear()
            net.clear()

        with tqdm(loader_train) as tqdm_loadr:
            for x_torch, t_torch in tqdm_loadr:
                if use_torch:
                    optimizer_torch.zero_grad()
                net.clear()

                if not bin_modulation:
                    x_torch[x_torch >  0.5] = +1
                    x_torch[x_torch <= 0.5] = -1

                x = Buffer()
                t = Buffer()
                if use_bb:
                    x.bb = bb.FrameBuffer.from_numpy(np.array(x_torch).astype(np.float32))
                    t.bb = bb.FrameBuffer.from_numpy(np.array(t_torch).astype(np.float32))

                if use_torch:
                    x.torch = x_torch.to(device)
                    t.torch = t_torch.to(device)

                y_list = net.forward(x, train=True)

                if use_torch:
                    y_torch_list = []
                    for y in y_list:
                        y_torch_list.append(y.torch)
                    yy_torch = torch.cat(y_torch_list, 0)
                    tt_torch = torch.cat([t.torch]*loops, 0)
                    loss_torch = criterion_torch(yy_torch, torch.argmax(tt_torch, dim=1))
                    loss_torch.backward()

                if use_bb:
                    dy_list = []
                    for y in y_list:
                        dy = criterion_bb.calculate(y.bb, t.bb)
                        dy /= len(y_list)
                        dy_list.append(dy)
                    net.backward(dy_list)

                if use_torch:
                    optimizer_torch.step()
                if use_bb:
                    optimizer_bb.update()

                loss_val_torch = 0
                loss_val_bb    = 0

                if use_torch:
                    loss_val_torch = loss_torch.item()
                if use_bb:
                    loss_val_bb    = criterion_bb.get()
#                   criterion_bb.clear()

                log_loss_torch.append(loss_val_torch)
                log_loss_bb.append(loss_val_bb)

                tqdm_loadr.set_postfix(loss_bb=loss_val_bb, loss_torch=loss_val_torch)
        
        if use_bb:
            bb.save_networks(data_path, net, backups=3)
        
        if epoch % 10 == 9:
            view(net, loader_test, n=2)
    
    if use_torch:
        plt.plot(log_loss_torch, label='PyTorch')
    if use_bb:
        plt.plot(log_loss_bb, label='BinaryBrain')
    plt.ylabel('loss')
    plt.legend()


def inference(net, loader, n=0):
    loops = 0
    with torch.no_grad():
        for x_torch, _ in tqdm(loader):
            net.clear()
            if not bin_modulation:
                x_torch[x_torch >  0.5] = +1
                x_torch[x_torch <= 0.5] = -1

            x = Buffer()
            if use_bb:
                x.bb = bb.FrameBuffer.from_numpy(np.array(x_torch).astype(np.float32))
            if use_torch:
                x.torch = x_torch.to(device)

            net.forward(x, train=False)
            
            loops += 1
            if n > 0 and loops >= n:
                break

In [None]:
if True:
    net.send_command("switch_model dense")
    learning(data_path, net, epochs=1)

In [None]:
if False:
#   bb.load_networks(data_path+'_distillation2', net)
    net.send_command("switch_model lut")
    learning(data_path, net, epochs=8, learning_rate=0.01)

In [None]:
view(net, loader_test, n=8)

In [None]:
net

In [None]:
-----------

In [None]:
luts_list = [
    net.s_blk.blk,
    net.s_blk.cnv0,
    net.s_blk.cnv1,
    net.s_blk.cnv2,
    net.m_blk.blk0,
    net.m_blk.cnv0,
    net.m_blk.cnv1,
    net.m_blk.blk1,
]

In [None]:
net.send_command("switch_model dense")

In [None]:
for i in range(len(luts_list)):
    print(i)
    net.send_command("parameter_lock true")
    luts_list[i].send_command("switch_model lut")
    luts_list[i].send_command("parameter_lock false")
    learning(os.path.join(data_path, 'dist_s'), net, epochs=8, learning_rate=0.01)
    learning(os.path.join(data_path, 'dist_s'), net, epochs=4, learning_rate=0.001)

In [None]:
learning(os.path.join(data_path, 'dist_s'), net, epochs=4, learning_rate=0.001)

In [None]:
--------

In [None]:
net.send_command("switch_model lut")
net.send_command("parameter_lock true")
net.m_blk.send_command("parameter_lock false")
learning(os.path.join(data_path, 'dist_m'), net, epochs=8, learning_rate=0.01)
learning(os.path.join(data_path, 'dist_m'), net, epochs=4, learning_rate=0.001)

In [None]:
net.send_command("switch_model lut")
net.send_command("parameter_lock true")
net.s_blk.send_command("parameter_lock false")
learning(os.path.join(data_path, 'dist_s'), net, epochs=8, learning_rate=0.01)
learning(os.path.join(data_path, 'dist_s'), net, epochs=4, learning_rate=0.001)
view(net, loader_test, n=8)

In [None]:
for _ in range(8):
    net.send_command("switch_model lut")
    net.send_command("parameter_lock true")
    net.m_blk.send_command("parameter_lock false")
    learning(os.path.join(data_path, 'dist_m'), net, epochs=2, learning_rate=0.001)
    
    net.send_command("switch_model lut")
    net.send_command("parameter_lock true")
    net.s_blk.send_command("parameter_lock false")
    learning(os.path.join(data_path, 'dist_s'), net, epochs=2, learning_rate=0.001)
    view(net, loader_test, n=8)

In [None]:
net.send_command("switch_model lut")
net.send_command("parameter_lock false")
learning(os.path.join(data_path, 'lut'), net, epochs=20, learning_rate=0.0001)

In [None]:
view(net, loader_test, n=8)

In [None]:
bb.load_networks(os.path.join(data_path, 'dist_s'), net)

In [None]:
net.send_command("parameter_lock true")
net.s_blk.cnv0.send_command("switch_model lut")
net.s_blk.cnv1.send_command("switch_model lut")
net.s_blk.cnv2.send_command("switch_model lut")
net.s_blk.cnv0.send_command("parameter_lock true")
net.s_blk.cnv1.send_command("parameter_lock true")
net.s_blk.cnv2.send_command("parameter_lock false")

In [None]:
bb.save_networks(os.path.join(data_path, 'dist_indp'), net, backups=3)

In [None]:
-----------

In [None]:
net.send_command('distillation start')

In [None]:
#bb.load_networks(data_path+'_distillation2', net)
for _ in range(4):
    inference(net, loader_train)
    net.send_command('distillation info')
    bb.save_networks(data_path+'_distillation2', net, backups=3)

In [None]:
net.send_command('distillation info')

In [None]:
net.send_command('distillation stop')

In [None]:
net.send_command("switch_model lut")
view(net, loader_test, n=8)

In [None]:
-----------

In [None]:
bb.save_networks(data_path, net, name='dense_001')

In [None]:
bb.save_networks(data_path, net, name='dense_002_layers', write_layers=True)

In [None]:
if False:
    net.send_command('log start')
#   learning(net, 1)
    inference(net, loader_train)
    net.send_command('log stop')

In [None]:
def inference_x(net, loader):
    for x_torch, t_torch in loader:
        x_np = np.array(x_torch).astype(np.float32)
        x = Buffer()
        if not bin_modulation:
            x_torch[x_torch >  0.5] = +1
            x_torch[x_torch <= 0.5] = -1
        if use_bb:
            x.bb = bb.FrameBuffer.from_numpy(x_np)
        if use_torch:
            x.torch = x_torch.to(device)

        with torch.no_grad():
            y_list = net.forward(x, train=False)

In [None]:
def read_list(filename):
    with open(filename, 'rb') as f:
        x_list = []
        while True:
            try:
                x = pickle.load(f)
            except:
                break
            x_list.append(x)
    return x_list

def read_lists(cnv_blk):
    x_list = read_list(cnv_blk.log_x_name)
    y_list = read_list(cnv_blk.log_y_name)
    return x_list, y_list

In [None]:
def distilation_lut(model, xy, epochs=2, learning_rate=0.001):
    print('----')
    x_list, y_list = xy
    criterion_bb = bb.LossMeanSquaredError()
    optimizer_bb = bb.OptimizerAdam(learning_rate=learning_rate)
    optimizer_bb.set_variables(model.get_parameters(), model.get_gradients())

    for epoch in range(epochs):
        criterion_bb.clear()
        model.clear()

        with tqdm(range(len(x_list))) as tqdm_loadr:
            for i in tqdm_loadr:
                x = x_list[i]
                t = y_list[i]
                x = bb.FrameBuffer.from_numpy(x).astype(bin_dtype)
                t = bb.FrameBuffer.from_numpy(t)

                model.clear()

                y = model.forward(x, train=True)

                dy = criterion_bb.calculate(y.astype(bb.DType.FP32), t.astype(bb.DType.FP32))
                model.backward(dy)
                optimizer_bb.update()

                tqdm_loadr.set_postfix(loss=criterion_bb.get())
        bb.save_networks(data_path, net)

def distilation(cnv_blk, model=None, epochs=2, learning_rate=0.001):
    if model is None:
        model = cnv_blk.cnv_bb['lut']
    else:
        model.set_input_shape(cnv_blk.cnv_bb['lut'].get_input_shape())
    xy = read_lists(cnv_blk)
    distilation_lut(model, xy, pochs=epochs, learning_rate=learning_rate)

In [None]:
xy = read_lists(net.s_blk.cnv.cnv1)

In [None]:
class LutConvBlock2(bb.Sequential):
    def __init__(self, hid_ch, out_ch, batch_norm=True, activation=True, name=""):
        self.cnv0 = bb.Convolution2d(
                        bb.Sequential([
                            bb.DifferentiableLut([hid_ch*6, 1, 1], connection='random', name=name+'_lut0_0', batch_norm=batch_norm, binarize=activation, bin_dtype=bin_dtype),
                            bb.DifferentiableLut([hid_ch,   1, 1], connection='serial', name=name+'_lut0_1', batch_norm=batch_norm, binarize=activation, bin_dtype=bin_dtype),
                        ]),
                        filter_size=(1, 1),
                        padding='same',
                        name=name+'_cnv0',
                        fw_dtype=bin_dtype)
        self.cnv1 = bb.Convolution2d(
                        bb.Sequential([
                            bb.DifferentiableLut([hid_ch,   1, 1], connection='depthwise', name=name+'_lut1_0', batch_norm=batch_norm, binarize=activation, bin_dtype=bin_dtype),
                        ]),
                        filter_size=(3, 3),
                        padding='same',
                        name=name+'_cnv1',
                        fw_dtype=bin_dtype)
        self.cnv2 = bb.Convolution2d(
                        bb.Sequential([
                            bb.DifferentiableLut([out_ch*6*6, 1, 1], connection='random', name=name+'_lut2_0', batch_norm=batch_norm, binarize=activation, bin_dtype=bin_dtype),
                            bb.DifferentiableLut([out_ch*6,   1, 1], connection='serial', name=name+'_lut2_1', batch_norm=batch_norm, binarize=activation, bin_dtype=bin_dtype),
                            bb.DifferentiableLut([out_ch,     1, 1], connection='serial', name=name+'_lut2_2', batch_norm=batch_norm, binarize=activation, bin_dtype=bin_dtype),
                        ]),
                        filter_size=(1, 1),
                        padding='same',
                        name=name+'_cnv2',
                        fw_dtype=bin_dtype)
        super(LutConvBlock2, self).__init__([self.cnv0, self.cnv1, self.cnv2], name=name)


#lut_net = bb.Sequential([
#                LutConvBlock2(36*4, 36),
#                LutConvBlock2(36*2, 32),
#             ])

lut_net = LutConvBlock2(36*4, 32)

lut_net.set_input_shape(net.s_blk.cnv.cnv1.cnv_bb['lut'].get_input_shape())
distilation_lut(lut_net, xy, epochs=1, learning_rate=0.01)
#distilation(net.s_blk.cnv.cnv1, net.s_blk.cnv.cnv1.cnv_bb['lut'])
#distilation_lut(net.s_blk.cnv.cnv1, lut_net, learning_rate=0.01)

In [None]:
distilation_lut(lut_net, xy, epochs=8, learning_rate=0.005)

In [None]:
print(lut_net.get_info())

In [None]:
net.send_command("switch_model lut")

In [None]:
distilation(net.s_blk.cnv.cnv1)

In [None]:
print(lut_net.get_info())

In [None]:
print(net.s_blk.cnv.cnv1.cnv_bb['lut'].get_info())

In [None]:
print(net.s_blk.cnv.cnv1.cnv_bb['lut'].get_input_shape())

In [None]:
distilation(net.m_blk.cnv0.cnv0)
distilation(net.m_blk.cnv0.cnv1)
distilation(net.m_blk.cnv1.cnv0)
distilation(net.m_blk.cnv1.cnv1)
distilation(net.s_blk.cnv.cnv0)
distilation(net.s_blk.cnv.cnv1)

In [None]:
view(net, loader_test, n=2)

In [None]:
learning(net, 1)

In [None]:
y_list[0].shape

In [None]:
t.get_frame_size()

In [None]:
-----

In [None]:

criterion_bb = bb.LossMeanSquaredError()
optimizer_bb = bb.OptimizerAdam(learning_rate=0.01)
optimizer_bb.set_variables(model.get_parameters(), model.get_gradients())

for epoch in range(epochs):
    criterion_bb.clear()
    model.clear()
    
    with tqdm(range(len(x_list))) as tqdm_loadr:
        for i in tqdm_loadr:
            x = x_list[i]
            t = y_list[i]
            x = bb.FrameBuffer.from_numpy(x).astype(bin_dtype)
            t = bb.FrameBuffer.from_numpy(t)
            
            model.clear()
            
            y = model.forward(x, train=True)

            dy = criterion_bb.calculate(y.astype(bb.DType.FP32), t.astype(bb.DType.FP32))
            model.backward(dy)
            optimizer_bb.update()
            
            tqdm_loadr.set_postfix(loss=criterion_bb.get())
    bb.save_networks(data_path, net)

In [None]:
cnv_blk = net.m_blk.cnv0.cnv0
x_list = read_list(cnv_blk.log_x_name)
y_list = read_list(cnv_blk.log_y_name)

#dense = cnv_blk.cnv_bb['dense']
model = cnv_blk.cnv_bb['lut']

In [None]:
#bb.save_networks(data_path, net, name='layers', write_layers=True)

In [None]:
bb.save_networks(data_path + '_lut', net)