# This cell has two classes:
1. `EnvChannel`: Defines the environment
2. `ControlAgent`: Defines the control agent

# Parameters:
1. ```alpha```: Learning rate of agent
2. `gamma`: Discount factor
3. `epsilon`: for epsilon-greedy policy
4. `num_states`: Number of states in system, currently three(Good, Medium, Bad)
5. `d1`: Delay threshold for Good state and medium state
6. `d2`: Delay threshold for Medium and bad state

# Return:
1. ```action (int)```: -1: Decrease Resolution; 0: No change; 1: Increase Resolution
2. `state (int)`: 0: Bad; 1: Medium; 2: Good


# Arguments:
1. `state_list`: List of integer states. 
2. `avg_confidence`: (Float) Average confidence 
    
# Sample use:

Initialize ControlAgent class:
```
agent = ControlAgent(d1=10,d2=20)
```
To train and get optimal actions
```action = agent.get_signal(delay_list, avg_confidence)```

To get random actions:
```action = agent.get_signal(delay_list, avg_confidence, random_actions=True)```

# To Do:
1. ~~Instead of taking state_list directly, estimate the states indirectly from packet delay.~~ **Done**
2. ~~Update the function get_delay_factor.~~ **Done** 


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import sys


class EnvChannel:
    def __init__(self, num_states=3, d1=.01, d2=.02):

        self.num_states = num_states
        self.states = np.arange(self.num_states)
        self.curr_state = np.random.choice(self.states)
        self.prev_state = -1
        self.reward = 0
        self.action = 0
        self.num_actions = 3
        self.valid_actions = [-1, 0, 1]  # Reduce, No change, Increase
        self.d1 = d1
        self.d2 = d2
        self.avg_delay = 0

    def sample_action(self):
        return np.random.randint(self.num_actions)

    def step(self, action, partial_reward):
        self.action = action
        self.reward = -10 * self.avg_delay + partial_reward
        return self.reward

    def estimate_state(self, delay_list):
        self.avg_delay = np.sum(delay_list) / np.max(np.shape(delay_list))
        if self.avg_delay <= self.d1:
            self.curr_state = 2  # Good state
        elif self.avg_delay <= self.d2:
            self.curr_state = 1  # Medium state
        else:
            self.curr_state = 0  # Bad state
        return self.curr_state

    def reset(self):
        self.__init__(num_states=self.num_states)
        return self.curr_state


class ControlAgent:
    def __init__(self,
                 d1,
                 d2,
                 alpha=0.1,
                 gamma=.99,
                 epsilon=.95,
                 num_states=3,
                 random_actions=False):
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.env = EnvChannel(num_states, d1, d2)
        self.q_table = np.zeros([self.env.num_states, self.env.num_actions])
        self.all_epochs = []
        self.penalties = 0
        self.iteration_i = 0
        self.prev_state = None
        self.prev_avg_delay = 0
        self.random_actions = random_actions

    def get_delay_factor(self):
        # This should be function of delay
        #         return np.average(delays)
        return self.avg_delay

    def get_signal(self, delay_list, avg_confidence):

        state = self.env.estimate_state(delay_list)
        self.iteration_i += 1

        if self.iteration_i == 1:
            self.prev_state = state
            action = self.env.sample_action()

        else:
            #delay_factor = self.get_delay_factor()
            reward = self.env.step(self.prev_action, avg_confidence)

            if not self.random_actions:
                old_qvalue = self.q_table[self.prev_state, self.prev_action]
                next_max = np.argmax(self.q_table[state, :])

                new_qvalue = (1 - self.alpha) * old_qvalue + \
                    self.alpha * (reward + self.gamma * next_max)
                self.q_table[self.prev_state, self.prev_action] = new_qvalue

                self.penalties += reward

                if np.random.uniform(0, 1) < self.epsilon:
                    action = self.env.sample_action()  # Explore action space

                else:
                    # Exploit learned values
                    action = np.argmax(self.q_table[state, :])
            else:
                action = self.env.sample_action()  # Explore action space

            if self.iteration_i % 100 == 0:
                with open(f"data/iteration_{self.iteration_i}", "wb") as fp:
                    pickle.dump([self.q_table, self.penalties], fp)

        self.prev_state = state
        self.prev_action = action
        #         self.prev_avg_delay = self.avg_delay
        action = self.env.valid_actions[action]
        #         print(f"Action is: {action}")
        return action, state