In [1]:
import numpy as np
from tqdm import tqdm
import matplotlib.pylab as plt
from typing import Optional,Tuple
np.seterr(divide='ignore')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [2]:
def input_checker(func):
    """
    A decorator function that checks the validity of input arguments.

    Args:
        func: The function to be decorated

    Returns:
        The decorated function

    Raises:
        AssertionError: If any of the input arguments is invalid.
    """
    def wrapper(std_: np.ndarray = None,
                 mean_: np.ndarray = None,
                 k: int = None,
                 n: int = None,
                 alpha: float = None
                 ):
        # Check the format of 'mean_'
        assert isinstance(mean_, (np.ndarray, list)), "invalid format of 'mean_'"

        # Check the format of 'std_'
        assert isinstance(std_, (np.ndarray, list)), "invalid format of 'std_'"

        # Check if 'k' is an integer
        assert isinstance(k, int), 'k needs to be an integer'

        # Check if 'k' is greater than 1
        assert k > 1, 'the value of k should be more than 1, {0} is given!'.format(k)

        # Check consistency of 'k', 'mean_', and 'std_'
        assert k == len(mean_) == len(std_), 'The length of std_, mean, and k are not consistent'

        assert alpha > 0 , 'the value of alpha needs to be positive. {0} is given'.format(alpha)
            
        # checking the step size
        assert isinstance(alpha, (type(None),float)), 'the format of alpha is incorrect'

        # Call the original function with the validated arguments
        return func(std_, mean_, k, n, alpha)
    return wrapper


In [None]:
@input_checker
class EpsilonGreedy:
    def __init__(self,
                 std_: np.ndarray = None,
                 mean_: np.ndarray = None,
                 k: int = None,
                 n: int = None,
                 alpha: float = None,
                 c:float = None,
                 method: str = 'mean') -> None:
        """
        Epsilon-Greedy class for the bandit-k problem.

        Args:
            std_ (np.ndarray): Array of standard deviations for each arm. Defaults to None.
            mean_ (np.ndarray): Array of means for each arm. Defaults to None.
            epsilon (float): Exploration rate. Defaults to None.
            k (int): Number of arms. Defaults to None.
            n (int): Number of time steps. Defaults to None.
            alpha (float): Step size parameter for constant step size update. Defaults to None.
        """

        self.std_ = std_
        self.mean_ = mean_
        self.k = k
        self.n = n
        self.alpha = alpha
        self.Q = np.zeros((self.k))
        self.H = np.zeros((self.k))

        self.probs = np.zeros((self.k))
        self.A = np.random.choice(range(self.k))
        
        self.probs = self.logits2prob()
        self.Q_hist = np.zeros((self.k, n))
        self.A_hist = np.zeros((self.n, )) * np.nan
        self.R = 0
        self.action_counter = np.zeros((self.k, ))

    def logits2prob(self)->np.ndarray:
        elements_exp = np.exp(self.H)
        return elements_exp / elements_exp.sum()
        
    
    def update_preference(self)->None:
        mask = True

        return

    def run(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Runs the epsilon-greedy algorithm for the bandit-k problem.

        Returns:
            Tuple: Tuple containing the action-value history, action history, and action counter.
        """
        for self.itr in range(self.n):
            if self.itr == 0:
                pass 
            else:
                self.A = self.selector()
            
            self.mask = np.where(self.A, True, False)
            self.action_counter[self.A] += 1
            self.R = self.reward(ind=self.A)
            self.Q[self.A] = self.Q[self.A] + self.weight() * (self.R - self.Q[self.A])
            self.Q_hist[:, self.itr:self.itr+1] = self.Q[:, None]
            self.A_hist[self.itr] = self.A
        return self.Q_hist, self.A_hist, self.action_counter


In [3]:
list(range(4))

[0, 1, 2, 3]