In [None]:
# coding: utf-8
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import cuda, Variable
from chainer import optimizers
import numpy as np
import data

In [None]:
class CNN_deCNN(chainer.Chain):

    def __init__(self):
        super(CNN_deCNN, self).__init__(
            conv1=L.Convolution2D(1, 10, 4, stride=2,pad=1),
            conv2=L.Convolution2D(10, 25, 4, stride=2,pad=1),
            conv3=L.Convolution2D(25, 50, 4, stride=2,pad=1),
            conv4=L.Convolution2D(50, 100, 4, stride=2,pad=1),
            deconv1=L.Deconvolution2D(100, 50, 4, stride=2,pad=1),           
            deconv2=L.Deconvolution2D(50, 25, 4, stride=2,pad=1),
            deconv3=L.Deconvolution2D(25, 10, 4, stride=2,pad=1),
            deconv4=L.Deconvolution2D(10, 1, 4, stride=2,pad=1),
            bn1 = L.BatchNormalization(10),
            bn2 = L.BatchNormalization(25),
            bn3 = L.BatchNormalization(50),
            bn4 = L.BatchNormalization(100),
            bn5 = L.BatchNormalization(50),
            bn6 = L.BatchNormalization(25),
            bn7 = L.BatchNormalization(10),
            bn8 = L.BatchNormalization(1),
        )
        self.train = True

    def __call__(self, x):
        h = F.relu(self.bn1(self.conv1(x),test=not self.train))
        h = F.relu(self.bn2(self.conv2(h),test=not self.train))
        h = F.relu(self.bn3(self.conv3(h),test=not self.train))
        h = F.relu(self.bn4(self.conv4(h),test=not self.train))
        h = F.relu(self.bn5(self.deconv1(h),test=not self.train))
        h = F.relu(self.bn6(self.deconv2(h),test=not self.train))
        h = F.relu(self.bn7(self.deconv3(h),test=not self.train))
        y = F.relu(self.bn8(self.deconv4(h),test=not self.train))
        return y

In [None]:
class Discriminator(chainer.Chain):
    def __init__(self):
        super(Discriminator, self).__init__(
            conv1=L.Convolution2D(2, 10, 4, stride=2, pad=1),
            conv2=L.Convolution2D(10, 25, 4, stride=2, pad=1),
            conv3=L.Convolution2D(25, 50, 4, stride=2, pad=1),
            conv4=L.Convolution2D(50, 100, 4, stride=2, pad=1),
            l5 = L.Linear(None, 2),
            bn1 = L.BatchNormalization(10),
            bn2 = L.BatchNormalization(25),
            bn3 = L.BatchNormalization(50),
            bn4 = L.BatchNormalization(100),
        )
        self.train = True
        
    def __call__(self, x, test=False):
        h = F.relu(self.conv1(x))
        h = F.relu(self.bn2(self.conv2(h), test=not self.train))
        h = F.relu(self.bn3(self.conv3(h), test=not self.train))
        h = F.relu(self.bn4(self.conv4(h), test=not self.train))
        y = self.l5(h)
        return y

In [None]:
#モデル読み込み
gpu = 0 # 0：gpu使用、-1：gpu不使用

g_model = CNN_deCNN()
d_model = Discriminator()

if gpu >= 0:
    cuda.check_cuda_available()
xp = cuda.cupy if gpu >= 0 else np

if gpu >= 0:
    cuda.get_device(gpu).use()
    g_model.to_gpu()
    d_model.to_gpu()
    
g_optimizer = optimizers.Adam()
g_optimizer.setup(g_model)

d_optimizer = optimizers.Adam()
d_optimizer.setup(d_model)

In [None]:
#imgAのｘとｔのデータ読み込み
DataA = data.get_ori_data_pos_1pic(fpath = 'imgA.png')
#imgBのxのみのデータを取得
DataB = data.get_ori_data_x_1pic(fpath = 'imgB.png')

In [None]:
from bokeh.plotting import figure
from bokeh.io import gridplot, push_notebook, show, output_notebook

output_notebook()

palette_256 = ['#%02x%02x%02x' %(i,i,i) for i in range(256)] #256段階で白黒表示用

plt1 = figure(title = 'train N = --', x_range=[0, DataA['imgW']], y_range=[0, DataA['imgH']])
rend1 = plt1.image(image=[np.zeros_like(DataB['x'][0][0])],x=[0], y=[0], dw=[DataA['imgW']], dh=[DataA['imgH']], palette=palette_256)

plt2 = figure(title = 'count  = --', x_range=plt1.x_range, y_range=plt1.y_range)
rend2 = plt2.image(image=[DataB['x'][0][0]],x=[0], y=[0], dw=[DataA['imgW']], dh=[DataA['imgH']], palette=palette_256)

plts = gridplot([[plt1,plt2]], plot_width=400, plot_height=400)
handle = show(plts, notebook_handle=True)

In [None]:
#imgAトレーニングとテスト
batchSize = 100
N_train = 2000
nLoop = 100

for loop in range(nLoop): 
    
    print('loop = ' + str(loop))

    #＜トレーニング＞
    g_model.train = True
    
    for i in range(0, N_train, batchSize):
        
        #データ取得
        DataN = data.get_data_N_rand(DataA, batchSize)        
        x_batch = Variable(xp.asarray(DataN['x']))
        t_batch = Variable(xp.asarray(DataN['t_core']))

        #学習      
        y_batch = g_model(x_batch)
        
        # 直接の場合

        g_optimizer.zero_grads()
        g_loss =  F.mean_squared_error(y_batch, t_batch)
        g_loss.backward()
        g_optimizer.update()
 
    

        # pix2pixの場合
        """             
        y_pair = F.hstack((x_batch, y_batch))
        t_pair = F.hstack((x_batch, t_batch))
        
        y_fake = d_model(y_pair)   
        
        g_optimizer.zero_grads()
        g_loss = F.softmax_cross_entropy(y_fake, Variable(xp.zeros(batchSize, dtype=np.int32)))        
        g_loss.backward()
        g_optimizer.update()        
        
        d_optimizer.zero_grads()
        d_loss = F.softmax_cross_entropy(y_fake, Variable(xp.ones(batchSize, dtype=np.int32))) 
        y_original = d_model(t_pair)
        d_loss += F.softmax_cross_entropy(y_original, Variable(xp.zeros(batchSize, dtype=np.int32)))         
        d_loss.backward()
        d_optimizer.update()
        """ 
        
    # ＜テスト＞
    g_model.train = False
    x_batch = chainer.Variable(xp.asarray(DataB['x']), volatile='on')
    y_batch = g_model(x_batch)
    
    y_batch.to_cpu()
    DataB['y'] = y_batch.data
    DataB['y_point']= data.get_local_max_point(DataB['y'], 0.3)
 
    print( ', y_count = ' + str(DataB['y_point'][0].sum()))
    
    #グラフィック表示
    DataB['y_circle'] = data.draw_circle(DataB['y_point'])
    img1 = DataB['y'][0][0] 
    img2 = DataB['x'][0][0] + DataB['y_circle'][0][0]
    img1[img1>1] =1
    img2[img2>1] =1
    rend1.data_source.data['image'] = [img1]
    rend2.data_source.data['image'] = [img2]
    plt1.title.text='train N = '+ str((loop+1)*N_train)
    plt2.title.text='count = '+ str(int(DataB['y_point'][0].sum()))
    push_notebook(handle = handle)#表示をアップデート