From f9ec739bbc66da6f2a7a6bb1e781f53b70c53450 Mon Sep 17 00:00:00 2001 From: nmasahiro Date: Sat, 6 Feb 2021 21:29:52 +0900 Subject: [PATCH] refactor base dataset class --- obp/dataset/base.py | 25 ++++++++++--------------- obp/dataset/multiclass.py | 4 ++-- obp/dataset/synthetic.py | 4 ++-- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/obp/dataset/base.py b/obp/dataset/base.py index bc0ffe37..e4744ae7 100644 --- a/obp/dataset/base.py +++ b/obp/dataset/base.py @@ -5,7 +5,16 @@ from abc import ABCMeta, abstractmethod -class BaseRealBanditDataset(metaclass=ABCMeta): +class BaseBanditDataset(metaclass=ABCMeta): + """Base Class for Synthetic Bandit Dataset.""" + + @abstractmethod + def obtain_batch_bandit_feedback(self) -> None: + """Obtain batch logged bandit feedback.""" + raise NotImplementedError + + +class BaseRealBanditDataset(BaseBanditDataset): """Base Class for Real-World Bandit Dataset.""" @abstractmethod @@ -17,17 +26,3 @@ def load_raw_data(self) -> None: def pre_process(self) -> None: """Preprocess raw dataset.""" raise NotImplementedError - - @abstractmethod - def obtain_batch_bandit_feedback(self) -> None: - """Obtain batch logged bandit feedback.""" - raise NotImplementedError - - -class BaseSyntheticBanditDataset(metaclass=ABCMeta): - """Base Class for Synthetic Bandit Dataset.""" - - @abstractmethod - def obtain_batch_bandit_feedback(self) -> None: - """Obtain batch logged bandit feedback.""" - raise NotImplementedError diff --git a/obp/dataset/multiclass.py b/obp/dataset/multiclass.py index ca58e160..6c9ef8f7 100644 --- a/obp/dataset/multiclass.py +++ b/obp/dataset/multiclass.py @@ -11,12 +11,12 @@ from sklearn.model_selection import train_test_split from sklearn.utils import check_random_state, check_X_y -from .base import BaseSyntheticBanditDataset +from .base import BaseBanditDataset from ..types import BanditFeedback @dataclass -class MultiClassToBanditReduction(BaseSyntheticBanditDataset): +class MultiClassToBanditReduction(BaseBanditDataset): """Class for handling multi-class classification data as logged bandit feedback data. Note diff --git a/obp/dataset/synthetic.py b/obp/dataset/synthetic.py index 76e65ea3..770f9a40 100644 --- a/obp/dataset/synthetic.py +++ b/obp/dataset/synthetic.py @@ -9,13 +9,13 @@ from scipy.stats import truncnorm from sklearn.utils import check_random_state -from .base import BaseSyntheticBanditDataset +from .base import BaseBanditDataset from ..types import BanditFeedback from ..utils import sigmoid, softmax @dataclass -class SyntheticBanditDataset(BaseSyntheticBanditDataset): +class SyntheticBanditDataset(BaseBanditDataset): """Class for generating synthetic bandit dataset. Note