In [None]:
import numpy as np

class ActionMasking:
    def __init__(self, action_dim=126):
        self.action_dim = action_dim
        self.movement_actions = 7  # up, down, forward, backward, left, right, hover
        
        # Define power allocation actions per cluster size
        self.power_actions = {
            1: list(range(63, 84)),    # Single user
            2: list(range(0, 63)),     # Two users
            3: list(range(84, 105)),   # Three users
            4: list(range(105, 126))   # Four users
        }

    def get_mask(self, cluster_size):
        """Generate binary mask for valid actions based on cluster size"""
        mask = np.full(self.action_dim, -np.inf)  # Initialize all actions as invalid
        
        # Set movement actions as valid (always available)
        mask[0:self.movement_actions] = 0
        
        # Set valid power allocation actions based on cluster size
        valid_actions = self.power_actions.get(cluster_size, [])
        mask[valid_actions] = 0
        
        return mask

    def apply_mask(self, q_values, mask):
        """Apply mask to Q-values"""
        return q_values + mask

    def get_valid_actions(self, mask):
        """Get list of valid actions from mask"""
        return np.where(mask == 0)[0]

    def validate_action(self, action, cluster_size):
        """Check if action is valid for given cluster size"""
        mask = self.get_mask(cluster_size)
        return mask[action] == 0

    def get_random_valid_action(self, cluster_size):
        """Get random valid action for given cluster size"""
        if cluster_size == 2:
            return np.random.randint(0, 63)
        elif cluster_size == 1:
            return np.random.randint(63, 84)
        elif cluster_size == 3:
            return np.random.randint(84, 105)
        elif cluster_size == 4:
            return np.random.randint(105, 126)
        else:
            raise ValueError(f"Invalid cluster size: {cluster_size}")

    def get_movement_action(self, action_number):
        """Extract movement action from action number"""
        return action_number % 7

    def get_power_allocation(self, action_number):
        """Extract power allocation scheme from action number"""
        return action_number // 7

    def get_cluster_size_from_action(self, action):
        """Determine cluster size based on action number"""
        if 0 <= action <= 62:
            return 2
        elif 63 <= action <= 83:
            return 1
        elif 84 <= action <= 104:
            return 3
        elif 105 <= action <= 125:
            return 4
        else:
            raise ValueError(f"Invalid action: {action}")