-
-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
# coding=utf-8 | ||
from enum import Enum | ||
|
||
import PIL.ImageOps | ||
import numpy | ||
from PIL import Image, ImageDraw, ImageFont | ||
|
||
import letter | ||
from extensions import * | ||
|
||
num_characters = letter.nLetters # ascii: 127 | ||
|
||
class Target(Enum): # labels | ||
word = 1 | ||
text = 2 | ||
box = 3 # start,end | ||
position = 4 | ||
start = 4 #position | ||
end = 5 #position | ||
style = 6 | ||
angle = 7 | ||
size = 8 | ||
|
||
|
||
|
||
def pos_to_arr(pos): | ||
return [pos['x'],pos['y']] | ||
|
||
|
||
|
||
max_size = letter.max_size | ||
max_word_length= 15 | ||
canvas_size=300 # Arbitrary, shouldn't need to be specified in advance when doing inference | ||
|
||
def pad(vec, pad_to=max_word_length, one_hot=False, terminal_symbol=0): | ||
for i in range(0, pad_to - len(vec)): | ||
if one_hot: | ||
vec.append([terminal_symbol] * num_characters) | ||
else: | ||
vec.append(terminal_symbol) | ||
return vec | ||
|
||
class data(): | ||
def __init__(self): | ||
self.input_shape = None | ||
self.output_shape = None | ||
|
||
def __next__(self): | ||
return self.next_batch() | ||
|
||
def __iter__(self): | ||
# return next(self.generator) | ||
return self.next_batch() | ||
|
||
def next_batch(self): | ||
raise Exception("abstract data class must be implemented") | ||
|
||
|
||
class batch(data): | ||
|
||
def __init__(self, target=Target.word, batch_size=64): | ||
super().__init__() | ||
self.batch_size=batch_size | ||
self.target= target | ||
self.shape=[max_size * max_size, max_word_length * letter.nLetters] | ||
# self.shape=[batch_size,max_size,max_size,len(letters)] | ||
self.train= self | ||
self.test = self | ||
# self.test.images,self.test.labels = self.next_batch() # nonesense! | ||
|
||
def next_batch(self,batch_size=None): | ||
# type: () -> (list,list) | ||
words = [word() for i in range(batch_size or self.batch_size)] | ||
def norm(word): | ||
# type: (word) -> ndarray | ||
return word.matrix() # dump the whole abstract representation as an image | ||
xs=list(map(norm, words)) # 1...-1 range | ||
if self.target == Target.word: ys= [many_hot(word.text, num_characters) for w in words] | ||
if self.target == Target.size: ys = [l.size for l in words] | ||
if self.target == Target.position: ys = [pos_to_arr(l.pos) for l in words] | ||
return xs, ys | ||
# return list(xs), list(ys) | ||
|
||
|
||
def pick(xs): | ||
return xs[randint(0,len(xs)-1)] | ||
|
||
|
||
def many_hot(items, num_classes, offset, limit=max_word_length): | ||
labels_many_hot = [] | ||
for item in items: | ||
labels_one_hot = numpy.zeros(num_classes) | ||
labels_one_hot[item - offset] = 1 | ||
labels_many_hot.append(labels_one_hot) | ||
if len(labels_many_hot)>limit: | ||
print("item > limit %s > %d"%(item,limit)) | ||
break | ||
if len(labels_many_hot) < limit: | ||
pad(labels_many_hot,limit,true) | ||
|
||
return labels_many_hot | ||
|
||
|
||
def random_word(): | ||
word_file = "/usr/share/dict/words" | ||
WORDS = open(word_file).read().splitlines() | ||
return pick(WORDS) | ||
pass # Don't (just) use dictionary because we really want to ocr passwords too | ||
|
||
|
||
class word(): | ||
|
||
|
||
def __init__(self, *margs, **args): # optional arguments | ||
if not args: | ||
if margs: args=margs[0] # ruby style hash args | ||
else:args={} | ||
# super(Argument, self).__init__(*margs, **args) | ||
# self.name = args['name'] if 'name' in args else None | ||
# self.family = args['family'] if 'family' in args else pick(families) | ||
self.font = args['font'] if 'font' in args else pick(letter.fontnames) | ||
self.size = args['size'] if 'size' in args else pick(letter.sizes) | ||
self.color= args['color'] if 'color' in args else 'black'#'white'#self.random_color() # #None #pick(range(-90, 180)) | ||
self.back = args['back'] if 'back' in args else letter.random_color() | ||
self.angle= args['angle'] if 'angle' in args else 0 #pick(range(-max_angle,max_angle)) | ||
self.pos = args['pos'] if 'pos' in args else {'x':pick(range(0,canvas_size)),'y':pick(range(0, canvas_size))} | ||
# self.style= args['style'] if 'style' in args else self.get_style(self.font)# pick(styles) | ||
self.invert = args['invert'] if 'invert' in args else pick([-1, 0, 1]) | ||
self.text = args['text'] if 'text' in args else random_word() | ||
# if chaotic: # captcha style (or syntax highlighting?) | ||
# self.letters=[letter.letter(args,char=char) for char in self.word] # almost java style ;) | ||
# else: one word, one style! | ||
|
||
def projection(self): | ||
return self.matrix(),self.ord | ||
|
||
def matrix(self, normed=true): | ||
# type: (bool) -> ndarray | ||
matrix = np.array(self.image()) | ||
if normed: matrix = matrix / 255. | ||
if self.invert == -1: | ||
matrix = 1 - 2 * matrix # -1..1 | ||
elif self.invert: | ||
matrix = 1 - matrix # 0..1 | ||
return matrix | ||
# except: | ||
# return np.array(max_size*(max_size+extra_y)) | ||
|
||
def image(self): | ||
ttf_font = self.load_font() | ||
padding = self.pos | ||
size = [canvas_size, canvas_size] | ||
if self.back: | ||
img = Image.new('RGBA', size, self.back) # background_color | ||
else: | ||
img = Image.new('L', size, 'white') # grey | ||
draw = ImageDraw.Draw(img) | ||
draw.text((padding['x'], padding['y']), self.text, font=ttf_font, fill=self.color) | ||
if self.angle: | ||
rot = img.rotate(self.angle, expand=1).resize(size) | ||
if self.back: | ||
img = Image.new('RGBA', size, self.back) # background_color | ||
else: | ||
img = Image.new('L', size,'#FFF')#FUCK BUG! 'white')#,'grey') # # grey | ||
img.paste(PIL.ImageOps.colorize(rot, (0, 0, 0),self.back ) (0, 0), rot) | ||
return img | ||
|
||
def load_font(self): | ||
fontPath = self.font if '/' in self.font else letter.fonts_dir + self.font | ||
try: | ||
fontPath = fontPath.strip() | ||
ttf_font = ImageFont.truetype(fontPath, self.size) | ||
except: | ||
raise Exception("BAD FONT: " + fontPath) | ||
return ttf_font | ||
|
||
def show(self): | ||
self.image().show() | ||
|
||
def __str__(self): | ||
format="text{'%s',size=%d,font='%s',position=%s}" # angle=%d, | ||
return format % (self.text, self.size, self.font, self.pos) | ||
|
||
def __repr__(self): | ||
return self.__str__() | ||
|
||
def save(self, path): | ||
self.image().save(path) | ||
|
||
|
||
# @classmethod # can access class cls | ||
# def ls(cls, mypath=None): | ||
|
||
# @staticmethod # CAN'T access class | ||
# def ls(mypath): | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def show_matrix(mat): | ||
plt.matshow(mat, fignum=1) | ||
# plt.imshow(image) | ||
plt.draw() | ||
plt.waitforbuttonpress() | ||
|
||
|
||
def show_image(image): | ||
plt.imshow(image) | ||
plt.draw() | ||
plt.waitforbuttonpress() | ||
|
||
|
||
if __name__ == "__main__": | ||
while 1: | ||
# l = word(text='hello') | ||
w = word() | ||
# l.save("letters/letter_%s_%d.png"%(l.char,l.size)) | ||
print(w) | ||
try: | ||
# show_matrix(mat) | ||
image = w.image() | ||
show_image(image) | ||
del(image) | ||
# mat = w.matrix() | ||
# print(np.average(mat)) | ||
# print(np.max(mat)) | ||
# print(np.min(mat)) | ||
# | ||
except KeyboardInterrupt: | ||
print("HOW??") | ||
exit() | ||
# return | ||
break | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#!/usr/bin/env python | ||
#!/usr/bin/python | ||
import layer | ||
import letter | ||
# import tensorflow as tf | ||
# import layer.baselines | ||
|
||
# layer.clear_tensorboard() # Get rid of old runs | ||
|
||
data = letter.batch() | ||
input_width, output_width=data.shape[0],data.shape[1] | ||
|
||
# learning_rate = 0.03 # divergence even on overfit | ||
# learning_rate = 0.003 # quicker overfit | ||
learning_rate = 0.0003 | ||
|
||
nClasses =letter.nLetters | ||
training_steps = 500000 | ||
batch_size = 64 | ||
size = letter.max_size | ||
|
||
|
||
# OH, it does converge | ||
# Test Accuracy: ~0.875 Step 1.000.000 52148s | ||
def denseConv(net): | ||
# type: (layer.net) -> None | ||
print("Building dense-net") | ||
net.reshape(shape=[-1, size, size, 1]) # Reshape input picture | ||
net.buildDenseConv(nBlocks=1) | ||
net.classifier() # 10 classes auto | ||
|
||
|
||
""" Baseline tests to see that your model doesn't have any bugs and can learn small test sites without efforts """ | ||
|
||
# net = layer.net(layer.baseline, input_width=size, output_width=nClasses, learning_rate=learning_rate) | ||
# learning_rate: 0.003: full overfit at Step 800 | ||
# learning_rate: 0.0003: full overfit at Step 2400 | ||
|
||
# net = layer.net(layer.baselineDeep3, input_width=size, output_width=nClasses, learning_rate=learning_rate) | ||
# learning_rate: 0.003: overfit 98% at Step 5000 | ||
# learning_rate: 0.0003: full overfit at Step 24000 | ||
|
||
# net = layer.net(layer.baselineBatchNormDeep, input_width=size, output_width=nClasses, learning_rate=learning_rate) | ||
# learning_rate: 0.003: overfit 98% at Step 3000 ++ | ||
|
||
# net = layer.net(layer.baselineDenseConv, input_width=size, output_width=nClasses, learning_rate=learning_rate) | ||
# learning_rate: 0.003: overfit 98% at Step 3000 ++ | ||
|
||
# alex = broken baseline! lol, how? | ||
# net = layer.net(layer.alex, input_width=size, output_width=nClasses, learning_rate=.001) | ||
|
||
# net.train(data=data, test_step=1000) # run | ||
|
||
""" here comes the real network """ | ||
|
||
# net=layer.net(alex,input_width=28, output_width=nClasses, learning_rate=learning_rate) # NOPE!? | ||
net = layer.net(denseConv, input_width=size, output_width=nClasses, learning_rate=learning_rate) | ||
|
||
# net.train(data=data,steps=50000,dropout=0.6,display_step=1,test_step=1) # debug | ||
# net.train(data=data, steps=training_steps,dropout=0.6,display_step=5,test_step=20) # test | ||
net.train(data=data, dropout=.6, display_step=10, test_step=1000) # run resume | ||
|
||
# net.predict() # nil=random | ||
# net.generate(3) # nil=random |