In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

_**Inspired by [Ozturk et al., "Energy Optimization in Ultra-Dense Radio Access Networks via Traffic-Aware Cell Switching," in IEEE Transactions on Green Communications and Networking, vol. 5, no. 2, pp. 832-845, June 2021, doi: 10.1109/TGCN.2021.3056235.](https://ieeexplore.ieee.org/abstract/document/9344664)**_

In [1]:
%matplotlib inline

import numpy as np

import gymnasium as gym
from gymnasium import spaces

# **Network Scenario**

- #### Inspired by Ozturk et al. (2021) in the BS activation problem for energy saving

- #### Consider a simple network topology of 25 macro BSs in a 5x5 grid

  #### $$ B = \{B_1, B_2, B_3, B_4, ..., B_25\} $$

  <center>
    <img src="./network_scenario/imgs/1.png" alt="1.png" title="1.png" width="320"/>
  </center><br>

- #### Assume that all BSs in the grid have coverage overlaps

- #### Assume a centralized controller can monitor and manage the entire network topology

  <center>
    <img src="./network_scenario/imgs/2.png" alt="2.png" title="2.png" width="480"/>
  </center>

  #### **$ \rightarrow $ Objective:** MDP for turning on/off BSs. If a BS is turning OFF, its load is share to neighbor BSs. Controller perform action on BS sequentially (BSs perform action one by one)

<div class="alert alert-block alert-info">
  <b>Note:</b> MDP stands for Markov decision process
</div>


### __Declaration and Initialization__

In [None]:
class NetworkGridEnv(gym.Env):
    metadata = {}

    def __init__(self, render_mode=None, size=5):
        super(NetworkGridEnv, self).__init__()
        '''
        Define environment parameters
        '''
        self.size = size  # The size of the square grid
        self.num_bs = self.size ** 2
        self.num_actions = self.num_bs * 17
        self.current_time = 0  # Initial time

        '''
        - Observations (state space) are dictionaries with 2 traffic matrices "traffic_demand" and "traffic_state".
        - There are 17 actions for each BS, load can be split equally in up to 4 directions.
        '''
        self.observation_space = spaces.Dict(
            {
                "traffic_demand": spaces.Box(low=0, high=1, shape=(self.size, self.size), dtype=np.float32),
                "traffic_state": spaces.Box(low=0, high=1, shape=(self.size, self.size), dtype=np.float32)
            }
        )
        
        self.action_space = spaces.Discrete(self.num_actions)
        
        '''
        Initialize the state and other variables
        '''
        self.current_time = 0  # Start at time 0
        self.max_steps = 24  # 24 steps for 24 hours in a day
        self.active_bs = np.ones((self.size, self.size), dtype=bool)
        self.prev_bs = (-1, -1)  # Set an initial value
        self._demand_matrix = np.full((self.size, self.size), 0.1)
        self._state_matrix = np.full((self.size, self.size), 0.1)

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
        
    def _get_reward(self):
        # Calculate traffic loss
        loss = np.sum(self._demand_matrix) - np.sum(self._state_matrix)
        
        # Calculate energy consumption
        energy_csm = []
        for i in range(self.size):
            for j in range(self.size):
                if (self._state_matrix[i, j] > 1):
                    self._state_matrix[i, j] = 1
                    energy_csm.append(130 + 4.7 * 20)
                elif (0 < self._state_matrix[i, j] <= 1):
                    energy_csm.append(130 + 4.7 * 20 * self._state_matrix[i, j])
                else:
                    energy_csm.append(75)
        
        max_energy_csm = np.max(energy_csm)
        reward = - 100 * loss
        for i in range(self.num_bs):
            reward += (max_energy_csm - energy_csm[i])
        
        return reward
        
    def _get_obs(self):
        return {"traffic_demand": self._demand_matrix, "traffic_state": self._state_matrix}
    
    def _get_info(self):
        energy_csm = []
        all_on_energy_csm = []
        for i in range(self.size):
            for j in range(self.size):
                if (self._state_matrix[i, j] > 1):
                    energy_csm.append(130 + 4.7 * 20)
                    all_on_energy_csm.append(130 + 4.7 * 20)
                elif (0 < self._state_matrix[i, j] <= 1):
                    energy_csm.append(130 + 4.7 * 20 * self._state_matrix[i, j])
                    all_on_energy_csm.append(130 + 4.7 * 20 * self._state_matrix[i, j])
                else:
                    energy_csm.append(75)
        return {
            "traffic_coverage": np.sum(self._demand_matrix) / np.sum(self._state_matrix) * 100,
            "energy_saving": (np.sum(all_on_energy_csm) - np.sum(energy_csm)) / np.sum(all_on_energy_csm) * 100
        }
    
    def reset(self, seed=None, options=None):
        '''
        Initialize the environment
        '''
        self.current_time = 0
        self.active_bs = np.ones((self.size, self.size), dtype=bool)
        self.prev_bs = -1, -1  # Set an initial value
        self._demand_matrix = np.full((self.size, self.size), 0.1)
        self._state_matrix = np.full((self.size, self.size), 0.1)
        
        observation = self._get_obs()
        info = self._get_info()
        
        return observation, info
    
    def step(self, action):
        '''
        Check if action is within the valid range [0, 424]
        '''
        if action < 0 or action > 424:
            raise ValueError("Invalid action. Action must be in the range [0, 424].")
        
        '''
        Determine the action and the BS location to perform the action
        '''
        bs_row, bs_col, bs_action = action // 17, action % 17, action % 17
        bs_row, bs_col = bs_row // self.size, bs_row % self.size
        
        '''
        Implement action masking to ensure valid actions and update traffic state based on load shifts
        '''
        if ((bs_row, bs_col) == self.prev_bs):
        # Check if the current BS is the same as the previous BS
            observation = self._get_obs()
            info = self._get_info()
            return observation, info
        
        if (bs_action == 0):
            # Action 0: BS turning ON
            self.active_bs[bs_row, bs_col] = 1
        elif (bs_action >= 1 and bs_action <= 15):
            # Action 1-15: Bs turning OFF and shifting loads
            if (self.active_bs[bs_row, bs_col] == 0):
                # Do not turn off an already deactivated BS
                observation = self._get_obs()
                info = self._get_info()
                return observation, info
            else:
                self.active_bs[bs_row, bs_col] = 0
                self._state_matrix[bs_row, bs_col] = 0
                bs_left, bs_top = (bs_action & 0b1000) >> 3, (bs_action & 0b0100) >> 2
                bs_right, bs_bottom = (bs_action & 0b0010) >> 1, (bs_action & 0b0001)
                corner_idx = self.size - 1
                
                # BSs at edge may have less than 4 neighbors, don't share load to OFF BS
                if ((bs_row, bs_col) == (0, 0)):
                    shift_load = bs_right * self.active_bs[0, 1] + bs_bottom * self.active_bs[1, 0]
                    if (shift_load != 0):
                        if (self.active_bs[0, 1] == 0): self._state_matrix[1, 0] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[1, 0] == 0): self._state_matrix[0, 1] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[0, 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[1, 0] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif ((bs_row, bs_col) == (0, corner_idx)):
                    shift_load = bs_left * self.active_bs[0, corner_idx - 1] + bs_bottom * self.active_bs[1, corner_idx]
                    if (shift_load != 0):
                        if (self.active_bs[0, corner_idx - 1] == 0): self._state_matrix[1, corner_idx] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[1, corner_idx] == 0): self._state_matrix[0, corner_idx - 1] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[0, corner_idx - 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[1, corner_idx] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif ((bs_row, bs_col) == (corner_idx, 0)):
                    shift_load = bs_top * self.active_bs[corner_idx - 1, 0] + bs_right * self.active_bs[corner_idx, 1]
                    if (shift_load != 0):
                        if (self.active_bs[corner_idx - 1, 0] == 0): self._state_matrix[corner_idx, 1] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[corner_idx, 1] == 0): self._state_matrix[corner_idx - 1, 0] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[corner_idx - 1, 0] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[corner_idx, 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif ((bs_row, bs_col) == (corner_idx, corner_idx)):
                    shift_load = bs_top * self.active_bs[corner_idx - 1, corner_idx] + bs_left * self.active_bs[corner_idx, corner_idx - 1]
                    if (shift_load != 0):
                        if (self.active_bs[corner_idx - 1, corner_idx] == 0): self._state_matrix[corner_idx, corner_idx - 1] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[corner_idx, corner_idx - 1] == 0): self._state_matrix[corner_idx - 1, corner_idx] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[corner_idx - 1, corner_idx] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[corner_idx, corner_idx - 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif (bs_row == 0):
                    shift_load = bs_left * self.active_bs[0, bs_col - 1] + bs_right * self.active_bs[0, bs_col + 1] + bs_bottom * self.active_bs[1, bs_col]
                    if (shift_load != 0):
                        for dr, dc in [(0, -1), (0, 1), (1, 0)]:
                            if (self.active_bs[dr, bs_col + dc] != 0):
                                self._state_matrix[dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                elif (bs_row == corner_idx):
                    shift_load = bs_left * self.active_bs[corner_idx, bs_col - 1] + bs_right * self.active_bs[corner_idx, bs_col + 1] + bs_top * self.active_bs[corner_idx - 1, bs_col]
                    if (shift_load != 0):          
                        for dr, dc in [(0, -1), (0, 1), (-1, 0)]:
                            if (self.active_bs[corner_idx + dr, bs_col + dc] != 0):
                                self._state_matrix[corner_idx + dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                elif (bs_col == 0):
                    shift_load = bs_right * self.active_bs[bs_row, 1] + bs_top * self.active_bs[bs_row - 1, 0] + bs_bottom * self.active_bs[bs_row + 1, 0]
                    if (shift_load != 0):       
                        for dr, dc in [(-1, 0), (1, 0), (0, 1)]:
                            if (self.active_bs[bs_row + dr, dc] != 0):
                                self._state_matrix[bs_row + dr, dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                elif (bs_col == corner_idx):
                    shift_load = bs_left * self.active_bs[bs_row, corner_idx - 1] + bs_top * self.active_bs[bs_row - 1, corner_idx] + bs_bottom * self.active_bs[bs_row + 1, corner_idx]
                    if (shift_load != 0): 
                        for dr, dc in [(-1, 0), (1, 0), (0, -1)]:
                            if (self.active_bs[bs_row + dr, corner_idx + dc] != 0):
                                self._state_matrix[bs_row + dr, corner_idx + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                else:
                    shift_load = bs_left * self.active_bs[bs_row, bs_col - 1] + bs_top * self.active_bs[bs_row - 1, bs_col] + bs_right * self.active_bs[bs_row, bs_col + 1] + bs_bottom * self.active_bs[bs_row + 1, bs_col]
                    if (shift_load != 0): 
                        for dr, dc in [(-1, 0), (0, -1), (1, 0), (0, 1)]:
                            if (self.active_bs[bs_row + dr, bs_col + dc] != 0):
                                self._state_matrix[bs_row + dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)
               
        else: pass  # Action 16: Do nothing
        
        '''
        Increment the time step 
        '''
        self.current_time += 1        
        
        '''
        An episode is done after 24 steps
        '''
        terminated = self.current_time >= self.max_steps
        truncated = False
        reward = self._get_reward()
        observation = self._get_obs()
        info = self._get_info()
        '''
        Update the previous BS
        '''
        self.prev_bs = bs_row, bs_col
        
        return observation, reward, terminated, truncated, info
        
    
    def render(self):
        '''
        Implement a visualization method (if needed)
        '''
        info = self._get_info()
        print("Performance Metrics:")
        print("- Traffic Coverage: {:.2f}%".format(info["traffic_coverage"]))
        print("- Energy Saving: {:.2f}%".format(info["energy_saving"]))
    
    def close(self):
        '''
        Implement the close function to clean up (if needed)
        '''
        pass

### __Constructing Observations From Environment States__

In [None]:
def _get_obs(self):
    return {"traffic_demand": self._demand_matrix, "traffic_state": self._state_matrix}

def _get_info(self):
    energy_csm = []
    all_on_energy_csm = []
    for i in range(self.size):
        for j in range(self.size):
            if (self._state_matrix[i, j] > 1):
                energy_csm.append(130 + 4.7 * 20)
                all_on_energy_csm.append(130 + 4.7 * 20)
            elif (0 < self._state_matrix[i, j] <= 1):
                energy_csm.append(130 + 4.7 * 20 * self._state_matrix[i, j])
                all_on_energy_csm.append(130 + 4.7 * 20 * self._state_matrix[i, j])
            else:
                energy_csm.append(75)
    return {
        "traffic_coverage": np.sum(self._demand_matrix) / np.sum(self._state_matrix) * 100,
        "energy_saving": (np.sum(all_on_energy_csm) - np.sum(energy_csm)) / np.sum(all_on_energy_csm) * 100
    }

## **State space**

#### State space will contain 2 traffic matrices: **$ s^t = (D^t, L^t) $**

- #### **Traffic demand (D):** The original traffic demand from [Milan dataset](https://ieee-dataport.org/documents/milan-dataset) at time t <br>
  - #### $ 0 \leq {d_j}^t \leq 1 $
  - #### Initiate all $ {d_j}^0 = {l_j}^0 = 0.1 $

- #### **Traffic state (L):** The traffic load matrix considering BS (de)activation evolution from time to time (from $ t = 0 $ to $ t $) <br>
  - #### $ {l_j}^t = 0 $ indicates BS j is OFF
  - #### $ 0 \leq {l_j}^t \leq 1 $
  - #### Initiate all $ {l_j}^t $ is $ {\lambda_j}^t $ for calculating Power

<center>
  <img src="./network_scenario/imgs/3.png" alt="3.png" title="3.png" width="480"/>
</center>

In [None]:
self.observation_space = spaces.Dict(
    {
        "traffic_demand": spaces.Box(low=0, high=1, shape=(self.size, self.size), dtype=np.float32),
        "traffic_state": spaces.Box(low=0, high=1, shape=(self.size, self.size), dtype=np.float32)
    }
)

### __Reset__

In [None]:
def reset(self, seed=None, options=None):
    '''
    Initialize the environment
    '''
    self.current_time = 0
    self.active_bs = np.ones((self.size, self.size), dtype=bool)
    self.prev_bs = -1, -1  # Set an initial value
    self._demand_matrix = np.full((self.size, self.size), 0.1)
    self._state_matrix = np.full((self.size, self.size), 0.1)
    
    observation = self._get_obs()
    info = self._get_info()
    
    return observation, info

## **Action space**

- #### There will be **17 actions for each BS**

- #### Load can be split **equally** in <u>up to</u> 4 directions

- #### Total action space of **25 * 17 = 425 possible actions**

  #### $$ a^t \in \{0, 1, 2,..., 424\} $$

<center>
  <img src="./network_scenario/imgs/4.png" alt="4.png" title="4.png" width="600"/>
</center>

In [None]:
self.action_space = spaces.Discrete(self.num_actions)

### __Step__

In [None]:
def step(self, action):
        '''
        Check if action is within the valid range [0, 424]
        '''
        if action < 0 or action > 424:
            raise ValueError("Invalid action. Action must be in the range [0, 424].")
        
        '''
        Determine the action and the BS location to perform the action
        '''
        bs_row, bs_col, bs_action = action // 17, action % 17, action % 17
        bs_row, bs_col = bs_row // self.size, bs_row % self.size
        
        '''
        Implement action masking to ensure valid actions and update traffic state based on load shifts
        '''
        if ((bs_row, bs_col) == self.prev_bs):
        # Check if the current BS is the same as the previous BS
            observation = self._get_obs()
            info = self._get_info()
            return observation, info
        
        if (bs_action == 0):
            # Action 0: BS turning ON
            self.active_bs[bs_row, bs_col] = 1
        elif (bs_action >= 1 and bs_action <= 15):
            # Action 1-15: Bs turning OFF and shifting loads
            if (self.active_bs[bs_row, bs_col] == 0):
                # Do not turn off an already deactivated BS
                observation = self._get_obs()
                info = self._get_info()
                return observation, info
            else:
                self.active_bs[bs_row, bs_col] = 0
                self._state_matrix[bs_row, bs_col] = 0
                bs_left, bs_top = (bs_action & 0b1000) >> 3, (bs_action & 0b0100) >> 2
                bs_right, bs_bottom = (bs_action & 0b0010) >> 1, (bs_action & 0b0001)
                corner_idx = self.size - 1
                
                # BSs at edge may have less than 4 neighbors, don't share load to OFF BS
                if ((bs_row, bs_col) == (0, 0)):
                    shift_load = bs_right * self.active_bs[0, 1] + bs_bottom * self.active_bs[1, 0]
                    if (shift_load != 0):
                        if (self.active_bs[0, 1] == 0): self._state_matrix[1, 0] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[1, 0] == 0): self._state_matrix[0, 1] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[0, 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[1, 0] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif ((bs_row, bs_col) == (0, corner_idx)):
                    shift_load = bs_left * self.active_bs[0, corner_idx - 1] + bs_bottom * self.active_bs[1, corner_idx]
                    if (shift_load != 0):
                        if (self.active_bs[0, corner_idx - 1] == 0): self._state_matrix[1, corner_idx] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[1, corner_idx] == 0): self._state_matrix[0, corner_idx - 1] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[0, corner_idx - 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[1, corner_idx] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif ((bs_row, bs_col) == (corner_idx, 0)):
                    shift_load = bs_top * self.active_bs[corner_idx - 1, 0] + bs_right * self.active_bs[corner_idx, 1]
                    if (shift_load != 0):
                        if (self.active_bs[corner_idx - 1, 0] == 0): self._state_matrix[corner_idx, 1] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[corner_idx, 1] == 0): self._state_matrix[corner_idx - 1, 0] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[corner_idx - 1, 0] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[corner_idx, 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif ((bs_row, bs_col) == (corner_idx, corner_idx)):
                    shift_load = bs_top * self.active_bs[corner_idx - 1, corner_idx] + bs_left * self.active_bs[corner_idx, corner_idx - 1]
                    if (shift_load != 0):
                        if (self.active_bs[corner_idx - 1, corner_idx] == 0): self._state_matrix[corner_idx, corner_idx - 1] += self._state_matrix[bs_row, bs_col]
                        elif (self.active_bs[corner_idx, corner_idx - 1] == 0): self._state_matrix[corner_idx - 1, corner_idx] += self._state_matrix[bs_row, bs_col]
                        else:
                            self._state_matrix[corner_idx - 1, corner_idx] += (self._state_matrix[bs_row, bs_col] / shift_load)
                            self._state_matrix[corner_idx, corner_idx - 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
                elif (bs_row == 0):
                    shift_load = bs_left * self.active_bs[0, bs_col - 1] + bs_right * self.active_bs[0, bs_col + 1] + bs_bottom * self.active_bs[1, bs_col]
                    if (shift_load != 0):
                        for dr, dc in [(0, -1), (0, 1), (1, 0)]:
                            if (self.active_bs[dr, bs_col + dc] != 0):
                                self._state_matrix[dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                elif (bs_row == corner_idx):
                    shift_load = bs_left * self.active_bs[corner_idx, bs_col - 1] + bs_right * self.active_bs[corner_idx, bs_col + 1] + bs_top * self.active_bs[corner_idx - 1, bs_col]
                    if (shift_load != 0):          
                        for dr, dc in [(0, -1), (0, 1), (-1, 0)]:
                            if (self.active_bs[corner_idx + dr, bs_col + dc] != 0):
                                self._state_matrix[corner_idx + dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                elif (bs_col == 0):
                    shift_load = bs_right * self.active_bs[bs_row, 1] + bs_top * self.active_bs[bs_row - 1, 0] + bs_bottom * self.active_bs[bs_row + 1, 0]
                    if (shift_load != 0):       
                        for dr, dc in [(-1, 0), (1, 0), (0, 1)]:
                            if (self.active_bs[bs_row + dr, dc] != 0):
                                self._state_matrix[bs_row + dr, dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                elif (bs_col == corner_idx):
                    shift_load = bs_left * self.active_bs[bs_row, corner_idx - 1] + bs_top * self.active_bs[bs_row - 1, corner_idx] + bs_bottom * self.active_bs[bs_row + 1, corner_idx]
                    if (shift_load != 0): 
                        for dr, dc in [(-1, 0), (1, 0), (0, -1)]:
                            if (self.active_bs[bs_row + dr, corner_idx + dc] != 0):
                                self._state_matrix[bs_row + dr, corner_idx + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

                else:
                    shift_load = bs_left * self.active_bs[bs_row, bs_col - 1] + bs_top * self.active_bs[bs_row - 1, bs_col] + bs_right * self.active_bs[bs_row, bs_col + 1] + bs_bottom * self.active_bs[bs_row + 1, bs_col]
                    if (shift_load != 0): 
                        for dr, dc in [(-1, 0), (0, -1), (1, 0), (0, 1)]:
                            if (self.active_bs[bs_row + dr, bs_col + dc] != 0):
                                self._state_matrix[bs_row + dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)
               
        else: pass  # Action 16: Do nothing
        
        '''
        Increment the time step 
        '''
        self.current_time += 1        
        
        '''
        An episode is done after 24 steps
        '''
        terminated = self.current_time >= self.max_steps
        truncated = False
        reward = self._get_reward()
        observation = self._get_obs()
        info = self._get_info()
        '''
        Update the previous BS
        '''
        self.prev_bs = bs_row, bs_col
        
        return observation, reward, terminated, truncated, info

### __Action space masking__

- #### **Action mask:** Valid actions to perform at every step t
- #### The purpose of action masking is to ensure:
    - #### Do not perform action on the same previous BS (Avoid BS ON-OFF continuously)
    - #### Allow only feasible actions in each state:
        - #### Do not share load to OFF BS
        - #### Do not turn off an already deactivated BS
        - #### BSs at edge may have less than 4 neighbors

<center>
  <img src="./network_scenario/imgs/5.png" alt="5.png" title="5.png" width="800"/>
</center>

In [None]:
'''
Implement action masking to ensure valid actions and update traffic state based on load shifts
'''
if ((bs_row, bs_col) == self.prev_bs):
# Check if the current BS is the same as the previous BS
    observation = self._get_obs()
    info = self._get_info()
    return observation, info
        
if (bs_action == 0):
    # Action 0: BS turning ON
    self.active_bs[bs_row, bs_col] = 1
elif (bs_action >= 1 and bs_action <= 15):
    # Action 1-15: Bs turning OFF and shifting loads
    if (self.active_bs[bs_row, bs_col] == 0):
        # Do not turn off an already deactivated BS
        observation = self._get_obs()
        info = self._get_info()
        return observation, info
    else:
        self.active_bs[bs_row, bs_col] = 0
        self._state_matrix[bs_row, bs_col] = 0
        bs_left, bs_top = (bs_action & 0b1000) >> 3, (bs_action & 0b0100) >> 2
        bs_right, bs_bottom = (bs_action & 0b0010) >> 1, (bs_action & 0b0001)
        corner_idx = self.size - 1
                
        # BSs at edge may have less than 4 neighbors, don't share load to OFF BS
        if ((bs_row, bs_col) == (0, 0)):
            shift_load = bs_right * self.active_bs[0, 1] + bs_bottom * self.active_bs[1, 0]
            if (shift_load != 0):
                if (self.active_bs[0, 1] == 0): self._state_matrix[1, 0] += self._state_matrix[bs_row, bs_col]
                elif (self.active_bs[1, 0] == 0): self._state_matrix[0, 1] += self._state_matrix[bs_row, bs_col]
                else:
                    self._state_matrix[0, 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    self._state_matrix[1, 0] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
        elif ((bs_row, bs_col) == (0, corner_idx)):
            shift_load = bs_left * self.active_bs[0, corner_idx - 1] + bs_bottom * self.active_bs[1, corner_idx]
            if (shift_load != 0):
                if (self.active_bs[0, corner_idx - 1] == 0): self._state_matrix[1, corner_idx] += self._state_matrix[bs_row, bs_col]
                elif (self.active_bs[1, corner_idx] == 0): self._state_matrix[0, corner_idx - 1] += self._state_matrix[bs_row, bs_col]
                else:
                    self._state_matrix[0, corner_idx - 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    self._state_matrix[1, corner_idx] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
        elif ((bs_row, bs_col) == (corner_idx, 0)):
            shift_load = bs_top * self.active_bs[corner_idx - 1, 0] + bs_right * self.active_bs[corner_idx, 1]
            if (shift_load != 0):
                if (self.active_bs[corner_idx - 1, 0] == 0): self._state_matrix[corner_idx, 1] += self._state_matrix[bs_row, bs_col]
                elif (self.active_bs[corner_idx, 1] == 0): self._state_matrix[corner_idx - 1, 0] += self._state_matrix[bs_row, bs_col]
                else:
                    self._state_matrix[corner_idx - 1, 0] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    self._state_matrix[corner_idx, 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
        elif ((bs_row, bs_col) == (corner_idx, corner_idx)):
            shift_load = bs_top * self.active_bs[corner_idx - 1, corner_idx] + bs_left * self.active_bs[corner_idx, corner_idx - 1]
            if (shift_load != 0):
                if (self.active_bs[corner_idx - 1, corner_idx] == 0): self._state_matrix[corner_idx, corner_idx - 1] += self._state_matrix[bs_row, bs_col]
                elif (self.active_bs[corner_idx, corner_idx - 1] == 0): self._state_matrix[corner_idx - 1, corner_idx] += self._state_matrix[bs_row, bs_col]
                else:
                    self._state_matrix[corner_idx - 1, corner_idx] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    self._state_matrix[corner_idx, corner_idx - 1] += (self._state_matrix[bs_row, bs_col] / shift_load)
                    
        elif (bs_row == 0):
            shift_load = bs_left * self.active_bs[0, bs_col - 1] + bs_right * self.active_bs[0, bs_col + 1] + bs_bottom * self.active_bs[1, bs_col]
            if (shift_load != 0):
                for dr, dc in [(0, -1), (0, 1), (1, 0)]:
                    if (self.active_bs[dr, bs_col + dc] != 0):
                        self._state_matrix[dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

        elif (bs_row == corner_idx):
            shift_load = bs_left * self.active_bs[corner_idx, bs_col - 1] + bs_right * self.active_bs[corner_idx, bs_col + 1] + bs_top * self.active_bs[corner_idx - 1, bs_col]
            if (shift_load != 0):          
                for dr, dc in [(0, -1), (0, 1), (-1, 0)]:
                    if (self.active_bs[corner_idx + dr, bs_col + dc] != 0):
                        self._state_matrix[corner_idx + dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

        elif (bs_col == 0):
            shift_load = bs_right * self.active_bs[bs_row, 1] + bs_top * self.active_bs[bs_row - 1, 0] + bs_bottom * self.active_bs[bs_row + 1, 0]
            if (shift_load != 0):       
                for dr, dc in [(-1, 0), (1, 0), (0, 1)]:
                    if (self.active_bs[bs_row + dr, dc] != 0):
                        self._state_matrix[bs_row + dr, dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

        elif (bs_col == corner_idx):
            shift_load = bs_left * self.active_bs[bs_row, corner_idx - 1] + bs_top * self.active_bs[bs_row - 1, corner_idx] + bs_bottom * self.active_bs[bs_row + 1, corner_idx]
            if (shift_load != 0): 
                for dr, dc in [(-1, 0), (1, 0), (0, -1)]:
                    if (self.active_bs[bs_row + dr, corner_idx + dc] != 0):
                        self._state_matrix[bs_row + dr, corner_idx + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)

        else:
            shift_load = bs_left * self.active_bs[bs_row, bs_col - 1] + bs_top * self.active_bs[bs_row - 1, bs_col] + bs_right * self.active_bs[bs_row, bs_col + 1] + bs_bottom * self.active_bs[bs_row + 1, bs_col]
            if (shift_load != 0): 
                for dr, dc in [(-1, 0), (0, -1), (1, 0), (0, 1)]:
                    if (self.active_bs[bs_row + dr, bs_col + dc] != 0):
                        self._state_matrix[bs_row + dr, bs_col + dc] += (self._state_matrix[bs_row, bs_col] / shift_load)
       
else: pass  # Action 16: Do nothing

## **Reward function**

- #### **Energy Consumption**

    - #### Derived from the _[Energy Aware Radioand neTwork tecHnologies (EARTH) power consumption model](https://doi.org/10.1109/MWC.2011.6056691)_

    <center>
      <img src="./network_scenario/imgs/6.png" alt="6.png" title="6.png" width="400"/>
    </center><br>

    - #### With $ P_0 = 130W, \eta = 4.7, P_T = 20W, P_S = 75 W $ for Macro BS

- #### **Traffic Loss**

  #### The maximum load of each BS is 1 $ \rightarrow $ Migrate load to a busy BS will lead to sacrifice of traffic loads (BSs load clip at 1)

  #### $$ loss = sum(D^t) - sum(L^t) $$

- #### **Reward Function**

  <center>
    <img src="./network_scenario/imgs/7.png" alt="7.png" title="7.png" width="600"/>
  </center>

In [None]:
def _get_reward(self):
    # Calculate traffic loss
    loss = np.sum(self._demand_matrix) - np.sum(self._state_matrix)
    
    # Calculate energy consumption
    energy_csm = []
    for i in range(self.size):
        for j in range(self.size):
            if (self._state_matrix[i, j] > 1):
                self._state_matrix[i, j] = 1
                energy_csm.append(130 + 4.7 * 20)
            elif (0 < self._state_matrix[i, j] <= 1):
                energy_csm.append(130 + 4.7 * 20 * self._state_matrix[i, j])
            else:
                energy_csm.append(75)
        
    max_energy_csm = np.max(energy_csm)
    reward = - 100 * loss
    for i in range(self.num_bs):
        reward += (max_energy_csm - energy_csm[i])
        
    return reward

## **Environment configuration**

- #### State transitions:

<center>
  <img src="./network_scenario/imgs/8.png" alt="8.png" title="8.png" width="600"/>
</center>

- #### Each episode contains: 24 (time intervals per day) * 1 (days per week) = **24 steps**

- #### Performance metrics (to be collect from `_get_info`):

    - #### Traffic coverage (%): `traffic_coverage = sum(L^t)/sum(D^t) * 100`

    - #### Energy saving (%): `energy_saving = (total_P_all_ON – total_P)/ total_P_all_ON * 100`