Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
foulwall
committed
Jul 18, 2013
1 parent
41158b6
commit 20cd8b9
Showing
12 changed files
with
317 additions
and
18 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# File : $HeadURL$ | ||
# Version: $Id$ | ||
|
||
from modshogun import RealFeatures, MulticlassLabels | ||
from modshogun import GaussianKernel | ||
from modshogun import GMNPSVM | ||
|
||
import numpy as np | ||
import gzip as gz | ||
import pickle as pkl | ||
|
||
TRAIN_SVM_FNAME_GZ = "data/ocr.svm.gz" | ||
|
||
NEAR_ZERO_POS = 1e-8 | ||
NEAR_ONE_NEG = 1-NEAR_ZERO_POS | ||
|
||
TRAIN_X_FNAME = "data/train_data_x.asc.gz" | ||
TRAIN_Y_FNAME = "data/train_data_y.asc.gz" | ||
|
||
MATIX_IMAGE_SIZE = 16 | ||
FEATURE_DIM = MATIX_IMAGE_SIZE * MATIX_IMAGE_SIZE | ||
|
||
HISTORY_WIDTH = 5 | ||
HISTORY_HEIGHT = 2 | ||
|
||
FEATURE_RANGE_MAX = 1.0 | ||
|
||
|
||
class Ai: | ||
def __init__(self): | ||
self.x = None | ||
self.y = None | ||
|
||
self.x_test = None | ||
self.y_test = None | ||
|
||
self.svm = None | ||
|
||
def load_train_data(self, x_fname, y_fname): | ||
Ai.__init__(self) | ||
|
||
self.x = np.loadtxt(x_fname) | ||
self.y = np.loadtxt(y_fname) - 1.0 | ||
|
||
self.x_test = self.x | ||
self.y_test = self.y | ||
|
||
def _svm_new(self, kernel_width, c, epsilon): | ||
if self.x == None or self.y == None: | ||
raise Exception("No training data loaded.") | ||
|
||
x = RealFeatures(self.x) | ||
y = MulticlassLabels(self.y) | ||
|
||
self.svm = GMNPSVM(c, GaussianKernel(x, x, kernel_width), y) | ||
self.svm.set_epsilon(epsilon) | ||
|
||
def write_svm(self): | ||
gz_stream = gz.open(TRAIN_SVM_FNAME_GZ, 'wb', 9) | ||
pkl.dump(self.svm, gz_stream) | ||
gz_stream.close() | ||
|
||
def read_svm(self): | ||
gz_stream = gz.open(TRAIN_SVM_FNAME_GZ, 'rb') | ||
self.svm = pkl.load(gz_stream) | ||
gz_stream.close() | ||
|
||
def enable_validation(self, train_frac): | ||
x = self.x | ||
y = self.y | ||
|
||
idx = np.arange(len(y)) | ||
np.random.shuffle(idx) | ||
train_idx=idx[:np.floor(train_frac*len(y))] | ||
test_idx=idx[np.ceil(train_frac*len(y)):] | ||
|
||
self.x = x[:,train_idx] | ||
self.y = y[train_idx] | ||
self.x_test = x[:,test_idx] | ||
self.y_test = y[test_idx] | ||
|
||
def train(self, kernel_width, c, epsilon): | ||
self._svm_new(kernel_width, c, epsilon) | ||
|
||
x = RealFeatures(self.x) | ||
self.svm.io.enable_progress() | ||
self.svm.train(x) | ||
self.svm.io.disable_progress() | ||
|
||
def load_classifier(self): self.read_svm() | ||
|
||
def classify(self, matrix): | ||
cl = self.svm.apply( | ||
RealFeatures( | ||
np.reshape(matrix, newshape=(FEATURE_DIM, 1), | ||
order='F') | ||
) | ||
).get_label(0) | ||
|
||
return int(cl + 1.0) % 10 | ||
|
||
def get_test_error(self): | ||
self.svm.io.enable_progress() | ||
l = self.svm.apply(RealFeatures(self.x_test)).get_labels() | ||
self.svm.io.disable_progress() | ||
|
||
return 1.0 - np.mean(l == self.y_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from Ai import Ai | ||
ai = Ai() | ||
ai.load_classifier() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,48 @@ | ||
{% extends "default.html" %} | ||
{% block javascript %} | ||
<script> | ||
function recognize_action(){ | ||
$.ajax( | ||
{ | ||
url: 'recognize', | ||
type: 'POST', | ||
dataType: "text", | ||
data: { | ||
'csrfmiddlewaretoken': '{{ csrf_token }}', | ||
'lines': JSON.stringify(lines)}, | ||
success: show_digit, | ||
}); | ||
} | ||
function show_digit(data) | ||
{ | ||
data = JSON.parse(data); | ||
preview_svg.selectAll(".preview_blocks") | ||
.data(data['thumb']).enter() | ||
.append('g') | ||
.selectAll(".preview_blocks") | ||
.data(Object) | ||
.enter() | ||
.append('rect') | ||
.attr("class", "preview_blocks") | ||
.attr("fill", function(d) { if (d) return "black"; else return "none";}) | ||
.attr("x", function(d,i,j){return preview_x(i);}) | ||
.attr("y", function(d,i,j){return preview_y(j);}) | ||
.attr("width", preview_x(1)-preview_x(0)) | ||
.attr("height", preview_y(1)-preview_y(0)); | ||
if (!$("#digit").length) | ||
{ | ||
var digit = document.createElement("div"); | ||
$(".span3").append(digit); | ||
digit.id = "digit"; | ||
} | ||
$("#digit").html("<hr><p>The predict is</p><h2 style='text-align:center;'>" + data['predict'] + "</h2>"); | ||
} | ||
function clear_action() | ||
{ | ||
d3.selectAll(".preview_blocks").remove(); | ||
d3.selectAll(".drew").remove(); | ||
d3.selectAll("#digit").remove(); | ||
lines = []; | ||
} | ||
</script> | ||
{% endblock %} |
Oops, something went wrong.