In [1]:
import numpy as np
from matplotlib import pyplot as plt
from uuid import uuid4, UUID
from typing import Dict, List, Tuple, Literal, Callable, Any
from typing_extensions import override
from dataclasses import dataclass
from abc import ABC, abstractmethod

In [2]:
class NamedHashable(ABC):
    name: str
    id: UUID

    def __init__(self, name: str, id_gen: Callable[..., UUID] | None = None) -> None:
        if id_gen is not None:
            self.id = id_gen()
        else:
            self.id = uuid4()
        
        self.name = name
    
    def __hash__(self) -> int:
        return self.id.int

    @abstractmethod
    def __repr__(self) -> str:
        pass

In [29]:
class State(NamedHashable, ABC):
    def __init__(self, name: str, **kwargs) -> None:
        super().__init__(name=name)

    @abstractmethod
    def _repr_metadata(self) -> str | None:
        pass

    @override
    def __repr__(self) -> str:
        metadata = self._repr_metadata() or ""
        return f"[`{self.name}` {f'> {metadata}' if metadata is not None else ''}]"


class Action(NamedHashable):
    def __init__(self, name: str) -> None:
        super().__init__(name=name)

    @override
    def __repr__(self) -> str:
        return f"<`{self.name}`>"

In [4]:
ValueFunction = Dict[State, float]

In [5]:
class StudentAction(Action):
    difficulty: Literal["EASY", "MED", "HARD"]

    def __init__(self, name: str, difficulty: Literal["EASY", "MED", "HARD"]) -> None:
        super().__init__(name=name)
        self.difficulty = difficulty

student_action1 = StudentAction(name="study", difficulty="MED")

{
    student_action1: 0.5
}

{<`study`>: 0.5}

In [38]:
TransitionProbaMatrix = np.ndarray
ActionRewardFunction = np.ndarray
RewardFunction = Dict[Action, ActionRewardFunction]
ValueFunction = Dict[State, float]
Policy = Dict[State, Dict[Action, float]]

In [41]:
class MarkovDecisionProcess:
    _states: List[State]
    _actions: List[Action]
    _transition_proba: Dict[Action, TransitionProbaMatrix]
    _reward_function: RewardFunction
    _gamma: float
    _values: ValueFunction

    def __init__(self, name: str, states: List[State], actions: List[Action], gamma: float) -> None:
        self.name = name
        self._states = states
        self._actions = actions
        self._transition_proba = {
            action: np.zeros(shape=(len(states), len(self._states)))
            for action in actions
        }
        self._reward_function = {
            action: np.zeros(shape=(len(states), len(self._states)))
            for action in actions
        }
        self._values = {
            state: np.random.randn()
            for state in states
        }
        self._gamma = gamma

    @property
    def states(self) -> List[State]:
        return self._states

    @property
    def actions(self) -> List[State]:
        return self._actions

    @property
    def transition_proba(self) -> List[State]:
        return self._transition_proba

    @property
    def reward(self) -> List[State]:
        return self._reward_function

    @property
    def values(self) -> List[State]:
        return self._values

    def _get_state_index(self, state: State | int) -> int:
        if isinstance(state, State):
            return self._states.index(state)
        else:
            return state

    def _get_action_index(self, action: Action | int) -> int:
        if isinstance(action, Action):
            return self._actions.index(action)
        else:
            return action

    def set_reward(
        self,
        src_state: State | int,
        action: Action | int,
        dst_state: State | int,
        reward: float,
    ) -> None:
        src_inx: int = self._get_state_index(state=src_state)
        dst_inx: int = self._get_state_index(state=dst_state)
        self._reward_function[action][src_inx, dst_inx] = reward

    def set_state_reward(self, state: State | int, reward: float) -> None:
        inx: int = self._get_state_index(state=state)
        for action in self.actions:
            self._reward_function[action][inx, :] = reward

    def set_transition_proba(
        self,
        src_state: State | int,
        action: Action | int,
        dst_state: State | int,
        proba: float,
    ) -> None:
        src_inx: int = self._get_state_index(state=src_state)
        dst_inx: int = self._get_state_index(state=dst_state)
        self._transition_proba[action][src_inx, dst_inx] = proba

    def add_terminal_state(self, state: State | int) -> None:
        inx: int = self._get_state_index(state=state)
        for action in self._actions:
            self._reward_function[action][inx, :] = 0.0

    def solve_policy_iter(self) -> Tuple[ValueFunction, Policy]:
        pass

    def solve_value_iter(self) -> Tuple[ValueFunction, Policy]:
        pass

In [42]:
class StudentState(State):
    def __init__(self, name: str) -> None:
        super().__init__(name=name)
    
    @override
    def _repr_metadata(self) -> str | None:
        return None

In [32]:
class StudentMarkovDecisionProcess(MarkovDecisionProcess):
    def __init__(self, states: List[State], actions: List[Action], rewards: List[float]) -> None:
        

<__main__.MarkovDecisionProcess at 0x7ac2045342f0>

In [None]:
hostel: StudentState = StudentState(name="hostel")
canteen: StudentState = StudentState(name="canteen")
academic_building: StudentState = StudentState(name="academic_building")

eat: Action = Action(name="eat")
attend_class: Action = Action(name="attend_class")

mdp = MarkovDecisionProcess()

In [40]:
a = np.random.randn(3, 4)
a[1, :] = 0
a

array([[-0.11287461,  0.08922628, -0.33319502, -1.26860087],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.24872478,  0.02246836,  0.96927444,  0.76323481]])