# Piero Pettenà - RL project  

In [None]:
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import random
import math

In [None]:
# Gets the distance class of a given distance, given the size of the grid L
# distance 0 -> distance less than L/10
# distance 1 -> distance between L/10 and L/5
# distance 2 -> distance between L/5 and L
def get_dist_class(dist: float, L: int):
    if dist < L / 10:
        dist = 0
    elif dist < L / 5:
        dist = 1
    elif dist <= L:
        dist = 2
    else:
        print(f"Error: distance out of bounds: dist={dist}")
        return -1
    return dist


# Gets the time class of a given time, given the budget of time B
# time 0 -> time less than B/4
# time 1 -> time between B/4 and B/2
# time 2 -> time between B/2 and 3*B/4
# time 3 -> time between 3*B/4 and B
def get_time_class(time: int, B: int):
    if time < B / 4:
        time = 0
    elif time < B / 2:
        time = 1
    elif time < 3 * B / 4:
        time = 2
    elif time <= B:
        time = 3
    else:
        print(f"Error: time out of bounds: time={time}")
        return -1
    return time

# Function that discretizes the state. It works with SQUARE GRIDS
def discretize_state(grid_size: int, budget: int, ddist: float, rem_time: int):
    dist = get_dist_class(dist=ddist, L=grid_size)
    time = get_time_class(time=rem_time, B=budget)
    return dist, time

### Definition of a state
1. Distance of agent from datum
2. Direction of search movement
3. Remaining time to find target

### Definition of the q function
1. Distance of agent from datum
2. Direction of search movement
3. Remaining time to find target
4. Possible action

In [None]:
# TYPICAL (GRID)WORLD


class World:
    """World environment class. It contains the grid, the datum and
    goal positions, info on the current and 'explored grid' which contains
    the last moment each block has been visited.

    NOTE: should add a variable that improves visibility of adjacent cells

    Attributes:
        grid:                   grid of the world. cell contains -1 if not explored, -2 if explored, 10 if goal
        goal:                   position of target in grid
        datum:                  position of datum (reference) in grid
        current:                current offset (future implementation)
        actions:                List of action that the agent can take. I want these to be part of the environment
        randomExplorerFlag:     if true, action is picked at random each time
        rem_time:               remaining time to find the target
        budget:                 initial budget of time to find the target (will not be updated)
        states:                 states q values of the table
    """

    def __init__(self, Ly: int=20, Lx: int =20, goal: np.ndarray =np.array([0, 0]), rem_time: int =1000):
        self.goal = goal
        if np.array_equal(self.goal, [0,0]):
            self.goal = np.array([np.random.randint(Lx), np.random.randint(Ly)])

        self.grid = np.full(
            shape=(Lx, Ly), fill_value= -1
        )  # Every cell which has not been explored contains -1

        self.grid[goal[0], goal[1]] = 10  # Goal cell contains 10
        self.actions = np.array(
            [[1, 0], [-1, 0], [0, 1], [0, -1]]
        )  # Actions = [Up, Down, Right, Left]
        self.randomExplorerFlag = True
        self.datum = np.array([Lx // 2, Ly // 2])
        self.budget = rem_time
        self.rem_time = rem_time
        self.current = np.array([0, 0])  # maybe for addition of current
        # state is now defined as a 3d numpy array of (distance, dir, time)
        self.states = np.zeros(
            shape=(3, len(self.actions), 4)
        )  # ***************************************************************************
        # Convert the "shape()" parameter to a variable that we can choose
        # i.e. try to generalize this code for different number of distance classes,
        # directions and time classes


    # Function that simply plots the grid
    def plot_world(self):
        cmap = plt.get_cmap("viridis")  # You can choose a different colormap if you prefer

        # Define the colors for each value in the array
        colors = {-2: "gray", -1: "white", 10: "red"}

        # Create a figure and axis
        fig, ax = plt.subplots()

        # Plot each cell with the corresponding color
        for i in range(self.grid.shape[0]):
            for j in range(self.grid.shape[1]):
                value = self.grid[i, j]
                color = colors.get(
                    value, "white"
                )  # Default to white if the value is not in the colors dictionary
                ax.add_patch(
                    plt.Rectangle((j, i), 1, 1, fill=True, color=color, edgecolor="black")
                )

                # Display the cell value in the center
                ax.text(
                    j + 0.5,
                    i + 0.5,
                    str(value),
                    color="black",
                    ha="center",
                    va="center",
                    fontsize=10,
                )

        # Set axis limits
        ax.set_xlim(0, self.grid.shape[1])
        ax.set_ylim(0, self.grid.shape[0])

        # Set major ticks to be at the center of cells
        ax.set_xticks(np.arange(0, self.grid.shape[1] + 1, 1))
        ax.set_yticks(np.arange(0, self.grid.shape[0] + 1, 1))

        # Remove minor ticks
        ax.set_xticks([], minor=True)
        ax.set_yticks([], minor=True)

        # Set grid lines to be at major ticks
        ax.grid(which="major", color="black", linewidth=2)

        plt.show()


In [None]:
class Agent:
    """Agent = explorator

    Attributes:
        pos:            current position of agent
        visibility:     how far the agent can see (1 means only adjacent cells)
        env:
        action_value:   State-Action value matrix
        ddist:          distance from datum
        ??choices:        last choices for each position
        ##explored:       matrix with already explored cells -> already in environment
        dir:            direction of movement (=last action)
        state:          state q values of the table
        q function:     q(state, action)
    """

    def __init__(self, world: World):
        # Defining some constants
        self.n_actions = len(world.actions)
        self.Ly, self.Lx = world.grid.shape

        self.pos = world.datum  # initial position is the same as the datum
        self.visibility = 1  # try to change visibility of agent (sees not only strictly adjacent cells)
        self.env = world  # environment of agent. Should be a reference, not a full copy
        # self.action_value = np.zeros((world.grid.shape, len(world.actions)))
        self.ddist = 0  # distance from datum
        # self.explored = np.zeros_like(world.grid)  # matrix with already explored cells
        self.dir = np.array([0, 0])
        self.state_space = np.array(
            list(itertools.product(range(3), range(4), range(4)))
        )  # where 3 is the number of distance classes and 4 the number of directions, 4 the number of time classes
        self.actions = world.actions
        self.q_table = np.zeros((len(self.state_space), self.n_actions))


    # Returns the 4 nearby cells to visit. This could be extended to diagonal movement.
    def get_nearby_cells(self):
        top = self.pos + [1, 0]
        bottom = self.pos + [-1, 0]
        right = self.pos + [0, 1]
        left = self.pos + [0, -1]
        nearby = np.array([bottom, top, left, right])

        return nearby

    def remove_outside_cells(self, coordinates: np.ndarray):
        """
        Remove coordinates outside of a grid of size Lx by Ly.

        Parameters:
        - coordinates: numpy array of shape (n, 2) representing 2D integer coordinates.
        - Lx: int, size of the grid along the x-axis.
        - Ly: int, size of the grid along the y-axis.

        Returns:
        - filtered_coordinates: numpy array with valid coordinates within the grid.
        """
        # Check if the input is a numpy array
        if not isinstance(coordinates, np.ndarray):
            raise ValueError("Input 'coordinates' must be a NumPy array.")

        # Check if the shape of the array is (n, 2)
        if len(coordinates.shape) != 2 or coordinates.shape[1] != 2:
            raise ValueError(
                "Input 'coordinates' must be a 2D array with shape (n, 2)."
            )

        # Create a boolean mask for valid coordinates
        valid_mask = (
            (0 <= coordinates[:, 0])
            & (coordinates[:, 0] < self.Lx)
            & (0 <= coordinates[:, 1])
            & (coordinates[:, 1] < self.Ly)
        )

        # Use boolean indexing to filter the coordinates array
        filtered_coordinates = coordinates[valid_mask]

        return filtered_coordinates

    # Function that chooses an action for the agent. It can be random or based on the action-value matrix.
    # Validity of the action is checked before it is returned.
    # It has to be inserted in a while loop which doesn't update self.env.rem_time, otherwise remaining time
    # will decrease even when an action is not taken but found to be invalid.
    # This is bad especially if the agent can see more than one cell at a time, as it is increasingly
    # more likely that it runs into already visited cells (which are considered not valid actions).
    def chooseAction(self):
        if self.env.randomExplorerFlag:
            # border check:
            nearby = self.get_nearby_cells()
            nearby = self.remove_outside_cells(nearby)

            while True:
                # Sample from self.actions
                action = random.choice(list(self.actions))
                S_new = self.pos + np.array(action)

                # Check if the resulting position is inside nearby array
                if any(np.array_equal(S_new, cell) for cell in nearby):
                    break

            return action

        else:
            print("Needs implementation for randomExplorerFlag = False. Returning -5")
        return -5  # clear error

    # Function that returns the reward of the system. If the agent is in the goal position, it returns 10.
    # It has to be called whenever an action is taken, as it only checks the current position of the agent.
    # Here we could implement the possibility of viewing nearby cells.
    def reward(self):
        # idea: reward is also influenced by the distance
        # reward = reward - self.ddist
        return self.env.grid[self.pos[0], self.pos[1]]

    # Function that chooses an action using "chooseAction" and updates the position of the agent.
    # It also updates the direction of the agent and the visited cells.
    # It only updates the position if the action is not [0,0].
    def update_position(self):
        action = np.array([0, 0])
        while np.array_equal(action, [0, 0]):
            action = self.chooseAction()

        self.pos = self.pos + action
        self.dir = action  # update direction
        self.env.grid[self.pos[0], self.pos[1]] = -2  # update visited cells

        return action

    # Function that updates the state of the agent. It has to be called after the position has been updated.
    # Note, the dir returned value is the index of the chosen action from the vector  
    # actions = np.array([[1, 0], [-1, 0], [0, 1], [0, -1]])
    def get_current_state(self):
        """Return current state of the agent."""
        ddist = np.linalg.norm(self.pos - self.env.datum)
        dir = np.where((self.actions == self.dir).all(axis=1)) # get index of chosen action
        # Discretize distance and time:
        dist, time = discretize_state(
            grid_size=self.Lx,  # supposing square grid
            budget=self.env.budget,
            ddist=ddist,
            rem_time=self.env.rem_time,
        )

        return np.array([dist, dir[0][0], time])

    def update_q_table(self, reward, action_idx):
        """Updates the state of the agent"""
        # parameters that need to be decided:
        alpha = 0.1  # Learning rate
        gamma = 0.9  # Discount factor
        reward = 1.0  # Immediate reward

        state= self.get_current_state()
        state_index = np.where(np.all(self.state_space == state, axis=1))[0][0]
        self.q_table[state_index, action_idx] = (1 - alpha) * self.q_table[state_index, action_idx] +\
            alpha * (reward + gamma * np.max(self.q_table[state_index, :]))

    def transition(self):
        """Chooses a valid action, updates the position of the agent and the remaining time."""
        action = self.update_position()
        # print(f"Chosen action = {action}")
        action_idx = np.where((self.actions == action).all(axis=1))[0][0]
        reward = self.reward()
        # print(f"Reward  = {reward}")
        self.update_q_table(reward, action_idx)

        # update remaining time in environment attribute
        self.env.rem_time -= 1


In [None]:
Lx = 20
Ly = 20

world = World(Lx=Lx, Ly=Ly, rem_time=100, goal=np.array([5, 5]))
ag = Agent(world)
actions = {}

while(world.rem_time > 0):
    ag.transition()

print(f"Q table: \n{ag.q_table}")
world.plot_world()