Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/shogun-toolbox/shogun in…
Browse files Browse the repository at this point in the history
…to develop
  • Loading branch information
votjakovr committed Apr 30, 2013
2 parents bf4df78 + c7a100a commit b30b8ea
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 32 deletions.
Expand Up @@ -29,6 +29,7 @@ def serialization_complex_example (num=5, dist=1, dim=10, C=2.0, width=10):

feats=RealFeatures(data)
#feats.io.set_loglevel(MSG_DEBUG)
#feats.io.enable_file_and_line()
kernel=GaussianKernel(feats, feats, width)

labels=MulticlassLabels(lab)
Expand All @@ -50,9 +51,9 @@ def serialization_complex_example (num=5, dist=1, dim=10, C=2.0, width=10):
status = svm.save_serializable(fstream)
check_status(status,'asc')

#fstream = SerializableJsonFile("blaah.json", "w")
#status = svm.save_serializable(fstream)
#check_status(status,'json')
fstream = SerializableJsonFile("blaah.json", "w")
status = svm.save_serializable(fstream)
check_status(status,'json')

fstream = SerializableXmlFile("blaah.xml", "w")
status = svm.save_serializable(fstream)
Expand All @@ -71,11 +72,11 @@ def serialization_complex_example (num=5, dist=1, dim=10, C=2.0, width=10):
check_status(status,'asc')
new_svm.train()

#fstream = SerializableJsonFile("blaah.json", "r")
#new_svm=GMNPSVM()
#status = new_svm.load_serializable(fstream)
#check_status(status,'json')
#new_svm.train()
fstream = SerializableJsonFile("blaah.json", "r")
new_svm=GMNPSVM()
status = new_svm.load_serializable(fstream)
check_status(status,'json')
new_svm.train()

fstream = SerializableXmlFile("blaah.xml", "r")
new_svm=GMNPSVM()
Expand All @@ -85,7 +86,7 @@ def serialization_complex_example (num=5, dist=1, dim=10, C=2.0, width=10):

os.unlink("blaah.h5")
os.unlink("blaah.asc")
#os.unlink("blaah.json")
os.unlink("blaah.json")
os.unlink("blaah.xml")
return svm,new_svm

Expand Down
46 changes: 30 additions & 16 deletions src/shogun/io/SerializableJsonFile.cpp
Expand Up @@ -123,9 +123,8 @@ CSerializableJsonFile::close()

if (m_stack_stream.get_num_elements() == 1) {
if (m_task == 'w'
&& is_error(
json_object_to_file(m_filename, m_stack_stream.back())
)) {
&& json_object_to_file(m_filename, m_stack_stream.back()))
{
SG_WARNING("Could not close file `%s' for writing!\n",
m_filename);
}
Expand Down Expand Up @@ -193,7 +192,8 @@ CSerializableJsonFile::write_scalar_wrapped(
return false;
}

if (is_error(m_stack_stream.back())) return false;
if (is_error(m_stack_stream.back()))
return false;

return true;
}
Expand Down Expand Up @@ -248,8 +248,8 @@ CSerializableJsonFile::write_stringentry_end_wrapped(
json_object* array = m_stack_stream.get_element(
m_stack_stream.get_num_elements() - 2);

if (is_error(json_object_array_put_idx(
array, y, m_stack_stream.back()))) return false;
if (json_object_array_put_idx( array, y, m_stack_stream.back()))
return false;

pop_object();
return true;
Expand All @@ -262,7 +262,9 @@ CSerializableJsonFile::write_sparse_begin_wrapped(
push_object(json_object_new_object());

json_object* buf = json_object_new_array();
if (is_error(buf)) return false;
if (is_error(buf))
return false;

json_object_object_add(m_stack_stream.back(),
STR_KEY_SPARSE_FEATURES, buf);

Expand All @@ -284,12 +286,15 @@ CSerializableJsonFile::write_sparseentry_begin_wrapped(
index_t feat_index, index_t y)
{
json_object* buf = json_object_new_object();
if (is_error(json_object_array_put_idx(m_stack_stream.back(), y,
buf))) return false;
if (json_object_array_put_idx(m_stack_stream.back(), y, buf))
return false;

push_object(buf);

buf = json_object_new_int(feat_index);
if (is_error(buf)) return false;
if (is_error(buf))
return false;

json_object_object_add(m_stack_stream.back(),
STR_KEY_SPARSE_FEATINDEX, buf);

Expand Down Expand Up @@ -340,28 +345,34 @@ CSerializableJsonFile::write_sgserializable_begin_wrapped(
EPrimitiveType generic)
{
if (*sgserializable_name == '\0') {
push_object(NULL); return true;
push_object(NULL);
return true;
}

push_object(json_object_new_object());

json_object* buf;
buf = json_object_new_string(sgserializable_name);
if (is_error(buf)) return false;
if (is_error(buf))
return false;

json_object_object_add(m_stack_stream.back(),
STR_KEY_INSTANCE_NAME, buf);

if (generic != PT_NOT_GENERIC) {
string_t buf_str;
TSGDataType::ptype_to_string(buf_str, generic, STRING_LEN);
buf = json_object_new_string(buf_str);
if (is_error(buf)) return false;
if (is_error(buf))
return false;

json_object_object_add(m_stack_stream.back(),
STR_KEY_GENERIC_NAME, buf);
}

buf = json_object_new_object();
if (is_error(buf)) return false;
if (is_error(buf))
return false;
json_object_object_add(m_stack_stream.back(), STR_KEY_INSTANCE,
buf);
push_object(buf);
Expand All @@ -385,15 +396,18 @@ CSerializableJsonFile::write_type_begin_wrapped(
const TSGDataType* type, const char* name, const char* prefix)
{
json_object* buf = json_object_new_object();
if (is_error(buf)) return false;
if (is_error(buf))
return false;

json_object_object_add(m_stack_stream.back(), name, buf);
push_object(buf);

string_t str_buf;
type->to_string(str_buf, STRING_LEN);
buf = json_object_new_string(str_buf);
if (is_error(buf)) return false;
if (is_error(buf))
return false;

json_object_object_add(m_stack_stream.back(), STR_KEY_TYPE, buf);

return true;
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/io/SerializableJsonReader00.cpp
Expand Up @@ -304,8 +304,8 @@ SerializableJsonReader00::read_type_begin_wrapped(
if (strcmp(str_buf, json_object_get_string(buf)) != 0)
return false;

if (!m_file->get_object_any(&buf, buf_type, STR_KEY_DATA))
return false;
// data (and so buf) can be NULL for empty objects
m_file->get_object_any(&buf, buf_type, STR_KEY_DATA);
m_file->push_object(buf);

return true;
Expand Down
36 changes: 31 additions & 5 deletions tests/integration/python_modular/tester.py
Expand Up @@ -6,6 +6,7 @@
import filecmp
import numpy
import sys
import difflib

from generator import setup_tests, get_fname, blacklist, get_test_mod, run_test

Expand Down Expand Up @@ -84,6 +85,23 @@ def compare_dbg_helper(a, b, tolerance):
print "b", b
return False

def get_fail_string(a):
failed_string = []
if type(a) in (tuple,list):
for i in xrange(len(a)):
failed_string.append(get_fail_string(a[i]))
elif isinstance(a, modshogun.SGObject):
failed_string.append(pickle.dumps(a))
else:
failed_string.append(str(a))
return failed_string

def get_split_string(a):
strs=[]
for l in a:
strs.extend(l[0].replace('\\n','\n').splitlines())
return strs

def tester(tests, cmp_method, tolerance, failures, missing):
failed=[]

Expand Down Expand Up @@ -112,10 +130,7 @@ def tester(tests, cmp_method, tolerance, failures, missing):
print "%-60s OK" % setting_str
else:
if not missing:
failed_string = []
failed_string.append(a)
failed_string.append(b)
failed.append((setting_str, failed_string))
failed.append((setting_str, get_fail_string(a), get_fail_string(b)))
print "%-60s ERROR" % setting_str
except Exception, e:
print setting_str, e
Expand Down Expand Up @@ -154,6 +169,17 @@ def tester(tests, cmp_method, tolerance, failures, missing):
print
print "The following tests failed!"
for f in failed:
print "\t", f
print "\t", f[0]
expected=get_split_string(f[1])
got=get_split_string(f[2])
print "=== EXPECTED =========="
#import pdb
#pdb.set_trace()
print '\n'.join(expected)
print "=== GOT ==============="
print '\n'.join(got)
print "====DIFF================"
print '\n'.join(difflib.unified_diff(expected, got, fromfile='expected', tofile='got'))
print "====EOT================"
sys.exit(1)
sys.exit(0)

0 comments on commit b30b8ea

Please sign in to comment.