Skip to content

Commit

Permalink
Temporarily allows to set algorithm switching through an environment …
Browse files Browse the repository at this point in the history
…variable
  • Loading branch information
arjoly committed Nov 6, 2014
1 parent c31d565 commit bf98916
Show file tree
Hide file tree
Showing 3 changed files with 2,927 additions and 2,647 deletions.
42 changes: 21 additions & 21 deletions benchmarks/bench_20newsgroups.py
Expand Up @@ -15,12 +15,12 @@

ESTIMATORS = {
"dummy": DummyClassifier(),
"random_forest": RandomForestClassifier(n_estimators=200,
max_features="log2",
"random_forest": RandomForestClassifier(n_estimators=100,
max_features="sqrt",
# min_samples_split=10
),
"extra_trees": ExtraTreesClassifier(n_estimators=200,
max_features="log2",
"extra_trees": ExtraTreesClassifier(n_estimators=100,
max_features="sqrt",
# min_samples_split=10
),
}
Expand All @@ -44,21 +44,21 @@
y_train = data_train.target
y_test = data_test.target

print("20 newsgroups")
print("=============")
print("X_train.shape = {0}".format(X_train.shape))
print("X_train.format = {0}".format(X_train.format))
print("X_train.dtype = {0}".format(X_train.dtype))
print("X_train density = {0}"
"".format(X_train.nnz / np.product(X_train.shape)))
print("y_train {0}".format(y_train.shape))
print("X_test {0}".format(X_test.shape))
print("X_test.format = {0}".format(X_test.format))
print("X_test.dtype = {0}".format(X_test.dtype))
print("y_test {0}".format(y_test.shape))
print()
print("Classifier Training")
print("===================")
# print("20 newsgroups")
# print("=============")
# print("X_train.shape = {0}".format(X_train.shape))
# print("X_train.format = {0}".format(X_train.format))
# print("X_train.dtype = {0}".format(X_train.dtype))
# print("X_train density = {0}"
# "".format(X_train.nnz / np.product(X_train.shape)))
# print("y_train {0}".format(y_train.shape))
# print("X_test {0}".format(X_test.shape))
# print("X_test.format = {0}".format(X_test.format))
# print("X_test.dtype = {0}".format(X_test.dtype))
# print("y_test {0}".format(y_test.shape))
# print()
# print("Classifier Training")
# print("===================")
accuracy, train_time, test_time = {}, {}, {}
for name in sorted(args["estimators"]):
clf = ESTIMATORS[name]
Expand All @@ -67,15 +67,15 @@
except (TypeError, ValueError):
pass

print("Training %s ... " % name, end="")
# print("Training %s ... " % name, end="")
t0 = time()
clf.fit(X_train, y_train)
train_time[name] = time() - t0
t0 = time()
y_pred = clf.predict(X_test)
test_time[name] = time() - t0
accuracy[name] = accuracy_score(y_test, y_pred)
print("done")
# print("done")

print()
print("Classification performance:")
Expand Down

0 comments on commit bf98916

Please sign in to comment.