Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,9 +1424,8 @@ def predict(
for regression. For classification, the previous options are applied to the confidence
scores (soft voting) and then converted to final prediction. An additional option
"hard_voting" is available for classification.
If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
and returns a 2D array (num_samples, num_targets). Defaults to "mean".

If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".

Returns:
DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
Expand Down Expand Up @@ -1454,7 +1453,6 @@ def add_noise(module, input, output):

# Register the hook to the embedding_layer
handle = self.model.embedding_layer.register_forward_hook(add_noise)
pred_l = []
pred_prob_l = []
for _ in range(num_tta):
pred_df = self._predict(
Expand All @@ -1468,11 +1466,10 @@ def add_noise(module, input, output):
)
pred_idx = pred_df.index
if self.config.task == "classification":
pred_l.append(pred_df.values[:, -len(self.config.target) :].astype(int))
pred_prob_l.append(pred_df.values[:, : -len(self.config.target)])
elif self.config.task == "regression":
pred_prob_l.append(pred_df.values)
pred_df = self._combine_predictions(pred_l, pred_prob_l, pred_idx, aggregate_tta, None)
pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate_tta, None)
# Remove the hook
handle.remove()
else:
Expand Down Expand Up @@ -1993,7 +1990,6 @@ def cross_validate(

def _combine_predictions(
self,
pred_l: List[DataFrame],
pred_prob_l: List[DataFrame],
pred_idx: Union[pd.Index, List],
aggregate: Union[str, Callable],
Expand All @@ -2008,15 +2004,16 @@ def _combine_predictions(
elif aggregate == "max":
bagged_pred = np.max(pred_prob_l, axis=0)
elif aggregate == "hard_voting" and self.config.task == "classification":
pred_l = [np.argmax(p, axis=1) for p in pred_prob_l]
final_pred = np.apply_along_axis(
lambda x: np.argmax(np.bincount(x)),
axis=0,
arr=[p[:, -1].astype(int) for p in pred_l],
arr=pred_l,
)
elif callable(aggregate):
final_pred = bagged_pred = aggregate(pred_prob_l)
bagged_pred = aggregate(pred_prob_l)
if self.config.task == "classification":
if aggregate == "hard_voting" or callable(aggregate):
if aggregate == "hard_voting":
pred_df = pd.DataFrame(
np.concatenate(pred_prob_l, axis=1),
columns=[
Expand Down Expand Up @@ -2094,8 +2091,8 @@ def bagging_predict(
for regression. For classification, the previous options are applied to the confidence
scores (soft voting) and then converted to final prediction. An additional option
"hard_voting" is available for classification.
If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
and returns a 2D array (num_samples, num_targets). Defaults to "mean".
If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".

weights (Optional[List[float]], optional): The weights to be used for aggregating the predictions
from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
Expand All @@ -2122,7 +2119,6 @@ def bagging_predict(
assert aggregate != "hard_voting", "hard_voting is only available for classification"
cv = self._check_cv(cv)
prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
pred_l = []
pred_prob_l = []
datamodule = None
model = None
Expand All @@ -2149,15 +2145,14 @@ def bagging_predict(
fold_preds = self.predict(test, include_input_features=False)
pred_idx = fold_preds.index
if self.config.task == "classification":
pred_l.append(fold_preds.values[:, -len(self.config.target) :].astype(int))
pred_prob_l.append(fold_preds.values[:, : -len(self.config.target)])
elif self.config.task == "regression":
pred_prob_l.append(fold_preds.values)
if verbose:
logger.info(f"Fold {fold+1}/{cv.get_n_splits()} prediction done")
self.model.reset_weights()
pred_df = self._combine_predictions(pred_l, pred_prob_l, pred_idx, aggregate, weights)
pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate, weights)
if return_raw_predictions:
return pred_df, pred_l, pred_prob_l
return pred_df, pred_prob_l
else:
return pred_df
4 changes: 2 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def _run_bagging(
@pytest.mark.parametrize("cv", [2])
@pytest.mark.parametrize(
"aggregate",
["mean", "median", "min", "max", "hard_voting", lambda x: np.argmax(np.median(x, axis=0), axis=1)],
["mean", "median", "min", "max", "hard_voting", lambda x: np.median(x, axis=0)],
)
def test_bagging_classification(
classification_data,
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def _run_tta(
@pytest.mark.parametrize("categorical_cols", [["feature_0_cat"]])
@pytest.mark.parametrize(
"aggregate",
["mean", "median", "min", "max", "hard_voting", lambda x: np.argmax(np.median(x, axis=0), axis=1)],
["mean", "median", "min", "max", "hard_voting", lambda x: np.median(x, axis=0)],
)
def test_tta_classification(
classification_data,
Expand Down