Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions research/attention_ocr/python/demo_inference.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""A script to run inference on a set of image files.

NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution.
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
it will work only for images which look more or less similar to french street
names. In order to apply it to images from a different distribution you need
to retrain (or at least fine-tune) it using images from that distribution.

NOTE #2: This script exists for demo purposes only. It is highly recommended
to use tools and mechanisms provided by the TensorFlow Serving system to run
Expand All @@ -20,10 +20,11 @@

import tensorflow as tf
from tensorflow.python.platform import flags
from tensorflow.python.training import monitored_session

import common_flags
import datasets
import model as attention_ocr
import data_provider

FLAGS = flags.FLAGS
common_flags.define()
Expand All @@ -44,7 +45,7 @@ def get_dataset_image_size(dataset_name):
def load_images(file_pattern, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
images_actual_data = np.ndarray(shape=(batch_size, height, width, 3),
dtype='float32')
dtype='uint8')
for i in range(batch_size):
path = file_pattern % i
print("Reading %s" % path)
Expand All @@ -53,34 +54,40 @@ def load_images(file_pattern, batch_size, dataset_name):
return images_actual_data


def load_model(checkpoint, batch_size, dataset_name):
def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
images_placeholder = tf.placeholder(tf.float32,
shape=[batch_size, height, width, 3])
endpoints = model.create_base(images_placeholder, labels_one_hot=None)
init_fn = model.create_init_fn_to_restore(checkpoint)
return images_placeholder, endpoints, init_fn
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
images = tf.map_fn(data_provider.preprocess_image, raw_images,
dtype=tf.float32)
endpoints = model.create_base(images, labels_one_hot=None)
return raw_images, endpoints


def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_placeholder, endpoints = create_model(batch_size,
dataset_name)
images_data = load_images(image_path_pattern, batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return predictions.tolist()


def main(_):
images_placeholder, endpoints, init_fn = load_model(FLAGS.checkpoint,
FLAGS.batch_size,
FLAGS.dataset_name)
images_data = load_images(FLAGS.image_path_pattern, FLAGS.batch_size,
FLAGS.dataset_name)
with tf.Session() as sess:
tf.tables_initializer().run() # required by the CharsetMapper
init_fn(sess)
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
print("Predicted strings:")
predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
FLAGS.image_path_pattern)
for line in predictions:
print(line)

Expand Down
87 changes: 87 additions & 0 deletions research/attention_ocr/python/demo_inference_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import demo_inference
import tensorflow as tf
from tensorflow.python.training import monitored_session

_CHECKPOINT = 'model.ckpt-399731'
_CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'


class DemoInferenceTest(tf.test.TestCase):
def setUp(self):
super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
self._batch_size = 32

def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32
dataset_name = 'fsns'
images_placeholder, endpoints = demo_inference.create_model(batch_size,
dataset_name)
image_path_pattern = 'testdata/fsns_train_%02d.png'
images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name(
tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name)

session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data})

self.assertAllEqual(moving_mean_expected, moving_mean_np)

def test_correct_results_on_test_data(self):
image_path_pattern = 'testdata/fsns_train_%02d.png'
predictions = demo_inference.run(_CHECKPOINT, self._batch_size,
'fsns',
image_path_pattern)
self.assertEqual([
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
], predictions)


if __name__ == '__main__':
tf.test.main()
Loading