Skip to content

Commit

Permalink
perfect ocr
Browse files Browse the repository at this point in the history
  • Loading branch information
foulwall committed Jul 18, 2013
1 parent 41158b6 commit 20cd8b9
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 18 deletions.
Binary file added data/ocr.svm.gz
Binary file not shown.
107 changes: 107 additions & 0 deletions demo/Ai.py
@@ -0,0 +1,107 @@
# File : $HeadURL$
# Version: $Id$

from modshogun import RealFeatures, MulticlassLabels
from modshogun import GaussianKernel
from modshogun import GMNPSVM

import numpy as np
import gzip as gz
import pickle as pkl

TRAIN_SVM_FNAME_GZ = "data/ocr.svm.gz"

NEAR_ZERO_POS = 1e-8
NEAR_ONE_NEG = 1-NEAR_ZERO_POS

TRAIN_X_FNAME = "data/train_data_x.asc.gz"
TRAIN_Y_FNAME = "data/train_data_y.asc.gz"

MATIX_IMAGE_SIZE = 16
FEATURE_DIM = MATIX_IMAGE_SIZE * MATIX_IMAGE_SIZE

HISTORY_WIDTH = 5
HISTORY_HEIGHT = 2

FEATURE_RANGE_MAX = 1.0


class Ai:
def __init__(self):
self.x = None
self.y = None

self.x_test = None
self.y_test = None

self.svm = None

def load_train_data(self, x_fname, y_fname):
Ai.__init__(self)

self.x = np.loadtxt(x_fname)
self.y = np.loadtxt(y_fname) - 1.0

self.x_test = self.x
self.y_test = self.y

def _svm_new(self, kernel_width, c, epsilon):
if self.x == None or self.y == None:
raise Exception("No training data loaded.")

x = RealFeatures(self.x)
y = MulticlassLabels(self.y)

self.svm = GMNPSVM(c, GaussianKernel(x, x, kernel_width), y)
self.svm.set_epsilon(epsilon)

def write_svm(self):
gz_stream = gz.open(TRAIN_SVM_FNAME_GZ, 'wb', 9)
pkl.dump(self.svm, gz_stream)
gz_stream.close()

def read_svm(self):
gz_stream = gz.open(TRAIN_SVM_FNAME_GZ, 'rb')
self.svm = pkl.load(gz_stream)
gz_stream.close()

def enable_validation(self, train_frac):
x = self.x
y = self.y

idx = np.arange(len(y))
np.random.shuffle(idx)
train_idx=idx[:np.floor(train_frac*len(y))]
test_idx=idx[np.ceil(train_frac*len(y)):]

self.x = x[:,train_idx]
self.y = y[train_idx]
self.x_test = x[:,test_idx]
self.y_test = y[test_idx]

def train(self, kernel_width, c, epsilon):
self._svm_new(kernel_width, c, epsilon)

x = RealFeatures(self.x)
self.svm.io.enable_progress()
self.svm.train(x)
self.svm.io.disable_progress()

def load_classifier(self): self.read_svm()

def classify(self, matrix):
cl = self.svm.apply(
RealFeatures(
np.reshape(matrix, newshape=(FEATURE_DIM, 1),
order='F')
)
).get_label(0)

return int(cl + 1.0) % 10

def get_test_error(self):
self.svm.io.enable_progress()
l = self.svm.apply(RealFeatures(self.x_test)).get_labels()
self.svm.io.disable_progress()

return 1.0 - np.mean(l == self.y_test)
3 changes: 3 additions & 0 deletions demo/__init__.py
@@ -0,0 +1,3 @@
from Ai import Ai
ai = Ai()
ai.load_classifier()
75 changes: 73 additions & 2 deletions demo/ocr.py
Expand Up @@ -2,18 +2,89 @@
from django.template import RequestContext
from django.shortcuts import render_to_response

import demo
import modshogun as sg
import numpy as np
import json

MATIX_IMAGE_SIZE = 16
FEATURE_RANGE_MAX = 1.0
NEAR_ZERO_POS = 1e-8
NEAR_ONE_NEG = 1-NEAR_ZERO_POS

def entrance(request):
properties = { 'title' : 'Digit Recognize',
'template': {'type': 'drawing'},
'panels': [
{
'panel_name': 'result',
'panel_label': 'Digit'}]}
'panel_name': 'preview',
'panel_label': 'Preview'}]}
return render_to_response("ocr/index.html",
properties,
context_instance = RequestContext(request))

def _draw_line(image, start, end):
start = np.array(start, dtype=np.int)
end = np.array(end, dtype=np.int)

delta = abs(end - start)

e = delta[0]/2.0
x, y = start
image[y, x] = FEATURE_RANGE_MAX
while np.any((x, y) != end):
if e < 0.0 or x == end[0]:
y += -1 if start[1] > end[1] else 1
e += delta[0]
if e >= 0.0 and x != end[0]:
x += -1 if start[0] > end[0] else 1
e -= delta[1]
image[y, x] = FEATURE_RANGE_MAX

def _get_coords(data):
result = map(lambda line: np.array(line), data)
result = map(lambda line: np.transpose(line), result)
minx = 2.0
miny = 2.0
for line in result:
minx = min(minx, min(line[0]))
miny = min(miny, min(line[1]))
for line in result:
line[0] -= minx
line[1] -= miny

maxxy = 0.0
for line in result: maxxy = max(maxxy, line.max())
for line in result: line /= maxxy + NEAR_ZERO_POS

maxx = 0.0
maxy = 0.0
for line in result:
maxx = max(maxx, max(line[0]))
maxy = max(maxy, max(line[1]))
for line in result:
line[0] += (1 - maxx)/2
line[1] += (1 - maxy)/2

result = map(lambda line: np.transpose(line), result)
return result

def recognize(request):
image = np.zeros((16, 16), dtype=np.float)
try:
data = json.loads(request.POST['lines'])
coords = map(lambda line: MATIX_IMAGE_SIZE*line,
_get_coords(data))
image = np.zeros((MATIX_IMAGE_SIZE, MATIX_IMAGE_SIZE),
dtype=np.float)
for line in coords:
for i in range(line.shape[0]-1):
_draw_line(image, line[i], line[i+1])
digit = demo.ai.classify(image)
return HttpResponse(json.dumps({'predict': digit,
'thumb': image.tolist()}))
except:
raise Http404

print image

1 change: 1 addition & 0 deletions shogun_demo/urls.py
Expand Up @@ -29,6 +29,7 @@
url(r'^kernel_matrix/entrance', 'demo.kernel_matrix.entrance'),
url(r'^kernel_matrix/generate', 'demo.kernel_matrix.generate'),
url(r'^ocr/entrance', 'demo.ocr.entrance'),
url(r'^ocr/recognize', 'demo.ocr.recognize'),
url(r'^toy_data/generator/generate', 'toy_data.generator.generate'),
url(r'^toy_data/importer/dump', 'toy_data.importer.dump'),

Expand Down
1 change: 1 addition & 0 deletions templates/coordinate_2dims.js
Expand Up @@ -98,6 +98,7 @@ var start = d3.svg.line()
.x(function (d) {return x(d.x); })
.y(function (d) {return y(0); })
.interpolate('basis');

{% if template.heatmap %}
var heatmap_legend = document.createElement("div");
$('.span9').append(heatmap_legend);
Expand Down
2 changes: 2 additions & 0 deletions templates/default.html
Expand Up @@ -31,6 +31,8 @@
{% include "toy_data_generator_form.html" %}
{% elif panel.panel_name == 'arguments' %}
{% include "arguments_form.html" %}
{% elif panel.panel_name == 'preview' %}
{% include "preview.html" %}
{% endif %}
</div>
{% endfor %}
Expand Down
7 changes: 5 additions & 2 deletions templates/drawing.js
Expand Up @@ -2,8 +2,8 @@ var margin = {top: 20, right: 20, bottom: 30, left: 40};
var width = 660 - margin.left - margin.right;
var height = 610 - margin.top - margin.bottom;

var x = d3.scale.linear().range( [0, width] );
var y = d3.scale.linear().range( [height, 0] );
var x = d3.scale.linear().range( [0, width] ).domain([0,1]);
var y = d3.scale.linear().range( [0, height] ).domain([0,1]);

var xGrid = d3.svg.axis().scale(x).orient("bottom");
var yGrid = d3.svg.axis().scale(y).orient("left");
Expand All @@ -20,14 +20,17 @@ svg.append("g")
.attr("transform", "translate(0, " + height + ")")
.call(xGrid
.tickSize(-height, 0, 0)
.ticks(20)
.tickFormat("")
);
svg.append("g")
.attr("class", "grid")
.attr("id", "y_grid")
.call(yGrid
.tickSize(-width, 0, 0)
.ticks(20)
.tickFormat("")
);


{% include "mouse_click.js" %}
4 changes: 2 additions & 2 deletions templates/kernel_matrix/index.html
Expand Up @@ -17,7 +17,7 @@
return;
}

var z = data['z'];
z = data['z'];
var domain = data['domain'];
var minimum = domain[0];
var maximum = domain[1];
Expand All @@ -35,7 +35,7 @@
svg.selectAll(".heatmap")
.data(z).enter()
.append('g')
.selectAll('.box')
.selectAll('.heatmap')
.data(Object)
.enter()
.append('rect')
Expand Down
49 changes: 37 additions & 12 deletions templates/mouse_click.js
Expand Up @@ -64,33 +64,58 @@ d3.select("svg").node().oncontextmenu = function(){return false;}; //disable rig
{% endif %}
{% elif template.type == 'drawing' %}
var pressed = false;
var line_dots=[];
var last_dot=[];
var lines = [];
canvas_div
.on("mousedown", mouse_down)
.on("mousemove", mouse_move)
.on("mouseup", mouse_up);

function mouse_move(event) {
function mouse_move() {
if(pressed)
{
if (d3.mouse(this)[0]-margin.left < 0 || d3.mouse(this)[1]-margin.top > height)
if (d3.mouse(this)[0]-margin.left <= 0
|| d3.mouse(this)[1]-margin.top >= height
|| d3.mouse(this)[0]-margin.left >= width
|| d3.mouse(this)[1]-margin.top <= 0)
{
mouse_up();
return;
}
var point = d3.mouse(this);
var e = window.event || d3.event;
if(e.button == 2 || e.button == 3)
return;
point[0]-=margin.left;
point[1]-=margin.top;
svg.append("circle")
.attr("class", "dot")
.attr("r", 6)
.attr("cx", point[0])
.attr("cy", point[1]);
if (point.toString() != last_dot.toString())
{
line_dots.push([x.invert(point[0]), y.invert(point[1])]);
last_dot = point.concat();
}

var line = d3.svg.line()
.x(function(d) {return x(d[0]);})
.y(function(d) {return y(d[1]);})
.interpolate('basis');
svg.selectAll(".drawing").remove();
svg.append("path")
.attr("class", "drawing")
.attr("d",line(line_dots))
.style("stroke-width", "30")
.style("stroke", "green")
.style("stroke-linecap","round")
.style("fill", "transparent");
}
}
function mouse_down(event){
function mouse_down(){
pressed = true;
}
function mouse_up(event){
function mouse_up(){
svg.selectAll(".drawing")
.attr("class", "drew");
if(line_dots.length)
lines.push(line_dots);
line_dots = [];

pressed = false;
}

Expand Down
47 changes: 47 additions & 0 deletions templates/ocr/index.html
@@ -1 +1,48 @@
{% extends "default.html" %}
{% block javascript %}
<script>
function recognize_action(){
$.ajax(
{
url: 'recognize',
type: 'POST',
dataType: "text",
data: {
'csrfmiddlewaretoken': '{{ csrf_token }}',
'lines': JSON.stringify(lines)},
success: show_digit,
});
}
function show_digit(data)
{
data = JSON.parse(data);
preview_svg.selectAll(".preview_blocks")
.data(data['thumb']).enter()
.append('g')
.selectAll(".preview_blocks")
.data(Object)
.enter()
.append('rect')
.attr("class", "preview_blocks")
.attr("fill", function(d) { if (d) return "black"; else return "none";})
.attr("x", function(d,i,j){return preview_x(i);})
.attr("y", function(d,i,j){return preview_y(j);})
.attr("width", preview_x(1)-preview_x(0))
.attr("height", preview_y(1)-preview_y(0));
if (!$("#digit").length)
{
var digit = document.createElement("div");
$(".span3").append(digit);
digit.id = "digit";
}
$("#digit").html("<hr><p>The predict is</p><h2 style='text-align:center;'>" + data['predict'] + "</h2>");
}
function clear_action()
{
d3.selectAll(".preview_blocks").remove();
d3.selectAll(".drew").remove();
d3.selectAll("#digit").remove();
lines = [];
}
</script>
{% endblock %}

0 comments on commit 20cd8b9

Please sign in to comment.