In [2]:
import tensorflow.keras.layers as layers
import tensorflow as tf

In [3]:
inputs = layers.Input(shape=(572, 572, 1))

conv0 = layers.Conv2D(64, activation='relu', kernel_size=3)(inputs) #(570, 570, 64)
conv1 = layers.Conv2D(64, activation='relu', kernel_size=3)(conv0) #(568, 568, 64)
conv2 = layers.MaxPool2D(pool_size=(2,2), strides=(2,2))(conv1) #(284, 284, 64)

conv3 = layers.Conv2D(128, activation='relu', kernel_size=3)(conv2) #(282, 282, 128)
conv4 = layers.Conv2D(128, activation='relu', kernel_size=3)(conv3) #(280, 280, 128)
conv5 = layers.MaxPool2D(pool_size=(2,2), strides=(2,2))(conv4) #(140, 140, 128)

conv6 = layers.Conv2D(256, activation='relu', kernel_size=3)(conv5) #(138, 138, 256)
conv7 = layers.Conv2D(256, activation='relu', kernel_size=3)(conv6) #(136, 136, 256)
conv8 = layers.MaxPool2D(pool_size=(2,2), strides=(2,2))(conv7) #(68, 68, 256)

conv9 = layers.Conv2D(512, activation='relu', kernel_size=3)(conv8) #(66, 66, 512)
conv10 = layers.Conv2D(512, activation='relu', kernel_size=3)(conv9) #(64, 64, 512)
conv11 = layers.MaxPool2D(pool_size=(2,2), strides=(2,2))(conv10) #(32, 32, 512)

conv12 = layers.Conv2D(1024, activation='relu', kernel_size=3)(conv11) #(30, 30, 1024)
conv13 = layers.Conv2D(1024, activation='relu', kernel_size=3)(conv12) #(28, 28, 1024)

trans01 = layers.Conv2DTranspose(512, kernel_size=2, strides=(2,2), activation='relu')(conv13) #(56, 56, 512)
crop01 = layers.Cropping2D(cropping=(4,4))(conv10) #(56, 56, 512)
concat01 = layers.concatenate([trans01, crop01], axis=-1) #(56, 56, 1024)

conv14 = layers.Conv2D(512, activation='relu', kernel_size=3)(concat01) #(54, 54, 512)
conv15 = layers.Conv2D(512, activation='relu', kernel_size=3)(conv14) #(52, 52, 512)

trans02 = layers.Conv2DTranspose(256, kernel_size=2, strides=(2,2), activation='relu')(conv15) #(104, 104, 256)
crop02 = layers.Cropping2D(cropping=(16, 16))(conv7) #(104, 104, 256)
concat02 = layers.concatenate([trans02, crop02], axis=-1) #(104, 104, 512)

conv16 = layers.Conv2D(256, activation='relu', kernel_size=3)(concat02) #(102, 102, 256)
conv17 = layers.Conv2D(256, activation='relu', kernel_size=3)(conv16) #(100, 100, 256)

trans03 = layers.Conv2DTranspose(128, kernel_size=2, strides=(2,2), activation='relu')(conv17) #(200, 200, 128)
crop03 = layers.Cropping2D(cropping=(40, 40))(conv4) #(200, 200, 128)
concat03 = layers.concatenate([trans03, crop03], axis=-1) #(200, 200, 256)

conv18 = layers.Conv2D(128, activation='relu', kernel_size=3)(concat03) #(198, 198, 128)
conv19 = layers.Conv2D(128, activation='relu', kernel_size=3)(conv18) #(196, 196, 128)

trans04 = layers.Conv2DTranspose(64, kernel_size=2, strides=(2,2), activation='relu')(conv19) #(392, 392, 64)
crop04 = layers.Cropping2D(cropping=(88, 88))(conv1) #(392, 392, 64)
concat04 = layers.concatenate([trans04, crop04], axis=-1) #(392, 392, 128)

conv20 = layers.Conv2D(64, activation='relu', kernel_size=3)(concat04) #(390, 390, 64)
conv21 = layers.Conv2D(64, activation='relu', kernel_size=3)(conv20) #(388, 388, 64)

outputs = layers.Conv2D(2, kernel_size=1)(conv21)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name='u-netmodel')

In [4]:
model.summary()

Model: "u-netmodel"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 572, 572, 1)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 570, 570, 64)         640       ['input_1[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, 568, 568, 64)         36928     ['conv2d[0][0]']              
                                                                                                  
 max_pooling2d (MaxPooling2  (None, 284, 284, 64)         0         ['conv2d_1[0][0]']            
 D)                                                                                      

 conv2d_17 (Conv2D)          (None, 388, 388, 64)         36928     ['conv2d_16[0][0]']           
                                                                                                  
 conv2d_18 (Conv2D)          (None, 388, 388, 2)          130       ['conv2d_17[0][0]']           
                                                                                                  
Total params: 31030658 (118.37 MB)
Trainable params: 31030658 (118.37 MB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________


In [5]:
from IPython.display import SVG
from keras.utils import plot_model  # 수정된 부분

%matplotlib inline

SVG(plot_model(model, show_shapes=True, show_layer_names=True, dpi=80).create(prog='dot', format='svg'))

AttributeError: 'Image' object has no attribute 'create'

In [6]:
import numpy as np
import tensorflow as tf

In [7]:
X = np.asarray([[1, 2], [3, 4]])

In [8]:
print(X) 
print(X.shape)

[[1 2]
 [3 4]]
(2, 2)


In [9]:
X = X.reshape((1, 2, 2, 1)) # 2X2 크기의 흑백 이미지 한 개

In [10]:
print(X)
print(X.shape)

[[[[1]
   [2]]

  [[3]
   [4]]]]
(1, 2, 2, 1)


In [11]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2DTranspose(1, (1, 1), strides=(2, 2), input_shape=(2, 2, 1)))

In [12]:
weights = [np.asarray([[[[1]]]]), np.asarray([1])]

In [13]:
weights

[array([[[[1]]]]), array([1])]

In [14]:
model.set_weights(weights)

In [15]:
yhat = model.predict(X) 
yhat = yhat.reshape((4, 4)) # 결과를 확인하기 편하게 reshaping
print(yhat)

[[2. 1. 3. 1.]
 [1. 1. 1. 1.]
 [4. 1. 5. 1.]
 [1. 1. 1. 1.]]


In [27]:
X = np.asarray([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
X = X.reshape((1, 3, 3, 1))

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), input_shape=(3, 3, 1)))

weights = [np.full((3, 3, 1, 1), 2), np.asarray([1])]
model.set_weights(weights)

yhat = model.predict(X) 
yhat = yhat.reshape((7, 7))
print(yhat)

[[ 3.  3.  5.  3.  5.  3.  3.]
 [ 3.  3.  5.  3.  5.  3.  3.]
 [ 7.  7. 13.  7. 13.  7.  7.]
 [ 5.  5.  9.  5.  9.  5.  5.]
 [11. 11. 21. 11. 21. 11. 11.]
 [ 7.  7. 13.  7. 13.  7.  7.]
 [ 7.  7. 13.  7. 13.  7.  7.]]
