-
Notifications
You must be signed in to change notification settings - Fork 35
/
util.py
170 lines (143 loc) · 4.18 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import time
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import keras.backend as K
import glob
def one_hot_encoder(data, whole_set, char2idx):
"""
Encode the whole list, not a single record
"""
ret = []
for i in data:
idx = char2idx[i]
tmp = np.zeros(len(whole_set), dtype=np.int8)
tmp[idx] = 1
ret.append(tmp)
return np.asarray(ret)
def one_hot_decoder(data, whole_set):
ret = []
if data.ndim == 1: # keras bug ?
data = np.expand_dims(data, 0)
for probs in data:
idx = np.argmax(probs)
# print idx, whole_set[idx], probs[idx]
ret.append(whole_set[idx])
return ret
def top_one_prob(data):
ret = []
if data.ndim == 1: # keras bug ?
data = np.expand_dims(data, 0)
for probs in data:
idx = np.argmax(probs)
ret.append(probs[idx])
return ret
def list2str(data):
return ''.join([i if i != 'empty' else '' for i in data])
def plot_loss_figure(history, save_path):
train_loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(train_loss, 'b', val_loss, 'r')
plt.xlabel('train_loss: blue val_loss: red epoch')
plt.ylabel('loss')
plt.title('loss figure')
plt.savefig(save_path)
def load_data(input_dir, max_nb_cha, width, height, channels, char_set, char2idx):
"""
The format of the file folder
All image file, named as 'id.jpg', id starts from 1
label.txt, the first row correspond to 1.jpg's label, the file should be encoded in utf-8
"""
print 'Loading data...'
tag = time.time()
x = []
y = []
for dirpath, dirnames, filenames in os.walk(input_dir):
nb_pic = len(glob.glob(dirpath + os.sep + '*.jpg'))
if nb_pic <= 0:
continue
for i in range(1, nb_pic+1):
filename = str(i) + '.jpg'
filepath = dirpath + os.sep + filename
pixels = load_img(filepath, width, height, channels)
x.append(pixels)
label_path = dirpath + os.sep + 'label.txt'
with open(label_path) as f:
for raw in f:
raw = raw.decode('utf-8').strip('\n\r')
if len(raw) == 0:
break
y.append([])
for i in range(max_nb_cha):
if i < len(raw): # make sure there isn't empty label
y[-1].append(raw[i])
else:
y[-1].append('empty')
# transform to keras format, and do one hot encoding
x = np.asarray(x)
x /= 255 # normalized
y = [one_hot_encoder(i, char_set, char2idx) for i in y]
y = np.asarray(y)
if y.shape[1] == 1: # keras bug ?
y = y[:,0,:]
print 'Data loaded, spend time(m) :', (time.time()-tag)/60
return [x, y]
def my_resize(im, width, height):
box = max(im.size)
back = Image.new('RGB', (box, box), color='white')
back.paste(im,(0,0))
back = back.resize((width, height))
return back
def load_img(path, width, height, channels):
img = Image.open(path)
img = img.resize((width, height))
# img = my_resize(img, width, height)
if channels==1: # convert the image to gray scale image if it's RGB
img = img.convert('L')
img = np.asarray(img, dtype='float32')
if channels > 1:
img = np.rollaxis(img, 2, 0)
else:
img = np.expand_dims(img, 0)
return img
def get_char_set(file_dir):
file_path = file_dir+'label.txt'
ret = set(['empty'])
with open(file_path) as f:
for raw in f:
raw = raw.decode('utf-8').strip('\r\n')
for i in raw:
ret.add(i)
char_set = list(ret)
char2idx = dict(zip(char_set, range(len(char_set))))
return char_set, char2idx
def get_maxnb_char(file_dir):
file_path = file_dir+'label.txt'
ret = 1
with open(file_path) as f:
for raw in f:
raw = raw.decode('utf-8').strip('\r\n')
ret = max(ret, len(raw))
return ret
def categorical_accuracy_per_sequence(y_true, y_pred):
return K.mean(K.min(K.equal(K.argmax(y_true, axis=-1),
K.argmax(y_pred, axis=-1)), axis=-1))
def get_sample_weight(label, whole_set):
if label.ndim < 3: # in case output_size==1
return None
ret = []
for i in label:
ret.append([])
tag = False
for j in i:
cha = whole_set[np.argmax(j)]
weight = 0
if cha == 'empty' and tag == False:
weight = 1 # TODO
tag = True
if cha != 'empty':
weight = 1
ret[-1].append(weight)
ret = np.asarray(ret)
return ret