import pomdp_py
import random


class State(pomdp_py.State):

    def __init__(self, name):
        self.name = name

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        if isinstance(other, State):
            return self.name == other.name
        return False

    def __str__(self):
        return self.name

    def __repr__(self):
        return f"State: {self.name}"


class Action(pomdp_py.Action):

    def __init__(self, name):
        self.name = name

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        if isinstance(other, Action):
            return self.name == other.name
        return False

    def __str__(self):
        return self.name

    def __repr__(self):
        return f"Action: {self.name}"


class Observation(pomdp_py.Observation):

    def __init__(self, name):
        self.name = name

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        if isinstance(other, Observation):
            return self.name == other.name
        return False

    def __str__(self):
        return self.name

    def __repr__(self):
        return f"Observation: {self.name}"


class ObservationModel(pomdp_py.ObservationModel):

    def probability(self, observation, next_state, action):
        if observation.name == next_state.name:
            return 1.0
        else:
            return 0.0

    def sample(self, next_state, action):
        return Observation(next_state.name)

    def get_all_observations(self):
        return [Observation(s) for s in {"home", "wait-room", "train",
                                         "light-traffic", "medium-light",
                                         "heavy-traffic", "work"}]


class TransitionModel(pomdp_py.TransitionModel):
    def probability(self, next_state, state, action):
        if state.name == "home":
            if action.name == "rail":
                if next_state.name == "wait-room":
                    return 0.1
                elif next_state.name == "train":
                    return 0.9
                else:
                    return 0.0
            elif action.name == "car":
                if next_state.name == "light-traffic":
                    return 0.2
                elif next_state.name == "medium-traffic":
                    return 0.7
                elif next_state.name == "heavy-traffic":
                    return 0.1
                else:
                    return 0.0
            elif action.name == "cycle":
                if next_state.name == "work":
                    return 1.0
                else:
                    return 0.0
            else:
                return 0.0

        if state.name == "wait-room":
            if action.name == "wait":
                if next_state.name == "wait-room":
                    return 0.1
                elif next_state.name == "train":
                    return 0.9
                else:
                    return 0.0
            elif action.name == "go-home":
                if next_state.name == "home":
                    return 1.0
                else:
                    return 0.0
            else:
                return 0.0

        if state.name == "train":
            if action.name == "relax":
                if next_state.name == "work":
                    return 1.0
                else:
                    return 0.0
            else:
                return 0.0

        if (state.name == "light-traffic" or state.name == "medium-traffic"
                or state.name == "heavy-traffic"):
            if action.name == "drive":
                if next_state.name == "work":
                    return 1.0
                else:
                    return 0.0
            else:
                return 0.0

        if state.name == "work":
            if next_state.name == "work":
                return 1.0
            else:
                return 0.0

        return 0.0

    def sample(self, state, action):
        if state.name == "home":
            if action.name == "rail":
                if random.uniform(0, 1) < 0.1:
                    return State("wait-room")
                else:
                    return State("train")
            if action.name == "car":
                sample = random.uniform(0, 1)
                if sample < 0.2:
                    return State("light-traffic")
                elif 0.2 <= sample < 0.9:
                    return State("medium-traffic")
                else:
                    return State("heavy-traffic")
            if action.name == "cycle":
                return State("work")

        if state.name == "wait-room":
            if action.name == "wait":
                if random.uniform(0, 1) < 0.1:
                    return State("wait-room")
                else:
                    return State("train")
            if action.name == "go-home":
                return State("home")

        if state.name == "train":
            if action.name == "relax":
                return State("work")

        if (state.name == "light-traffic" or state.name == "medium-traffic"
                or state.name == "heavy-traffic"):
            if action.name == "drive":
                return State("work")

        if state.name == "work":
            return State("work")

        return state

    def get_all_states(self):
        return [State(s) for s in {"home", "wait-room", "train",
                                   "light-traffic", "medium-light",
                                   "heavy-traffic", "work"}]


class RewardModel(pomdp_py.RewardModel):

    def sample(self, state, action, next_state):
        if state.name == "home" and action.name == "rail":
            return -2
        elif state.name == "home" and action.name == "car":
            return -1
        elif state.name == "home" and action.name == "cycle":
            return -45
        elif state.name == "wait-room" and action.name == "wait":
            return -3
        elif state.name == "wait-room" and action.name == "go-home":
            return -2
        elif state.name == "train" and action.name == "relax":
            return -35
        elif state.name == "light-traffic" and action.name == "drive":
            return -20
        elif state.name == "medium-traffic" and action.name == "drive":
            return -30
        elif state.name == "heavy-traffic" and action.name == "drive":
            return -70
        else:
            return -1000


class PolicyModel(pomdp_py.RolloutPolicy):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def sample(self, state):
        return random.sample(self.get_all_actions(state), 1)[0]

    def rollout(self, state, history=None):
        return self.sample(state)

    def get_all_actions(self, state=None, history=None):
        if state is None or state.name == "work":
            return [Action(a) for a in {"rail", "car", "cycle", "wait",
                                        "go-home", "relax", "drive"}]
        else:
            if state.name == "home":
                return [Action(a) for a in {"rail", "car", "cycle"}]
            if state.name == "wait-room":
                return [Action(a) for a in {"wait", "go-home"}]
            if state.name == "train":
                return [Action("relax")]
            if state.name == "light-traffic":
                return [Action("drive")]
            if state.name == "medium-traffic":
                return [Action("drive")]
            if state.name == "heavy-traffic":
                return [Action("drive")]


def main():
    init_true_state = State("home")
    init_belief = pomdp_py.Histogram({State("home"): 1.0,
                                      State("wait-room"): 0.0,
                                      State("train"): 0.0,
                                      State("light-traffic"): 0.0,
                                      State("medium-traffic"): 0.0,
                                      State("heavy-traffic"): 0.0,
                                      State("work"): 0.0})

    agent = pomdp_py.Agent(init_belief, PolicyModel(), TransitionModel(),
                           ObservationModel(), RewardModel(),)
    env = pomdp_py.Environment(init_true_state, TransitionModel(),
                               RewardModel())
    agent.set_belief(init_belief, prior=True)

    planner = pomdp_py.ValueIteration(horizon=1, discount_factor=0.9)

    nsteps = 2
    for i in range(nsteps):
        action = planner.plan(agent)

        print("==== Step %d ====" % (i + 1))
        print("True state:", env.state)
        print("Belief:", agent.cur_belief)
        print("Action:", action)

        reward = env.state_transition(action, execute=True)
        print("Reward:", reward)

        observation = agent.observation_model.sample(env.state, action)
        print(">> Observation: %s" % observation)

        agent.update_history(action, observation)
        planner.update(agent, action, observation)
        new_belief = pomdp_py.update_histogram_belief(
            agent.cur_belief,
            action,
            observation,
            agent.observation_model,
            agent.transition_model)
        agent.set_belief(new_belief)


if __name__ == "__main__":
    main()