diff --git a/applications/classification/evaluate_multiclass_labels.py b/applications/classification/evaluate_multiclass_labels.py index ce09fb51278..f11d99f25d2 100644 --- a/applications/classification/evaluate_multiclass_labels.py +++ b/applications/classification/evaluate_multiclass_labels.py @@ -31,9 +31,10 @@ import argparse import logging +import numpy as np from contextlib import contextmanager, closing from modshogun import (LibSVMFile, MulticlassLabels, - SerializableHdf5File, MulticlassAccuracy) + MulticlassAccuracy) from utils import get_features_and_labels LOGGER = logging.getLogger(__file__) @@ -55,10 +56,8 @@ def main(actual, predicted): feats, labels = get_features_and_labels(LibSVMFile(actual)) # Load predicted labels - predicted_labels = MulticlassLabels() - predicted_labels_file = SerializableHdf5File(predicted, 'r') - with closing(predicted_labels_file): - predicted_labels.load_serializable(predicted_labels_file) + with open(predicted, 'r') as f: + predicted_labels = MulticlassLabels(np.array([float(l) for l in f])) multiclass_measures = MulticlassAccuracy() LOGGER.info("Accuracy = %s" % multiclass_measures.evaluate( diff --git a/applications/classification/test_multiclass_svm.py b/applications/classification/test_multiclass_svm.py index 4f4834ca19c..8772a233ab2 100644 --- a/applications/classification/test_multiclass_svm.py +++ b/applications/classification/test_multiclass_svm.py @@ -62,9 +62,10 @@ def main(classifier, testset, output): test_feats, test_labels = get_features_and_labels(LibSVMFile(testset)) predicted_labels = svm.apply(test_feats) - predicted_labels_output = SerializableHdf5File(output, 'w') - with closing(predicted_labels_output): - predicted_labels.save_serializable(predicted_labels_output) + with open(output, 'w') as f: + for cls in predicted_labels.get_labels(): + f.write("%s\n" % int(cls)) + LOGGER.info("Predicted labels saved in: '%s'" % output)