Permalink
Browse files

make snakes example faster with opengm

  • Loading branch information...
amueller committed Sep 10, 2013
1 parent fe8c030 commit f382630a35264e124e7b4fcc7da619a62709a22b
Showing with 11 additions and 6 deletions.
  1. +11 −6 examples/plot_snakes.py
View
@@ -40,6 +40,7 @@
from pystruct.datasets import load_snakes
from pystruct.utils import make_grid_edges, edge_list_to_features
from pystruct.models import EdgeFeatureGraphCRF
+from pystruct.inference import get_installed
def one_hot_colors(x):
@@ -101,11 +102,15 @@ def main():
Y_train_flat = [y_.ravel() for y_ in Y_train]
X_train_directions, X_train_edge_features = prepare_data(X_train)
-
+
+ if 'ogm' in get_installed():
+ inference = ('ogm', {'alg': 'fm'})
+ else:
+ inference = 'qpbo'
# first, train on X with directions only:
- crf = EdgeFeatureGraphCRF(inference_method='qpbo')
- ssvm = OneSlackSSVM(crf, inference_cache=50, C=.1, tol=.1, switch_to='ad3',
- n_jobs=-1)
+ crf = EdgeFeatureGraphCRF(inference_method=inference)
+ ssvm = OneSlackSSVM(crf, inference_cache=50, C=.1, tol=.1, max_iter=100,
+ n_jobs=1)
ssvm.fit(X_train_directions, Y_train_flat)
# Evaluate using confusion matrix.
@@ -121,9 +126,9 @@ def main():
print(confusion_matrix(np.hstack(Y_test_flat), np.hstack(Y_pred)))
# now, use more informative edge features:
- crf = EdgeFeatureGraphCRF(inference_method='qpbo')
+ crf = EdgeFeatureGraphCRF(inference_method=inference)
ssvm = OneSlackSSVM(crf, inference_cache=50, C=.1, tol=.1, switch_to='ad3',
- n_jobs=-1)
+ n_jobs=1)
ssvm.fit(X_train_edge_features, Y_train_flat)
Y_pred2 = ssvm.predict(X_test_edge_features)
print("Results using also input features for edges")

0 comments on commit f382630

Please sign in to comment.