In [1]:
## network workflow
import os
import keras
from keras.models import Model
# from keras.layers.convolutional import Convolution2D
# from keras.layers.convolutional import Conv3d
# from keras.layers.convolutional import Conv3D
# from keras.layers import Conv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D,Conv2DTranspose
from keras.layers.normalization import BatchNormalization
from keras.layers import (
    Input,
    Activation,
    merge,
    Dense,
    Lambda,
    Reshape,
    Dropout,
    Concatenate
)
os.environ["CUDA_VISIBLE_DEVICES"]="2"

Using TensorFlow backend.


In [13]:
class vspGAN():
    def __init__(self):
        self.gf = 64
        self.df = 64
        self.nt = 2001
        self.nr = 467
        self.nph = 3
        self.nop = 1
        
#         self.disc_patch = (int(self.nt / 2**4), int(self.nr / 2**4), 1)
        self.disc_patch = (int(self.nt / 2**4), int(self.nr / 2**4), 1)
        optimizer = Adam(0.0002, 0.5)
        
        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        
        # Build the generator
        self.generator = self.build_generator()
        
        
    def build_generator(self):  
        '''
        U-net generator
        '''
        def cutLayer(xx, target):
            return xx[:, 0:int(target.shape[1]),0:int(target.shape[2]),0:int(target.shape[3])]
    
        def conv2d(layer_input, filters, f_size=(4,4), s_size=(2,2), bn=True):
            """Layers used during downsampling"""
            xx = Conv2D(filters=filters, kernel_size=f_size, strides=s_size, padding='same', 
                        data_format='channels_last', dilation_rate=(1, 1), activation=None, 
                        use_bias=True, kernel_initializer='glorot_uniform',bias_initializer='zeros', 
                        kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, 
                        kernel_constraint=None, bias_constraint=None)(layer_input)
            if bn:
                xx = BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, 
                           beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', 
                           moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, 
                           beta_constraint=None, gamma_constraint=None)(xx)
            return xx
        

        def deconv2d(layer_input, skip_input, filters, f_size=(4,4),s_size=(2,2), dropout_rate=0, if_skip=True,
                     if_last=False):
            """Layers used during upsampling"""
            xx = Conv2DTranspose(filters=filters, kernel_size=f_size, strides=s_size, padding='same', output_padding=None, 
                         data_format='channels_last', dilation_rate=(1, 1), activation=None, use_bias=True, 
                         kernel_initializer='glorot_uniform', bias_initializer='zeros', 
                         kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None,
                         kernel_constraint=None, bias_constraint=None)(layer_input)
            if if_last:
                xx = Conv2DTranspose(filters=filters, kernel_size=f_size, strides=s_size, padding='same', output_padding=None, 
                             data_format='channels_last', dilation_rate=(1, 1), activation='tanh', use_bias=True, 
                             kernel_initializer='glorot_uniform', bias_initializer='zeros', 
                             kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None,
                             kernel_constraint=None, bias_constraint=None)(layer_input)
                
            if if_skip and xx.shape != skip_input.shape:
                xx = Lambda(cutLayer, arguments={'target':(skip_input)})(xx)#(xx,skip_input)
                
            if dropout_rate:   
                xx = Dropout(dropout_rate)(xx)
                
            if not if_last:
                xx = BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, 
                           beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', 
                           moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, 
                           beta_constraint=None, gamma_constraint=None)(xx)
                xx = Concatenate()([xx, skip_input])
            return xx

        # Image input
        d0 = Input( shape=( self.nt, self.nr, self.nph) )#shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf, bn=False)
        d2 = conv2d(d1, self.gf*8)
        d3 = conv2d(d2, self.gf*16)
        d4 = conv2d(d2, self.gf*16)

        # Upsampling
    
        u1 = deconv2d(d4, d3, self.gf*16)
        u2 = deconv2d(u1, d2, self.gf*16)
        u3 = deconv2d(u2, d1, self.gf*8)
        u4 = deconv2d(u3, d0, self.nop, if_last=True)

        return Model(inputs = d0, outputs = u4)
    
    
    def build_discriminator(self):  
        '''
        U-net discriminator
        ''' 
        def d_layer(layer_input, filters, f_size=(4,4),s_size=(2,2), bn=True):
            
            d = Conv2D(filters, kernel_size=f_size, strides=s_size, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d
        data_B = Input( shape=( self.nt, self.nr, self.nph) )
        data_A = Input( shape=( self.nt, self.nr, self.nop) )
        
        # Concatenate image and conditioning image by channels to produce input
        combined_data = Concatenate(axis=-1)([data_B, data_A])

        d1 = d_layer(combined_data, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model([data_B, data_A], validity)

            

In [None]:
epochs = 1; batch_size = 2; sample_interval = 2;

In [14]:
# opt = parameters().parse()
model = vspGAN()
G = model.build_generator()
D = model.build_discriminator()

my_data_loader = Dataloader(opt.data_path, opt.nt, opt.nr, opt.nph)

In [17]:
# G.summary()

In [18]:
# D.summary()

In [None]:
valid = np.ones( )
fake = 
start_time = datetime.datetime.now()

for epoch in range(epochs):
    for batch_i, (data_B, data_A) in enumerate(my_data_loader.load_batch(batch_size)):
        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Condition on B and generate a translated version
        fake_A = G.predict(data_B)
        
        # Train the discriminators (original images = real / generated = Fake)
        d_loss_real = D.train_on_batch([data_B, data_A], valid)
        d_loss_fake = D.train_on_batch([fake_B, data_A], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # -----------------
        #  Train Generator
        # -----------------

        # Train the generators
        g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])

        elapsed_time = datetime.datetime.now() - start_time
        # Plot the progress
        print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % 
               (epoch, epochs, batch_i, self.data_loader.n_batches,
                d_loss[0], 100*d_loss[1], g_loss[0], elapsed_time))

        # If at save interval => save generated image samples
        if batch_i % sample_interval == 0:
            self.sample_images(epoch, batch_i)