In [1]:
%matplotlib inline
import tensorflow as tf
from functools import partial

import sys
import os
if 'utils' in os.listdir('../../'):
    sys.path.append("../../")
    from utils.dataset import SerializationDataset
else:
    !pip install wget
    !git clone https://github.com/you-just-want-attention/all-about-mnist.git
    sys.path.append("./all-about-mnist/")
    from utils.dataset import SerializationDataset
    
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import keras
from tqdm import tqdm

from tf_models.text_recognition.models import CRNN
from tf_models.text_recognition.generator import DataGenerator

Using TensorFlow backend.


### [Optional.  Tensorflow Graph Visualization ]

---

> _Jupyter에서 Tensorflow에서 구성되는 Graph를 시각적으로 보여주기 위한 helper 메소드입니다._<br>

In [2]:
from IPython.display import clear_output, Image, display, HTML
import numpy as np    

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))

    display(HTML(iframe))

### Graph 만들기
----


In [3]:
crnn = CRNN()

# Model Graph 구성하기
(crnn
 ._attach_cnn(num_features=16)
 ._attach_rnn(num_features=128, num_depth=2)
 ._attach_transcription()
 ._attach_loss()
 ._attach_decoder()
 ._attach_metric()
 ._attach_optimizer(weight_decay=1e-5)
)

Instructions for updating:
Colocations handled automatically by placer.


<tf_models.text_recognition.models.CRNN at 0x10af69eb8>

In [4]:
show_graph(crnn.graph)


모델 학습시키기
----

### 1) 데이터 셋 가져오기 

In [None]:
train_set = SerializationDataset('mnist','train',
                                 digit=5,pad_range=(3,10))
validation_set = SerializationDataset('mnist','validation',
                                      digit=5,pad_range=(3,10))
test_set = SerializationDataset('mnist','test',
                                digit=(3,8),pad_range=(3,10))

### 2) 데이터 Generator 구성하기

In [None]:
train_gen = DataGenerator(train_set, 
                          batch_size=32)
valid_gen = DataGenerator(validation_set, 
                          batch_size=100, 
                          shuffle=False)
test_gen = DataGenerator(test_set, 
                         batch_size=500, 
                         shuffle=False)

test_images, test_labels = test_gen[0]

### 3) 학습시키기

In [None]:
num_epochs = 5
num_batch = 32
learning_rate = 0.001

crnn.fit_generator(train_gen, valid_gen, 
                   num_epoch=num_epochs,
                   learning_rate=learning_rate,
                   summary_path='./logs/')

### 4) 결과 확인하기

In [None]:
outputs = crnn.predict(test_images[:5])

for output, image in zip(outputs, test_images):
    plt.title(output)
    plt.imshow(image[:,:,0])
    plt.show()