In [None]:
from tensorflow.keras.layers import Lambda
import numpy as np
from ..config.parameters import *

class ActionMasking:
    def __init__(self, action_dim=126):
        self.action_dim = action_dim
        self.movement_actions = list(range(7))  # up, down, forward, backward, left, right, hover
        self.power_actions = {
            1: list(range(63, 84)),    # Single user
            2: list(range(0, 63)),     # Two users
            3: list(range(84, 105)),   # Three users
            4: list(range(105, 126))   # Four users
        }
        
        # Create masking layer
        self.masking_layer = self.create_masking_layer()

    def create_masking_layer(self):
        """Create Lambda layer for action masking"""
        def mask_actions(inputs):
            q_values, mask = inputs
            return q_values * mask
        return Lambda(mask_actions)

    def get_mask(self, cluster_size):
        """Generate binary mask for valid actions"""
        mask = np.zeros(self.action_dim)  # Initialize all actions as invalid
        
        # Set movement actions as valid
        mask[self.movement_actions] = 1.0
        
        # Set valid power allocation actions based on cluster size
        valid_actions = self.power_actions.get(cluster_size, [])
        mask[valid_actions] = 1.0
        
        return mask