In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import cv2
import pickle
import tensorflow as tf

### SegLink Model

In [None]:
from sl_model import SL512, DSODSL512
from sl_utils import PriorUtil
from ssd_data import InputGenerator
from ssd_data import preprocess

from utils.model import load_weights

In [None]:
Model = SL512
weights_path = './checkpoints/201809231008_sl512_synthtext/weights.002.h5'
segment_threshold = 0.6; link_threshold = 0.25
plot_name = 'sl512_crnn_sythtext'

In [None]:
Model = DSODSL512
weights_path = './checkpoints/201806021007_dsodsl512_synthtext/weights.012.h5'
segment_threshold = 0.55; link_threshold = 0.45
plot_name = 'dsodsl512_crnn_sythtext'

In [None]:
det_model = Model()
prior_util = PriorUtil(det_model)
det_model.load_weights(weights_path)

image_size = det_model.image_size

### CRNN Model

In [None]:
from crnn_model import CRNN
from crnn_utils import alphabet87 as alphabet

input_width = 256
input_height = 32

weights_path = './checkpoints/201806190711_crnn_gru_synthtext/weights.400000.h5'

rec_model = CRNN((input_width, input_height, 1), len(alphabet), gru=True, prediction_only=True)
rec_model.load_weights(weights_path, by_name=True)

### Detection real world images

In [None]:
inputs = []
images = []
images_orig = []
data = []

for img_path in glob.glob('data/images/test_images_seglink/*'):
    img = cv2.imread(img_path)
    images_orig.append(np.copy(img))
    inputs.append(preprocess(img, image_size))
    h, w = image_size
    img = cv2.resize(img, (w,h), cv2.INTER_LINEAR).astype('float32') # should we do resizing
    img = img[:, :, (2,1,0)] / 255 # BGR to RGB
    images.append(img)
    
inputs = np.asarray(inputs)

preds = det_model.predict(inputs, batch_size=1, verbose=1)

In [None]:
%%timeit
for i in range(len(inputs)):
    preds = det_model.predict(inputs[i:i+1], batch_size=1, verbose=0)
    #res = prior_util.decode(preds[0], segment_threshold, link_threshold)

In [None]:
preds.shape

### Detection SynthText

In [None]:
import pickle
from data_synthtext import GTUtility

file_name = 'gt_util_synthtext_seglink.pkl'
with open(file_name, 'rb') as f:
    gt_util = pickle.load(f)
gt_util_train, gt_util_val = gt_util.split(0.9)

In [None]:
idxs, inputs, images, data = gt_util_val.sample_random_batch(batch_size=32, input_size=image_size)

images_orig = [cv2.imread(os.path.join(gt_util_val.image_path, gt_util_val.image_names[idx])) for idx in idxs]

preds = det_model.predict(inputs, batch_size=1, verbose=1)

### Recognition

In [None]:
from crnn_data import crop_words
from crnn_utils import decode
from sl_utils import rbox_to_polygon, polygon_to_rbox
from utils.vis import plot_box, escape_latex

#for k in range(len(preds)):
for k in [0,2,3,9,10,11,24]:
    plt.figure(figsize=[8]*2)
    plt.imshow(images[k])
    res = prior_util.decode(preds[k], segment_threshold, link_threshold)
    
    #print(res.shape)
    
    img = images_orig[k]
    #mean = np.array([104,117,123])
    #img -= mean[np.newaxis, np.newaxis, :]
    rboxes = res[:,:5]
    if len(rboxes) == 0:
        plt.axis('off')
        plt.show()
        continue
        
    bh = rboxes[:,3]
    rboxes[:,2] += bh * 0.1
    rboxes[:,3] += bh * 0.2
    
    boxes = np.asarray([rbox_to_polygon(r) for r in rboxes])
    boxes = np.flip(boxes, axis=1) # TODO: fix order of points, why?
    boxes = np.reshape(boxes, (-1, 8))
    
    boxes_mask = np.array([not (np.any(b < 0-10) or np.any(b > 512+10)) for b in boxes]) # box inside image
    #boxes_mask = np.logical_and(boxes_mask, [b[2] > 0.8*b[3] for b in rboxes]) # width > height, in square world
    
    boxes = boxes[boxes_mask]
    rboxes = rboxes[boxes_mask]
    if len(boxes) == 0:
        boxes = np.empty((0,8))
    
    # plot boxes
    for box in boxes:
        c = 'rgby'
        for i in range(4):
            x, y = box[i*2:i*2+2]
            plt.plot(x,y, c[i], marker='o', markersize=4)
        plot_box(box, 'polygon')
    
    words = crop_words(img, np.clip(boxes/512,0,1), input_height, width=input_width, grayscale=True)
    words = np.asarray([w.transpose(1,0,2) for w in words])
    
    if len(words) > 0:
        res_crnn = rec_model.predict(words)

    #print('rboxes', len(rboxes), 'words', len(words), 'res_crnn', len(res_crnn))
    for i in range(len(words)):
        chars = [alphabet[c] for c in np.argmax(res_crnn[i], axis=1)]
        
        #gt_str = texts[i]
        res_str = decode(chars)
        
        #ed = editdistance.eval(gt_str, res_str)
        #ed = levenshtein(gt_str, res_str)
        #ed_norm = ed / len(gt_str)
        #mean_ed += ed
        #mean_ed_norm += ed_norm
        
        #print('%-20s %s' % (res_str, ''.join(chars)))
        #print('%s %-20s %0.2f' % (''.join(chars), res_str, res[i,5]))
        
        #print('%-20s %-20s %s %0.2f' % (
        #    gt_str,
        #    res_str,
        #    ''.join(chars),
        #    ed_norm))
        x, y, w, h, theta = rboxes[i]
        
        #res_str = re.sub(r"([#$%&_{}])", r"\\\1" , res_str)
        #print(res_str, '   ', escape_latex(res_str))
        
        
        #plt.text(x+h*np.sin(theta)/2, y+h*np.cos(theta)/2, escape_latex(res_str), rotation=theta/np.pi*180, 
        #         horizontalalignment='center', size='x-large' , color='cyan') # magenta, lime
        plt.text(x+h*np.sin(theta)/2, y+h*np.cos(theta)/2, escape_latex(res_str), rotation=theta/np.pi*180, 
                 horizontalalignment='center', size='xx-large' , color='lime') # magenta, lime
    
    plt.axis('off')
    
    file_name = 'plots/%s_endtoend_realworld_%03i.pgf' % (plot_name, k)
    #plt.savefig(file_name, bbox_inches='tight')
    #print(file_name)
    
    plt.show()
    
    if False:
        for i in range(len(words)):
            plt.figure(figsize=[30,0.5])
            plt.imshow(words[i][:,:,0].T, cmap='gray')
            plt.axis('off')
            plt.show()