In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Lambda

class ActionMasking:
    def __init__(self, action_dim=126):
        self.action_dim = action_dim
        self.movement_actions = list(range(7))  # up, down, forward, backward, left, right, hover
        
        # Power allocation actions per cluster size
        self.power_actions = {
            1: list(range(63, 84)),    # Single user
            2: list(range(0, 63)),     # Two users
            3: list(range(84, 105)),   # Three users
            4: list(range(105, 126))   # Four users
        }
        
        # Create masking layer for network
        self.masking_layer = self.create_masking_layer()

    def create_masking_layer(self):
        """Create Lambda layer for action masking"""
        def mask_actions(inputs):
            q_values, mask = inputs
            return q_values + mask
        return Lambda(mask_actions)

    def get_mask(self, cluster_size, batch_size=1):
        """Generate action mask for given cluster size"""
        mask = np.full((batch_size, self.action_dim), -np.inf)  # Initialize with -inf
        
        # Set movement actions as valid
        mask[:, self.movement_actions] = 0
        
        # Set valid power allocation actions
        valid_actions = self.power_actions.get(cluster_size, [])
        mask[:, valid_actions] = 0
        
        return mask

    def get_batch_masks(self, cluster_sizes, batch_size):
        """Generate masks for a batch of different cluster sizes"""
        masks = np.full((batch_size, self.action_dim), -np.inf)
        
        for i, cluster_size in enumerate(cluster_sizes):
            # Set movement actions as valid
            masks[i, self.movement_actions] = 0
            
            # Set valid power allocation actions
            valid_actions = self.power_actions.get(cluster_size, [])
            masks[i, valid_actions] = 0
            
        return masks

    def apply_mask(self, q_values, mask):
        """Apply mask to Q-values"""
        return self.masking_layer([q_values, mask])

    def get_valid_actions(self, cluster_size):
        """Get list of valid actions for given cluster size"""
        mask = self.get_mask(cluster_size)[0]
        return np.where(mask == 0)[0]

    def get_random_valid_action(self, cluster_size):
        """Get random valid action for given cluster size"""
        valid_actions = self.get_valid_actions(cluster_size)
        return np.random.choice(valid_actions)

    def get_action_components(self, action):
        """Split action into movement and power allocation components"""
        movement = action % 7
        power_alloc = action // 7
        return movement, power_alloc

    def validate_action(self, action, cluster_size):
        """Check if action is valid for given cluster size"""
        mask = self.get_mask(cluster_size)[0]
        return mask[action] == 0

    def get_power_allocation_scheme(self, action_number, cluster_size):
        """Get power allocation values based on action and cluster size"""
        _, power_scheme = self.get_action_components(action_number)
        
        if cluster_size == 1:
            power_values = [7, 3, 1]
            return power_values[power_scheme % 3] * self.power_unit
        elif cluster_size == 2:
            stronger_user_values = [2, 4, 7]
            weaker_user_values = [1, 0.5, 0.25]
            scheme_idx = power_scheme % 3
            base_idx = power_scheme // 3
            return {
                'stronger': stronger_user_values[scheme_idx] * self.power_unit,
                'weaker': weaker_user_values[base_idx] * self.power_unit
            }
        # Add schemes for 3 and 4 users...
        
        return None

    def get_cluster_power_constraints(self, cluster_size):
        """Get power constraints for given cluster size"""
        if cluster_size == 1:
            return {'min': 0.5, 'max': 7}
        elif cluster_size == 2:
            return {
                'stronger': {'min': 2, 'max': 7},
                'weaker': {'min': 0.25, 'max': 1}
            }
        # Add constraints for 3 and 4 users...
        
        return None