Skip to content

Commit

Permalink
saving and loading from svmlight (no hdf5)
Browse files Browse the repository at this point in the history
  • Loading branch information
PirosB3 committed Mar 11, 2014
1 parent f0a94fb commit 9af5447
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
9 changes: 4 additions & 5 deletions applications/classification/evaluate_multiclass_labels.py
Expand Up @@ -31,9 +31,10 @@

import argparse
import logging
import numpy as np
from contextlib import contextmanager, closing
from modshogun import (LibSVMFile, MulticlassLabels,
SerializableHdf5File, MulticlassAccuracy)
MulticlassAccuracy)
from utils import get_features_and_labels

LOGGER = logging.getLogger(__file__)
Expand All @@ -55,10 +56,8 @@ def main(actual, predicted):
feats, labels = get_features_and_labels(LibSVMFile(actual))

# Load predicted labels
predicted_labels = MulticlassLabels()
predicted_labels_file = SerializableHdf5File(predicted, 'r')
with closing(predicted_labels_file):
predicted_labels.load_serializable(predicted_labels_file)
with open(predicted, 'r') as f:
predicted_labels = MulticlassLabels(np.array([float(l) for l in f]))

multiclass_measures = MulticlassAccuracy()
LOGGER.info("Accuracy = %s" % multiclass_measures.evaluate(
Expand Down
7 changes: 4 additions & 3 deletions applications/classification/test_multiclass_svm.py
Expand Up @@ -62,9 +62,10 @@ def main(classifier, testset, output):
test_feats, test_labels = get_features_and_labels(LibSVMFile(testset))
predicted_labels = svm.apply(test_feats)

predicted_labels_output = SerializableHdf5File(output, 'w')
with closing(predicted_labels_output):
predicted_labels.save_serializable(predicted_labels_output)
with open(output, 'w') as f:
for cls in predicted_labels.get_labels():
f.write("%s\n" % int(cls))

LOGGER.info("Predicted labels saved in: '%s'" % output)


Expand Down

0 comments on commit 9af5447

Please sign in to comment.