In [2]:
import numpy as np

In [19]:
"""
state: [c1, c2, c3] -> cards in the agents hand
        c4 -> card in the table
        w -> who won the first round
compress the state as single integer -> 
suppose that instead of 3 cards, it represents 3 coordinates in 3d space,
we can ravel it as a integer, helping a bit with coding
considering: 
4p -> 0; 7c -> 1; Ae -> 2; 7o -> 3; 3 -> 4; 2 -> 5; A -> 6 ...
4 -> 13
w is -1 if agent lost first round, 0 if draw/undecided, 1 if won

action: c -> card to be played
c = 0, 1 or 2, corresponding to the index of the card to be played

i will be considering that env.step returns a done variable, that is True
when the episode is over

TODO:
adopt convention: always sort the list of player cards, to avoid
repeating states !! ex: [11, 10, 3] is the same as [3, 11, 10]
also, leaving 0 always at the end of the list is required, and helps a lot
"""

class SarsaAgent:
    def __init__(self, eps: float, alpha: float, gamma: float, env) -> None:
        self.eps = eps
        self.alpha = alpha
        self.gamma = gamma
        self.env = env
        self.state_action_shape = (14, 14, 14, 14, 3, 3)
        # dictionary to store Q values
        self.Q = {}

    def _convert_state_action(self, state: list[int], action: int) -> int:
        state_action = [idx for idx in state]
        state_action.append(action)
        state_action[4] += 1 # converts the -1 to 0, so that we can use it here
        return np.ravel_multi_index(state_action, self.state_action_shape)
    
    def _choose_greedy_action(self, state: list[int]) -> int:
        action = 0
        # add possible action if the card isnt 0, and it is the first 3 cards in the list
        possible_actions = [i for i, idx in enumerate(state) if idx != 0 and i < 3]
        # perform exploration
        if np.random.uniform(0, 1) < self.eps:
            return np.random.choice(possible_actions)
        # perfom exploitation
        else:
            Q_values = []
            for action in possible_actions:
                compressed = self._convert_state_action(state, action)
                # get the Q value, or return 0 if it doesnt have value yet
                Q_values.append(self.Q.get(compressed, 0))
            return np.argmax(Q_values)

    def run(self, n_episodes: int) -> None:
        reward = 0
        for _ in range(n_episodes):
            # obtain starting states and action
            state_1 = self.env.reset()
            action_1 = self._choose_greedy_action(state_1)
            # loops through the episode
            while True:
                # get next state, reward and next action
                state_2, reward, done = self.env.step(action_1)
                action_2 = self._choose_greedy_action(state_2)

                sa1_compressed = self._convert_state_action(state_1, action_1)
                sa2_compressed = self._convert_state_action(state_2, action_2)

                q1 = self.Q.get(sa1_compressed, 0)
                q2 = self.Q.get(sa2_compressed, 0)
                # updates the Q value
                self.Q[sa1_compressed] = q1 + self.alpha * (reward + self.gamma*q2 - q1)

                state_1 = state_2
                action_1 = action_2

                if done:
                    break

In [22]:
agent = SarsaAgent(0.1, 1)
agent._convert_state_action([11, 10, 2, 4, -1], 0)
agent.Q[289584] = 3
select_action = [agent._choose_greedy_action([11, 10, 2, 4, -1]) for _ in range(10)]
select_action

[0, 2, 1, 0, 0, 0, 0, 0, 0, 0]