In [1]:
import matplotlib.pyplot as plt

In [2]:
from keras import models, backend
from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D, \
    BatchNormalization, Concatenate, Activation

Model

In [None]:
class UNET(models.Model):
    def __init__(self. original_shape, n_ch):
        ic = 3 if backend.image_data_format() == 'channels_last' else 1 #채널이 있는 shape순서
        
        def conv(x, n_f, mp_flag=True):
            x = MaxPooling2D(pool_size=(2,2), padding='same')(x) if mp_flag else x
            x = Conv2D(n_f, (3,3), padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('tanh')(x)
            x = Dropout(0.05)(x)
            x = Conv2D(n_f, (3,3), padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('tanh')(x)
            return x
        
        def deconv_unet(x, e, n_f):
            x = UpSampling2D(size=(2,2))(x)
            x = Concatenate(axis=ic)([x, e])
            x = Conv2D(n_f, (3,3), padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('tanh')(x)
            x = Conv2D(n_f, (3,3), padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('tanh')(x)
            return x
        
        #input
        original = Input(shape=original_shape)
        
        #encoding
        c1 = conv(original, 16, mp_flag=False)
        c2 = conv(c1, 32)
        
        #encoder
        encoded = conv(c2, 64)
        
        #decoding
        x = deconv_unet(encoded, c2, 32)
        x = deconv_unet(x, c1, 16)
        
        decoded = Conv2D(filters=n_ch, kernel_size=(3,3), activation='sigmoid', padding='same')(x)
        
        super().__init__(original, decoded)
        self.compile(optimizer = 'adadelta', loss='mse')
            

데이터 불러오기

In [3]:
from keras import datasets, utils

class DATA():
    def __init__(self, in_ch=None):
        (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
        if x_train.ndim == 4:
            if backend.image_data_format() == 'channels_first':
                n_ch, img_rows, img_cols = x_train.shape[1:]
            else:
                img_rows, img_cols, n_ch = x_train.shape[1:]
        else:
            img_rows, img_cols = x_train.shape[1:]
            
        in_ch = n_ch if in_ch is None else in_ch