-
Notifications
You must be signed in to change notification settings - Fork 590
/
ocropus-tleaves
executable file
·148 lines (123 loc) · 5.34 KB
/
ocropus-tleaves
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/python
import sys
sys.path = ["."]+sys.path
import re
import random as pyrandom
from pylab import *
import tables
import ocrolib; reload(ocrolib)
from collections import Counter
import ocrolib
from ocrolib import patrec
from ocrolib.patrec import Dataset,showgrid
from ocrolib.ligatures import lig
import argparse
parser = argparse.ArgumentParser("""
Perform training for the leaves of a vector quantizer.
""")
parser.add_argument('-N','--maxtotal',type=int,default=10000000000,help="max # of samples used")
parser.add_argument('-t','--testset',type=int,default=-1,help="use testset sequence t (-1=use all samples)")
parser.add_argument('-T','--maxtrain',type=int,default=100000,help="max # of training samples per classifier")
parser.add_argument('-d','--data',default="chardata.h5",help="data file")
parser.add_argument('-s','--splitter',default="split.splitter",help="input model to be updated")
parser.add_argument('-o','--output',default="trained.cmodel",help="output with per-leaf classifiers")
parser.add_argument('--exec',dest="execute",default=None,help="additional modules to import (e.g., to load additional classifiers)")
parser.add_argument('-C','--classifier',default="patrec.LogisticCmodel()",help="factory for leaf classifier")
parser.add_argument('-D','--debug',action="store_true")
parser.add_argument('-q','--quiet',action="store_true")
parser.add_argument('-Q','--parallel',type=int,default=0,help="number of CPUs to use for training")
parser.add_argument('--exclude',default=r"[ _\000-\037]",help="regular expression matching characters to exclude")
args = parser.parse_args()
#args = parser.parse_args(["-s","split100k.smodel","-N","100000"])
cfactory = eval("lambda:"+args.classifier)
if args.execute is not None:
print args.execute
exec args.execute
print "loading splitter"
splitter = ocrolib.load_object(args.splitter)
print "got",splitter
print "#splits",splitter.nclusters()
print "excluding",args.exclude
splitter_sizemode = getattr(splitter,"sizemode","perchar")
print "sizemode",splitter_sizemode
def testset(i):
"""Quick check for whether the sample is in the test set."""
if args.testset<0: return 0
return ocrolib.testset(i,sequence=args.testset)
# load the dataset and find out which buckets each sample
# goes into; we also get rid of samples that are in excluded
# classes by assigning it to the special bucket '-1'
with tables.openFile(args.data,"r") as h5:
print "loading dataset"
N = min(args.maxtotal,len(h5.root.classes))
patches = Dataset(h5.root.patches,maxsize=N)
data_sizemode = h5.getNodeAttr("/","sizemode")
print "sizemode (data)",data_sizemode
assert splitter_sizemode==data_sizemode,"sizemode for splitter (%s) and data (%s) don't agree"%(splitter_sizemode,data_sizemode)
print "splitting"
splits = patrec.parallel_predict(splitter,patches,parallel=args.parallel,verbose=not args.quiet)
for i in range(len(patches)):
if testset(i) or re.search(args.exclude,lig.chr(h5.root.classes[i])):
splits[i] = -1
# give the user some feedback about cluster distributions
splits = array(splits,'i')
histogram = Counter(splits)
if args.debug:
counts = sorted(histogram.values(),key=lambda x:-x)
ion(); gray(); clf()
yscale('log')
plot(counts)
ginput(1,1)
clusters = sorted(histogram.keys())
if len(clusters)<splitter.nclusters():
print "warning: not all clusters present",len(clusters),splitter.nclusters()
# create a number of "jobs"; we work through these either serially or in parallel
jobs = [(cluster,find(splits==cluster)) for cluster in clusters if cluster>=0]
# the main job performing the splitting for an individual bucket
def process1(job):
cluster,indexes = job
if len(indexes)>args.maxtrain:
indexes = pyrandom.sample(indexes,args.maxtrain)
note = "cluster %4d len %6d"%(cluster,len(indexes))
with tables.openFile(args.data,"r") as h5:
# load the classes and training data
cclasses = [lig.chr(h5.root.classes[i]) for i in indexes]
patches = Dataset(h5.root.patches)
data = array([patches[i] for i in indexes],'f')
# give the user some feedback about what the classes and samples are
counts = Counter(cclasses).most_common(5)
cinfo = " / ".join(["%s %s"%(k,v) for k,v in counts])
note += " "+cinfo
if args.debug:
clf();
if len(data)>=49: showgrid(patrec.vecsort(pyrandom.sample(data,49)))
else: showgrid(data)
suptitle(cinfo)
ginput(1,0.1)
assert data.ndim==2
# now just train the classifier and return it; the `cfactory` expression
# should take care of any parameters
if not args.quiet: print note
classifier = cfactory()
classifier.fit(data,cclasses)
return (cluster,classifier)
if (args.parallel>=0 and args.parallel<2) or args.debug:
classifiers = []
for job in jobs:
classifiers.append(process1(job))
else:
import multiprocessing
pool = multiprocessing.Pool(args.parallel)
classifiers = pool.map_async(process1,jobs)
pool.close()
pool.join()
del pool
classifiers = classifiers.get()
# nsplitter = patrec.HierarchicalSplitter()
# nsplitter.update(splitter)
cmodel = patrec.LocalCmodel(splitter=splitter)
for i,c in classifiers:
cmodel.setCmodel(i,c)
cmodel.sizemode = data_sizemode
print "writing"
ocrolib.save_object(args.output,cmodel)