# Assignment 4 - Problem 3

## Job-Hopping and Wages-Utility Maximization

On a given day, you could be working on any one of $n$ jobs or unemployed. Therefore, the **state space** is $S = \{0, 1, ..., n\}$ where state 0 is for unemployed and the rest is for each of the $n$ jobs. The **action space** is $\mathcal{A} = \{A, D\}$, since you can either accept or decline an offered job. 

We have the following **transition probabilities** $\mathcal{P}(s,a,s')$:

$$\mathcal{P}(0, A, 0) = 0, \mathcal{P}(0, A, s') = p_{s'} \text{ for all } s' \in S \\ \text{ since if you are unemployed and accept job s', job-offer probability is} p_{s'}$$

$$\mathcal{P}(0, R, 0) = 1, \mathcal{P}(0, R, s') = 0 \text{ for all } s' \in S \\ \text{ since if you are unemployed and reject a job, you are again unemployed}$$ 

$$\mathcal{P}(s, R, s') = 0, \mathcal{P}(s, R, s) = 1 \text{ for all } s' \in S, \mathcal{P}(s, A, 0) = \alpha, \mathcal{P}(s, A, s) = (1-\alpha) \\ \text{ since if you are employed you don't have the option to reject, and you lose your job with prob } \alpha \text{ and stay at your current job with prob } 1-\alpha$$


We have the following **reward functions** $\mathcal{R}(s,a)$:

$$\mathcal{R}(0, A) = log(w_{1}*p_{1} + ... w_{n}*p_{n})$$
$$\mathcal{R}(0, R) = log(w_{0})$$
$$\mathcal{R}(s, A) = log(w_{s})$$
$$\mathcal{R}(s, R) = log(w_{s})$$

Finally, we can use the **Bellman Optimality equation**:
\begin{equation} 
     V(s) = \max_{a\in \mathcal{A}}\{\mathcal{R}(s,a) + \gamma \cdot \sum_{s'\in \mathcal{N}} \mathcal{P}(s,a,s') \cdot V(s') \}
\end{equation}

In [133]:
from dataclasses import dataclass, field
from typing import List, Callable, Mapping, Tuple, Dict, TypeVar
import numpy as np
from operator import itemgetter 


In [196]:
@dataclass
class WageMDP():
    gamma: float
    alpha: float
    states: List[int] #jobs

    wages: List[float]
    probs: List[float]
    actions: List[str] = field(default_factory= lambda:['A', 'R'])

    def trans_prob(self, state: S, action: str, next_state: S) -> float:    
        #unemployed
        if (state, action) == (0, 'A'):
            if next_state == 0:
                return 0
            return probs[next_state]
        elif (state, action) == (0, 'R'):
            if next_state == 0:
                return 1
            return 0
        #employed
        else:
            if action == 'A':
                if next_state == state:
                    return 1-alpha
                elif next_state == 0:
                    return alpha
                else:    
                    return 0
            if action == "R":
                if next_state == state:
                    return 1
        return 0
    

    def reward(self, state: int, action: str) -> float:
        if (state, action) == (0, 'A'):
            exp_wage = sum([self.wages[state]*self.probs[state] for state in self.states])
            return np.log(exp_wage)
        return np.log(wages[state])

    def value_iteration(self) -> List[float]:
        vf = [0 for state in states]
        pi = ['' for state in states]
        tol = 1e-6

        while True:
            next_vf = vf.copy()
            for s in self.states:
                next_vf_pi = max([(self.reward(s, a) + sum(self.gamma*self.trans_prob(s, a, next_s) * vf[next_s] 
                                                           for next_s in self.states), a)
                                  for a in self.actions], key = itemgetter(0))
                next_vf[s] = next_vf_pi[0]
                pi[s] = next_vf_pi[1]
            if np.linalg.norm(np.array(next_vf) - np.array(vf)) < tol:
                return vf, pi
            vf = next_vf
    

In [201]:
gamma: float = 0.1
alpha: float = 0.1
states: List[int] = [0, 1, 2, 3, 4, 5]
wages: List[float] = [2, 2, 100, 100, 1, 3]
probs: List[float] = [0, 0.1, 0.2, 0.4, 0.1, 0.2]
    
mdp = WageMDP(
    gamma=gamma,
    alpha=alpha,
    states = states,
    wages = wages,
    probs = probs
)

In [202]:
mdp.value_iteration()

([4.449962479933997,
  0.8106006537576071,
  5.116855250523415,
  5.116855250523415,
  0.04890049177196922,
  1.2561666853124358],
 ['A', 'A', 'R', 'R', 'A', 'A'])

In [209]:
#policies don't really make sense... need to change the state space to separate out unemployment status..?