Skip to content

Commit

Permalink
refactor base dataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
nomuramasahir0 committed Feb 6, 2021
1 parent 36f544a commit f9ec739
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
25 changes: 10 additions & 15 deletions obp/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from abc import ABCMeta, abstractmethod


class BaseRealBanditDataset(metaclass=ABCMeta):
class BaseBanditDataset(metaclass=ABCMeta):
"""Base Class for Synthetic Bandit Dataset."""

@abstractmethod
def obtain_batch_bandit_feedback(self) -> None:
"""Obtain batch logged bandit feedback."""
raise NotImplementedError


class BaseRealBanditDataset(BaseBanditDataset):
"""Base Class for Real-World Bandit Dataset."""

@abstractmethod
Expand All @@ -17,17 +26,3 @@ def load_raw_data(self) -> None:
def pre_process(self) -> None:
"""Preprocess raw dataset."""
raise NotImplementedError

@abstractmethod
def obtain_batch_bandit_feedback(self) -> None:
"""Obtain batch logged bandit feedback."""
raise NotImplementedError


class BaseSyntheticBanditDataset(metaclass=ABCMeta):
"""Base Class for Synthetic Bandit Dataset."""

@abstractmethod
def obtain_batch_bandit_feedback(self) -> None:
"""Obtain batch logged bandit feedback."""
raise NotImplementedError
4 changes: 2 additions & 2 deletions obp/dataset/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state, check_X_y

from .base import BaseSyntheticBanditDataset
from .base import BaseBanditDataset
from ..types import BanditFeedback


@dataclass
class MultiClassToBanditReduction(BaseSyntheticBanditDataset):
class MultiClassToBanditReduction(BaseBanditDataset):
"""Class for handling multi-class classification data as logged bandit feedback data.
Note
Expand Down
4 changes: 2 additions & 2 deletions obp/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from scipy.stats import truncnorm
from sklearn.utils import check_random_state

from .base import BaseSyntheticBanditDataset
from .base import BaseBanditDataset
from ..types import BanditFeedback
from ..utils import sigmoid, softmax


@dataclass
class SyntheticBanditDataset(BaseSyntheticBanditDataset):
class SyntheticBanditDataset(BaseBanditDataset):
"""Class for generating synthetic bandit dataset.
Note
Expand Down

0 comments on commit f9ec739

Please sign in to comment.