Skip to content

Commit

Permalink
Added second param set for the estimators in _proximity_forest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Abelarm committed Oct 2, 2022
1 parent 395346c commit b970962
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions sktime/classification/distance_based/_proximity_forest.py
Expand Up @@ -973,7 +973,7 @@ def _predict_proba(self, X) -> np.ndarray:
return distributions

@classmethod
def get_test_params(cls):
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Returns
Expand All @@ -984,10 +984,14 @@ def get_test_params(cls):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params = {
params1 = {
"random_state": 0,
}
return params
params2 = {
"random_state": 42,
"distance_measure": "dtw"
}
return [params1, params2]


class ProximityTree(BaseClassifier):
Expand Down Expand Up @@ -1226,7 +1230,16 @@ def get_test_params(cls, parameter_set="default"):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`.
"""
return {"max_depth": 1, "n_stump_evaluations": 1}
params1 = {
"max_depth": 1,
"n_stump_evaluations": 1
}
params2 = {
"max_depth": 1,
"n_stump_evaluations": 1,
"distance_measure": "dtw"
}
return [params1, params2]


class ProximityForest(BaseClassifier):
Expand Down Expand Up @@ -1504,7 +1517,15 @@ def get_test_params(cls, parameter_set="default"):
if parameter_set == "results_comparison":
return {"n_estimators": 3, "max_depth": 2, "n_stump_evaluations": 2}
else:
return {"n_estimators": 2, "max_depth": 1, "n_stump_evaluations": 1}
param1 = {
"n_estimators": 2,
"max_depth": 1,
"n_stump_evaluations": 1}
param2 = {"n_estimators": 2,
"max_depth": 1,
"n_stump_evaluations": 1,
"distance_measure": "dtw"}
return [param1, param2]


# start of util functions
Expand Down

0 comments on commit b970962

Please sign in to comment.