In [None]:
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import random
from collections import deque
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from .action_masking import ActionMasking
from ..config.parameters import *

class DQN(object):
    def __init__(self):
        self.update_freq = 2000  # Model update frequency
        self.replay_size = 20000  # Training set size 
        self.step = 0
        self.replay_queue = deque(maxlen=self.replay_size)
        
        # Define state and action dimensions
        self.state_dim = NumberOfUAVs*3 + NumberOfUsers
        self.action_dim = 126  # Total actions
        
        # Initialize action masking
        self.action_masker = ActionMasking()
        
        self.model = self.create_model()
        self.target_model = self.create_model()

    def create_model(self):
        """Create neural network with integrated action masking"""
        state_input = Input(shape=(self.state_dim,))
        mask_input = Input(shape=(self.action_dim,))
        
        # Hidden layers
        x = Dense(90, activation='relu')(state_input)
        x = Dense(90, activation='relu')(x)
        
        # Q-value output before masking
        q_values = Dense(self.action_dim)(x)
        
        # Apply mask using element-wise multiplication
        masked_q_values = self.action_masker.masking_layer([q_values, mask_input])
        
        model = Model(inputs=[state_input, mask_input], outputs=masked_q_values)
        model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
        return model

    def Choose_action(self, s, epsilon, acting_UAV, User_asso_list):
        """Choose action with integrated masking"""
        # Get cluster size and action mask
        acting_user_list = np.where(User_asso_list.iloc[0,:] == acting_UAV)[0]
        cluster_size = len(acting_user_list)
        action_mask = self.action_masker.get_mask(cluster_size)
        
        if np.random.uniform() < epsilon:
            # Random action from valid actions only
            valid_actions = np.where(action_mask > 0)[0]
            return np.random.choice(valid_actions)
        else:
            # Get Q-values with masking
            state = np.expand_dims(s, axis=0)
            mask = np.expand_dims(action_mask, axis=0)
            q_values = self.model.predict([state, mask], verbose=0)[0]
            return np.argmax(q_values)

    def remember(self, s, a, next_s, reward, acting_UAV, User_asso_list):
        """Store experience with action mask"""
        cluster_size = len(np.where(User_asso_list.iloc[0,:] == acting_UAV)[0])
        mask = self.action_masker.get_mask(cluster_size)
        self.replay_queue.append((s, a, next_s, reward, mask))

    def train(self, batch_size=128, lr=1, factor=1):
        """Train with action masking"""
        if len(self.replay_queue) < self.replay_size:
            return
        self.step += 1
        
        if self.step % self.update_freq == 0:
            self.target_model.set_weights(self.model.get_weights())

        # Sample batch
        replay_batch = random.sample(self.replay_queue, batch_size)
        states = np.array([replay[0] for replay in replay_batch])
        actions = np.array([replay[1] for replay in replay_batch])
        next_states = np.array([replay[2] for replay in replay_batch])
        rewards = np.array([replay[3] for replay in replay_batch])
        masks = np.array([replay[4] for replay in replay_batch])
        
        # Get current Q values
        current_q = self.model.predict([states, masks], verbose=0)
        
        # Get next Q values from target model
        next_q = self.target_model.predict([next_states, masks], verbose=0)
        
        # Update Q values
        for i in range(batch_size):
            current_q[i][actions[i]] = (1 - lr) * current_q[i][actions[i]] + \
                                     lr * (rewards[i] + factor * np.max(next_q[i]))
        
        # Train the model
        self.model.fit([states, masks], current_q, verbose=0)

    def User_association(self, UAV_Position, User_Position, UAVsnumber, Usersnumber):
        User_Position_array = np.zeros([Usersnumber, 2])
        User_Position_array[:, 0] = User_Position.iloc[0,:].T
        User_Position_array[:, 1] = User_Position.iloc[1,:].T

        K_means_association = KMeans(n_clusters=UAVsnumber).fit(User_Position_array)
        User_cluster = K_means_association.labels_
        Cluster_center = K_means_association.cluster_centers_

        # Select nearest UAV to serve
        UAV_Position_array = np.zeros([UAVsnumber, 2])
        UAV_Position_array[:, 0] = UAV_Position.iloc[0, :].T
        UAV_Position_array[:, 1] = UAV_Position.iloc[1, :].T

        User_association_list = pd.DataFrame(
            np.zeros((1, Usersnumber)),
            columns=np.arange(Usersnumber).tolist(),
        )

        # New SSD cluster UAV pairing
        distance_UAVi2C = np.zeros((NumberOfUAVs, NumberOfUAVs))
        for UAV_name in range(NumberOfUAVs):
            for cluster_name in range(NumberOfUAVs):
                distance_UAVi2C[UAV_name,cluster_name] = np.linalg.norm(UAV_Position_array[UAV_name,:]-Cluster_center[cluster_name])

        row_ind, col_ind = linear_sum_assignment(distance_UAVi2C)
        
        for i in range(NumberOfUAVs):
            Servied_cluster = col_ind[i]
            Servied_cluster_list = Servied_cluster
            Servied_users = np.where(User_cluster==Servied_cluster_list)
            Servied_users_list = Servied_users[0]

            for j in range(np.size(Servied_users)):
                User_association_list.iloc[0,Servied_users_list[j]] = int(i)
            User_association_list = User_association_list.astype('int')

        return User_association_list