In [None]:
# coding: utf-8
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import cuda, Variable
from chainer import optimizers
import numpy as np
import data

In [None]:
class Generator(chainer.Chain):
    def __init__(self, ):
        super(Generator, self).__init__(
            conv1=L.Convolution2D(None, 64, 4, 2, 1),
            conv2=L.Convolution2D(None, 128, 4, 2, 1),
            norm2=L.BatchNormalization(128),
            conv3=L.Convolution2D(None, 256, 4, 2, 1),
            norm3=L.BatchNormalization(256),
            conv4=L.Convolution2D(None, 512, 4, 2, 1),
            norm4=L.BatchNormalization(512),

            deconv1=L.Deconvolution2D(None, 256, 4, 2, 1),
            dnorm1=L.BatchNormalization(256),
            deconv2=L.Deconvolution2D(None, 128, 4, 2, 1),
            dnorm2=L.BatchNormalization(128),
            deconv3=L.Deconvolution2D(None, 64, 4, 2, 1),
            dnorm3=L.BatchNormalization(64),
            deconv4=L.Deconvolution2D(None, 3, 4, 2, 1),
            )

    def __call__(self, x, test=False):
        # convolution
        h1 = F.leaky_relu(self.conv1(x))
        h2 = F.leaky_relu(self.norm2(self.conv2(h1), test=test))
        h3 = F.leaky_relu(self.norm3(self.conv3(h2), test=test))
        h4 = F.leaky_relu(self.norm4(self.conv4(h3), test=test))

        # deconvolution
        dh1 = F.leaky_relu(self.dnorm1(self.deconv1(h4), test=test))
        dh2 = F.leaky_relu(self.dnorm2(self.deconv2(dh1), test=test))
        dh3 = F.leaky_relu(self.dnorm3(self.deconv3(dh2), test=test))
        #y = F.tanh(self.deconv4(dh3))
        y = self.deconv4(dh3)
        return y


class Discriminator(chainer.Chain):
    def __init__(self):
        super(Discriminator, self).__init__(
            conv1=L.Convolution2D(None, 512, 4, 2, 1),
            conv2=L.Convolution2D(None, 256, 4, 2, 1),
            norm2=L.BatchNormalization(256),
            conv3=L.Convolution2D(None, 128, 4, 2, 1),
            norm3=L.BatchNormalization(128),
            conv4=L.Convolution2D(None, 64, 4, 2, 1),
            norm4=L.BatchNormalization(64),
            conv5=L.Convolution2D(None, 1, 4)
            )

    def __call__(self, x, test=False):
        # convolution
        h1 = F.leaky_relu(self.conv1(x))
        h2 = F.leaky_relu(self.norm2(self.conv2(h1), test=test))
        h3 = F.leaky_relu(self.norm3(self.conv3(h2), test=test))
        h4 = F.leaky_relu(self.norm4(self.conv4(h3), test=test))

        # full connect
        # the size of feature map is 4x4.
        # So convolution with 4x4 filter is similar to full connect.
        y = self.conv5(h4)
        return y, [h2, h3, h4]



In [None]:
class DiscoGANUpdater():
    def __init__(self, DataA, DataB, opt_g_ab, opt_g_ba, opt_d_a, opt_d_b):
        self.generator_ab = opt_g_ab.target
        self.generator_ba = opt_g_ba.target
        self.discriminator_a = opt_d_a.target
        self.discriminator_b = opt_d_b.target
        self._optimizers = {'generator_ab': opt_g_ab,
                            'generator_ba': opt_g_ba,
                            'discriminator_a': opt_d_a,
                            'discriminator_b': opt_d_b}
        self.iteration = 0
        self.xp = self.generator_ab.xp
        self.DataA = DataA
        self.DataB = DataB

    def compute_loss_gan(self, y_real, y_fake):
        batchsize = y_real.shape[0]
        loss_dis = 0.5 * F.sum(F.softplus(-y_real) + F.softplus(y_fake))
        loss_gen = F.sum(F.softplus(-y_fake))
        return loss_dis / batchsize, loss_gen / batchsize

    def compute_loss_feat(self, feats_real, feats_fake):
        losses = 0
        for feat_real, feat_fake in zip(feats_real, feats_fake):
            feat_real_mean = F.sum(feat_real, 0) / feat_real.shape[0]
            feat_fake_mean = F.sum(feat_fake, 0) / feat_fake.shape[0]
            l2 = (feat_real_mean - feat_fake_mean) ** 2
            loss = F.sum(l2) / l2.size
            losses += loss
        return losses

    def update_core(self, batchSize = 100):      
          
        # read data
        DataN = data.get_data_N_rand(self.DataA, N_pic = batchSize, imgH = 64, imgW = 64)        
        x_a = Variable(self.xp.asarray(DataN['x']))
  
        DataN = data.get_data_N_rand(self.DataB, N_pic = batchSize, imgH = 64, imgW = 64)        
        x_b = Variable(self.xp.asarray(DataN['x']))
        
        batchsize = x_a.shape[0]

        # conversion
        x_ab = self.generator_ab(x_a)
        x_ba = self.generator_ba(x_b)

        # reconversion
        x_aba = self.generator_ba(x_ab)
        x_bab = self.generator_ab(x_ba)

        # reconstruction loss
        recon_loss_a = F.mean_squared_error(x_a, x_aba)
        recon_loss_b = F.mean_squared_error(x_b, x_bab)

        # discriminate
        y_a_real, feats_a_real = self.discriminator_a(x_a)
        y_a_fake, feats_a_fake = self.discriminator_a(x_ba)

        y_b_real, feats_b_real = self.discriminator_b(x_b)
        y_b_fake, feats_b_fake = self.discriminator_b(x_ab)

        # GAN loss
        gan_loss_dis_a, gan_loss_gen_a = self.compute_loss_gan(y_a_real, y_a_fake)
        feat_loss_a = self.compute_loss_feat(feats_a_real, feats_a_fake)

        gan_loss_dis_b, gan_loss_gen_b = self.compute_loss_gan(y_b_real, y_b_fake)
        feat_loss_b = self.compute_loss_feat(feats_b_real, feats_b_fake)

        # compute loss
        if self.iteration < 10000:
            rate = 0.01
        else:
            rate = 0.5

        total_loss_gen_a = (1.-rate)*(0.1*gan_loss_gen_b + 0.9*feat_loss_b) + rate * recon_loss_a
        total_loss_gen_b = (1.-rate)*(0.1*gan_loss_gen_a + 0.9*feat_loss_a) + rate * recon_loss_b

        gen_loss = total_loss_gen_a + total_loss_gen_b
        dis_loss = gan_loss_dis_a + gan_loss_dis_b

        if self.iteration % 3 == 0:
            self.discriminator_a.cleargrads()
            self.discriminator_b.cleargrads()
            dis_loss.backward()
            self._optimizers['discriminator_a'].update()
            self._optimizers['discriminator_b'].update()
        else:
            self.generator_ab.cleargrads()
            self.generator_ba.cleargrads()
            gen_loss.backward()
            self._optimizers['generator_ab'].update()
            self._optimizers['generator_ba'].update()
        
        self.iteration = self.iteration+1

In [None]:
#モデル生成
gpu = 0 # 0：gpu使用、-1：gpu不使用
generator_ab = Generator()
generator_ba = Generator()
discriminator_a = Discriminator()
discriminator_b = Discriminator()
if gpu >= 0:
    chainer.cuda.get_device(gpu).use()
    generator_ab.to_gpu()
    generator_ba.to_gpu()
    discriminator_a.to_gpu()
    discriminator_b.to_gpu()
    
xp = cuda.cupy if gpu >= 0 else np

opt_g_ab = chainer.optimizers.Adam(2e-4, beta1=0.5, beta2=0.999)
opt_g_ab.setup(generator_ab)
opt_g_ab.add_hook(chainer.optimizer.WeightDecay(1e-4))
opt_g_ba = chainer.optimizers.Adam(2e-4, beta1=0.5, beta2=0.999)
opt_g_ba.setup(generator_ba)
opt_g_ba.add_hook(chainer.optimizer.WeightDecay(1e-4))

opt_d_a = chainer.optimizers.Adam(2e-4, beta1=0.5, beta2=0.999)
opt_d_a.setup(discriminator_a)
opt_d_a.add_hook(chainer.optimizer.WeightDecay(1e-4))
opt_d_b = chainer.optimizers.Adam(2e-4, beta1=0.5, beta2=0.999)
opt_d_b.setup(discriminator_b)
opt_d_b.add_hook(chainer.optimizer.WeightDecay(1e-4))

In [None]:
#データ読み込み
DataA = data.get_ori_data_x_1pic(fpath = 'imgA1.png')
DataB = data.get_rand_core()#ランダム位置での球状のデータを作成

updater = DiscoGANUpdater(DataA, DataB, opt_g_ab, opt_g_ba,opt_d_a, opt_d_b)

In [None]:
from bokeh.plotting import figure
from bokeh.io import gridplot, push_notebook, show, output_notebook

output_notebook()

ch, imgH, imgW = DataA['x'][0].shape

img1 = data.get_view_img(DataA['x'][0])
img2 = data.get_view_img(DataB['x'][0])

plt1 = figure(title = 'train N = --', x_range=[0, imgW], y_range=[0, imgH])
rend1 = plt1.image_rgba(image=[img1],x=[0], y=[0], dw=[imgW], dh=[imgH])

plt2 = figure(title = 'count  = --', x_range=plt1.x_range, y_range=plt1.y_range)
rend2 = plt2.image_rgba(image=[img2],x=[0], y=[0], dw=[imgW], dh=[imgH])

plts = gridplot([[plt1,plt2]], plot_width=400, plot_height=400)
handle = show(plts, notebook_handle=True)

In [None]:
#imgAトレーニングとテスト
batchSize = 100
N_train = 1000
nLoop = 1000

for loop in range(nLoop): 
    

    #＜トレーニング＞
    for i in range(0, N_train, batchSize):
        
        #アップデート
        updater.update_core()
        
    # ＜テスト＞
    x_a = chainer.Variable(xp.asarray(DataA['x']), volatile='on')
    x_ab = generator_ab(x_a, test=False)
    x_aba = generator_ba(x_ab, test=False)
    
    x_ab.to_cpu()
    x_aba.to_cpu()
    
    DataA['ab']=x_ab.data
    DataA['aba']=x_aba.data    
    
    #極大値の位置を抽出して、赤いサークルを元絵に重ねる
    img_temp = DataA['ab'].sum(axis = 1)/3
    DataA['ab_point'] = data.get_local_max_point(img_temp[:,np.newaxis,:,:],threshold = 50)
    DataA['ab_circle']=data.draw_circle(DataA['ab_point'])
    DataA['x_ab_circle'] = DataA['x'].copy()
    DataA['x_ab_circle'][:,0,:,:]=DataA['x_ab_circle'][:,0,:,:] + DataA['ab_circle']
    
    #グラフィック表示    
    img1 = data.get_view_img(DataA['x_ab_circle'][0])
    img2 = data.get_view_img(DataA['ab'][0])
    rend1.data_source.data['image'] = [img1]
    rend2.data_source.data['image'] = [img2]
    plt1.title.text='train N = '+ str((loop+1)*N_train)
    plt2.title.text='count  =  '+ str(DataA['ab_point'].sum()/255)
    push_notebook(handle = handle)#表示をアップデート

In [None]:
chainer.serializers.save_npz('generator_ab.npz',generator_ab)
chainer.serializers.save_npz('generator_ba.npz',generator_ba)