Skip to content

Commit

Permalink
Add predict sample code for lstm+ctc ocr. Also update it's README.md (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
BobLiu20 authored and Rahul Ravu committed Jan 21, 2017
1 parent bc3ae3f commit 06b55c9
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 6 deletions.
20 changes: 17 additions & 3 deletions example/warpctc/README.md
Expand Up @@ -34,9 +34,11 @@ I implement two examples, one is just a toy example which can be used to prove c
cd examples/warpctc
python lstm_ocr.py
```
Note:
* Please modify ```contexts = [mx.context.gpu(1)]``` in this file according to your hardware. If you only have one GPU pelase change 1 to 0(which GPU is selected.)
* Please copy your font file to current folder. And instend of './data/Xerox.ttf' by your font file name. Maybe you can get a font from /usr/share/fonts/truetype/ in ubuntu.

Notes:
* Please modify ```contexts = [mx.context.gpu(0)]``` in this file according to your hardware.
* Please review the code ```'./font/Ubuntu-M.ttf'```. Copy your font to here font/yourfont.ttf. To get a free font from [here](http://font.ubuntu.com/).
* The checkpoint will be auto saved in each epoch. And then you can use this checkpoint to do a predict.

The OCR example is constructed as follows:

Expand Down Expand Up @@ -92,3 +94,15 @@ Following code show detail construction of the net:
If you label length is smaller than or equal to b. You should provide labels with length b, and for those samples which label length is smaller than b, you should append 0 to label data to make it have length b.

Here, 0 is reserved for blank label.

## Do a predict

Pelase run:

```
python ocr_predict.py
```

Notes:
* Change the code following the name of your params and json file.
* You have to do a ```make``` in amalgamation folder.(a libmxnet_predict.so will be created in lib folder.)
8 changes: 5 additions & 3 deletions example/warpctc/lstm_ocr.py
Expand Up @@ -48,7 +48,8 @@ def get_label(buf):
class OCRIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, init_states):
super(OCRIter, self).__init__()
self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf'])
# you can get this font from http://font.ubuntu.com/
self.captcha = ImageCaptcha(fonts=['./font/Ubuntu-M.ttf'])
self.batch_size = batch_size
self.count = count
self.num_label = num_label
Expand Down Expand Up @@ -140,7 +141,7 @@ def Accuracy(label, pred):
momentum = 0.9
num_label = 4

contexts = [mx.context.gpu(1)]
contexts = [mx.context.gpu(0)]

def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len,
Expand Down Expand Up @@ -172,6 +173,7 @@ def sym_gen(seq_len):

model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Accuracy),
batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),)
batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),
epoch_end_callback = mx.callback.do_checkpoint(prefix, 1))

model.save("ocr")
83 changes: 83 additions & 0 deletions example/warpctc/ocr_predict.py
@@ -0,0 +1,83 @@
#!/usr/bin/env python2.7
# coding=utf-8
from __future__ import print_function
import sys, os
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append("../../amalgamation/python/")
sys.path.append("../../python/")

from mxnet_predict import Predictor
import mxnet as mx

import numpy as np
import cv2
import os

class lstm_ocr_model(object):
# Keep Zero index for blank. (CTC request it)
CONST_CHAR='0123456789'
def __init__(self, path_of_json, path_of_params):
super(lstm_ocr_model, self).__init__()
self.path_of_json = path_of_json
self.path_of_params = path_of_params
self.predictor = None
self.__init_ocr()

def __init_ocr(self):
num_label = 4 # Set your max length of label, add one more for blank
batch_size = 1

num_hidden = 100
num_lstm_layer = 2
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h

all_shapes = [('data', (batch_size, 80 * 30))] + init_states + [('label', (batch_size, num_label))]
all_shapes_dict = {}
for _shape in all_shapes:
all_shapes_dict[_shape[0]] = _shape[1]
self.predictor = Predictor(open(self.path_of_json).read(),
open(self.path_of_params).read(),
all_shapes_dict)

def forward_ocr(self, img):
img = cv2.resize(img, (80, 30))
img = img.transpose(1, 0)
img = img.reshape((80 * 30))
img = np.multiply(img, 1/255.0)
self.predictor.forward(data=img)
prob = self.predictor.get_output(0)
label_list = []
for p in prob:
max_index = np.argsort(p)[::-1][0]
label_list.append(max_index)
return self.__get_string(label_list)

def __get_string(self, label_list):
# Do CTC label rule
# CTC cannot emit a repeated symbol on consecutive timesteps
ret = []
label_list2 = [0] + list(label_list)
for i in range(len(label_list)):
c1 = label_list2[i]
c2 = label_list2[i+1]
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
# change to ascii
s = ''
for l in ret:
if l > 0 and l < (len(lstm_ocr_model.CONST_CHAR)+1):
c = lstm_ocr_model.CONST_CHAR[l-1]
else:
c = ''
s += c
return s

if __name__ == '__main__':
_lstm_ocr_model = lstm_ocr_model('ocr-symbol.json', 'ocr-0010.params')
img = cv2.imread('sample.jpg', 0)
_str = _lstm_ocr_model.forward_ocr(img)
print('Result: ', _str)

Binary file added example/warpctc/sample.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 06b55c9

Please sign in to comment.