Skip to content

Commit

Permalink
Merge pull request #125 from st-tech/add-error-meta
Browse files Browse the repository at this point in the history
fix error of meta
  • Loading branch information
usaito committed Aug 30, 2021
2 parents a4b61e9 + 020a78e commit 37ccdd7
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 9 deletions.
21 changes: 17 additions & 4 deletions obp/ope/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pandas import DataFrame
import seaborn as sns

from .estimators import BaseOffPolicyEstimator
from .estimators import BaseOffPolicyEstimator, DirectMethod as DM, DoublyRobust as DR
from ..types import BanditFeedback
from ..utils import check_confidence_interval_arguments

Expand Down Expand Up @@ -84,8 +84,11 @@ def __post_init__(self) -> None:
if key_ not in self.bandit_feedback:
raise RuntimeError(f"Missing key of {key_} in 'bandit_feedback'.")
self.ope_estimators_ = dict()
self.is_model_dependent = False
for estimator in self.ope_estimators:
self.ope_estimators_[estimator.estimator_name] = estimator
if isinstance(estimator, DM) or isinstance(estimator, DR):
self.is_model_dependent = True

def _create_estimator_inputs(
self,
Expand All @@ -102,9 +105,7 @@ def _create_estimator_inputs(
f"action_dist.ndim must be 3-dimensional, but is {action_dist.ndim}"
)
if estimated_rewards_by_reg_model is None:
logger.warning(
"`estimated_rewards_by_reg_model` is not given; model dependent estimators such as DM or DR cannot be used."
)
pass
elif isinstance(estimated_rewards_by_reg_model, dict):
for estimator_name, value in estimated_rewards_by_reg_model.items():
if not isinstance(value, np.ndarray):
Expand Down Expand Up @@ -171,6 +172,12 @@ def estimate_policy_values(
Dictionary containing estimated policy values by OPE estimators.
"""
if self.is_model_dependent:
if estimated_rewards_by_reg_model is None:
raise ValueError(
"When model dependent estimators such as DM or DR are used, `estimated_rewards_by_reg_model` must be given"
)

policy_value_dict = dict()
estimator_inputs = self._create_estimator_inputs(
action_dist=action_dist,
Expand Down Expand Up @@ -222,6 +229,12 @@ def estimate_intervals(
using nonparametric bootstrap procedure.
"""
if self.is_model_dependent:
if estimated_rewards_by_reg_model is None:
raise ValueError(
"When model dependent estimators such as DM or DR are used, `estimated_rewards_by_reg_model` must be given"
)

check_confidence_interval_arguments(
alpha=alpha,
n_bootstrap_samples=n_bootstrap_samples,
Expand Down
24 changes: 20 additions & 4 deletions obp/ope/meta_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from pandas import DataFrame
import seaborn as sns

from .estimators_continuous import BaseContinuousOffPolicyEstimator
from .estimators_continuous import (
BaseContinuousOffPolicyEstimator,
KernelizedDoublyRobust as KDR,
)
from ..types import BanditFeedback
from ..utils import check_confidence_interval_arguments

Expand Down Expand Up @@ -96,8 +99,11 @@ def __post_init__(self) -> None:
"action"
]
self.ope_estimators_ = dict()
self.is_model_dependent = False
for estimator in self.ope_estimators:
self.ope_estimators_[estimator.estimator_name] = estimator
if isinstance(estimator, KDR):
self.is_model_dependent = True

def _create_estimator_inputs(
self,
Expand All @@ -115,9 +121,7 @@ def _create_estimator_inputs(
"action_by_evaluation_policy must be 1-dimensional ndarray"
)
if estimated_rewards_by_reg_model is None:
logger.warning(
"`estimated_rewards_by_reg_model` is not given; model dependent estimators such as DM or DR cannot be used."
)
pass
elif isinstance(estimated_rewards_by_reg_model, dict):
for estimator_name, value in estimated_rewards_by_reg_model.items():
if not isinstance(value, np.ndarray):
Expand Down Expand Up @@ -186,6 +190,12 @@ def estimate_policy_values(
Dictionary containing estimated policy values by OPE estimators.
"""
if self.is_model_dependent:
if estimated_rewards_by_reg_model is None:
raise ValueError(
"When model dependent estimators such as DM or DR are used, `estimated_rewards_by_reg_model` must be given"
)

policy_value_dict = dict()
estimator_inputs = self._create_estimator_inputs(
action_by_evaluation_policy=action_by_evaluation_policy,
Expand Down Expand Up @@ -237,6 +247,12 @@ def estimate_intervals(
using nonparametric bootstrap procedure.
"""
if self.is_model_dependent:
if estimated_rewards_by_reg_model is None:
raise ValueError(
"When model dependent estimators such as DM or DR are used, `estimated_rewards_by_reg_model` must be given"
)

check_confidence_interval_arguments(
alpha=alpha,
n_bootstrap_samples=n_bootstrap_samples,
Expand Down
28 changes: 27 additions & 1 deletion tests/ope/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from obp.types import BanditFeedback
from obp.ope import OffPolicyEvaluation, BaseOffPolicyEstimator
from obp.ope import OffPolicyEvaluation, BaseOffPolicyEstimator, DirectMethod
from obp.utils import check_confidence_interval_arguments


Expand Down Expand Up @@ -310,6 +310,32 @@ def test_meta_post_init(synthetic_bandit_feedback: BanditFeedback) -> None:
)


def test_meta_estimated_rewards_by_reg_model_inputs(
synthetic_bandit_feedback: BanditFeedback,
) -> None:
"""
Test the estimate_policy_values/estimate_intervals functions wrt estimated_rewards_by_reg_model
"""
ope_ = OffPolicyEvaluation(
bandit_feedback=synthetic_bandit_feedback, ope_estimators=[DirectMethod()]
)

action_dist = np.zeros(
(synthetic_bandit_feedback["n_rounds"], synthetic_bandit_feedback["n_actions"])
)
with pytest.raises(ValueError):
ope_.estimate_policy_values(
action_dist=action_dist,
estimated_rewards_by_reg_model=None,
)

with pytest.raises(ValueError):
ope_.estimate_intervals(
action_dist=action_dist,
estimated_rewards_by_reg_model=None,
)


# action_dist, estimated_rewards_by_reg_model, description
invalid_input_of_create_estimator_inputs = [
(
Expand Down
27 changes: 27 additions & 0 deletions tests/ope/test_meta_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from obp.ope import (
ContinuousOffPolicyEvaluation,
BaseContinuousOffPolicyEstimator,
KernelizedDoublyRobust,
)
from obp.utils import check_confidence_interval_arguments

Expand Down Expand Up @@ -186,6 +187,32 @@ def test_meta_post_init(synthetic_continuous_bandit_feedback: BanditFeedback) ->
)


def test_meta_estimated_rewards_by_reg_model_inputs(
synthetic_bandit_feedback: BanditFeedback,
) -> None:
"""
Test the estimate_policy_values/estimate_intervals functions wrt estimated_rewards_by_reg_model
"""
kdr = KernelizedDoublyRobust(kernel="cosine", bandwidth=0.1)
ope_ = ContinuousOffPolicyEvaluation(
bandit_feedback=synthetic_bandit_feedback,
ope_estimators=[kdr],
)

action_by_evaluation_policy = np.zeros((synthetic_bandit_feedback["n_rounds"],))
with pytest.raises(ValueError):
ope_.estimate_policy_values(
action_by_evaluation_policy=action_by_evaluation_policy,
estimated_rewards_by_reg_model=None,
)

with pytest.raises(ValueError):
ope_.estimate_intervals(
action_by_evaluation_policy=action_by_evaluation_policy,
estimated_rewards_by_reg_model=None,
)


# action_by_evaluation_policy, estimated_rewards_by_reg_model, description
invalid_input_of_create_estimator_inputs = [
(
Expand Down

0 comments on commit 37ccdd7

Please sign in to comment.