Skip to content

Commit

Permalink
Fixed key error in h2o.balance
Browse files Browse the repository at this point in the history
  • Loading branch information
tgsmith61591 committed Sep 7, 2016
1 parent 3bf6726 commit 8b67ac7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion skutil/h2o/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def _validate_x_y_ratio(X, y, ratio):
y = _validate_target(y) # cast to string type

# generate cts. Have to get kludgier in h2o...
unq_vals = X[y].unique().as_data_frame(use_pandas=True)[y].values # numpy array of unique vals
unq_vals = X[y].unique()
unq_vals = unq_vals.as_data_frame(use_pandas=True)[unq_vals.columns[0]].values # numpy array of unique vals
unq_cts = dict([(val, X[y][X[y]==val].shape[0]) for val in unq_vals])

# validate is < max classes
Expand Down
4 changes: 2 additions & 2 deletions skutil/h2o/tests/test_h2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,15 +1192,15 @@ def balance():

# do a real undersample
x = Y[:60, :] # 50 zeros, 10 ones
b = UndersamplingClassBalancer(y='species', ratio=0.5).balance(x).as_data_frame(use_pandas=True)
b = H2OUndersamplingClassBalancer(target_feature='species', ratio=0.5).balance(x).as_data_frame(use_pandas=True)
assert b.shape[0] == 30
cts = b.species.value_counts()
assert cts[0] == 20
assert cts[1] == 10

# assert oversampling works
y = Y[:105, :]
d = H2OOversamplingClassBalancer(y='species', ratio=1.0).balance(y).as_data_frame(use_pandas=True)
d = H2OOversamplingClassBalancer(target_feature='species', ratio=1.0).balance(y).as_data_frame(use_pandas=True)
assert d.shape[0] == 150

cts= d.species.value_counts()
Expand Down

0 comments on commit 8b67ac7

Please sign in to comment.