In [42]:
from dataclasses import dataclass
from typing import Mapping, Dict
from rl.distribution import Categorical
from rl.markov_process import Transition, FiniteMarkovProcess
from scipy.stats import poisson
import numpy as np
import itertools

In [43]:
@dataclass(frozen=True)
class SALState:
    position: int

    def SAL_position(self) -> int:
        return self.position

In [52]:
class SimpleSALMPFinite(FiniteMarkovProcess[SALState]):

    def __init__(
        self,
        initial_position: int
    ):
        self.initial_position = initial_position
        super().__init__(self.get_transition_map())

    def get_transition_map(self) -> Transition[SALState]:
        d: Dict[SALState, Categorical[SALState]] = {}
      
        # state 1 transitions
        state = SALState(1)
        state_probs_map: Mapping[SALState, float] = {
            SALState(2): 0.5,
            SALState(3): 0.5    
        }
        d[state] = Categorical(state_probs_map)

        # state 2 transitions
        state = SALState(2)
        state_probs_map: Mapping[SALState, float] = {
            SALState(1): 0.5,
            SALState(3): 0.5    
        }
        d[state] = Categorical(state_probs_map)
        
        # state 3 transitions
        state = SALState(3)
        state_probs_map: Mapping[SALState, float] = {
            SALState(3): 1.0   
        }
        d[state] = Categorical(state_probs_map)
            
        return d
    
    def next_state(self, state):
        return self.get_transition_map()[state].sample()

In [53]:
def simulation(process, start_state):
    state = start_state
    while True:
        yield state
        state = process.next_state(state)

In [54]:
def traces(time_steps: int, num_traces: int) -> np.ndarray:
    # instantiate MDP
    process = SimpleSALMPFinite(initial_position = 1)
    
    # instantiate initial state
    start_state = SALState(process.initial_position)
    
    
    return np.vstack([
        np.fromiter((s.position for s in itertools.islice(
            simulation(process, start_state),
            time_steps + 1
        )), float) for _ in range(num_traces)])


In [55]:
initial_position = 1

In [56]:
si_mp = SimpleSALMPFinite(
    initial_position = initial_position
)

In [57]:
print("Transition Map")
print("--------------")
print(si_mp)

Transition Map
--------------
From State SALState(position=1):
  To State SALState(position=2) with Probability 0.500
  To State SALState(position=3) with Probability 0.500
From State SALState(position=2):
  To State SALState(position=1) with Probability 0.500
  To State SALState(position=3) with Probability 0.500
From State SALState(position=3):
  To State SALState(position=3) with Probability 1.000



In [58]:
print("Stationary Distribution")
print("-----------------------")
si_mp.display_stationary_distribution()

Stationary Distribution
-----------------------
{SALState(position=3): 1.0,
 SALState(position=2): 0.0,
 SALState(position=1): 0.0}


In [60]:
print("Run traces")
T = 5
num_traces = 5

tr = traces(T, num_traces)

print(tr)

Run traces
[[1. 2. 3. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3.]
 [1. 3. 3. 3. 3. 3.]
 [1. 2. 1. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3.]]
