Skip to content

Commit

Permalink
Merge pull request #62 from nmasahiro/refactor-dataset-base
Browse files Browse the repository at this point in the history
refactor base dataset class
  • Loading branch information
usaito committed Feb 6, 2021
2 parents 36f544a + f9ec739 commit 01aef32
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
25 changes: 10 additions & 15 deletions obp/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from abc import ABCMeta, abstractmethod


class BaseRealBanditDataset(metaclass=ABCMeta):
class BaseBanditDataset(metaclass=ABCMeta):
"""Base Class for Synthetic Bandit Dataset."""

@abstractmethod
def obtain_batch_bandit_feedback(self) -> None:
"""Obtain batch logged bandit feedback."""
raise NotImplementedError


class BaseRealBanditDataset(BaseBanditDataset):
"""Base Class for Real-World Bandit Dataset."""

@abstractmethod
Expand All @@ -17,17 +26,3 @@ def load_raw_data(self) -> None:
def pre_process(self) -> None:
"""Preprocess raw dataset."""
raise NotImplementedError

@abstractmethod
def obtain_batch_bandit_feedback(self) -> None:
"""Obtain batch logged bandit feedback."""
raise NotImplementedError


class BaseSyntheticBanditDataset(metaclass=ABCMeta):
"""Base Class for Synthetic Bandit Dataset."""

@abstractmethod
def obtain_batch_bandit_feedback(self) -> None:
"""Obtain batch logged bandit feedback."""
raise NotImplementedError
4 changes: 2 additions & 2 deletions obp/dataset/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state, check_X_y

from .base import BaseSyntheticBanditDataset
from .base import BaseBanditDataset
from ..types import BanditFeedback


@dataclass
class MultiClassToBanditReduction(BaseSyntheticBanditDataset):
class MultiClassToBanditReduction(BaseBanditDataset):
"""Class for handling multi-class classification data as logged bandit feedback data.
Note
Expand Down
4 changes: 2 additions & 2 deletions obp/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from scipy.stats import truncnorm
from sklearn.utils import check_random_state

from .base import BaseSyntheticBanditDataset
from .base import BaseBanditDataset
from ..types import BanditFeedback
from ..utils import sigmoid, softmax


@dataclass
class SyntheticBanditDataset(BaseSyntheticBanditDataset):
class SyntheticBanditDataset(BaseBanditDataset):
"""Class for generating synthetic bandit dataset.
Note
Expand Down

0 comments on commit 01aef32

Please sign in to comment.