In [None]:
import numpy as np
import random
from collections import deque
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Input, Lambda
from tensorflow.keras.optimizers import Adam
from .action_masking import ActionMasking

class DQN:
    def __init__(self, state_dim, num_uavs, num_users):
        self.state_dim = state_dim
        self.num_uavs = num_uavs
        self.num_users = num_users
        self.action_dim = 126  # Total action space size
        
        # Network parameters
        self.update_freq = 2000
        self.replay_size = 20000
        self.batch_size = 128
        self.learning_rate = 0.001
        
        # Initialize replay memory
        self.replay_buffer = deque(maxlen=self.replay_size)
        self.step = 0
        
        # Initialize action masking
        self.action_masker = ActionMasking(self.action_dim)
        
        # Create networks
        self.model = self.build_model()
        self.target_model = self.build_model()
        self.target_model.set_weights(self.model.get_weights())

    def build_model(self):
        """Build neural network with integrated action masking"""
        # State input
        state_input = Input(shape=(self.state_dim,))
        
        # Action mask input
        mask_input = Input(shape=(self.action_dim,))
        
        # Hidden layers
        x = Dense(90, activation='relu')(state_input)
        x = Dense(90, activation='relu')(x)
        
        # Q-value output before masking
        q_values = Dense(self.action_dim)(x)
        
        # Apply action mask
        masked_q_values = self.action_masker.masking_layer([q_values, mask_input])
        
        # Create model
        model = Model(inputs=[state_input, mask_input], outputs=masked_q_values)
        model.compile(optimizer=Adam(learning_rate=self.learning_rate), 
                     loss='mse')
        return model

    def state_abstraction(self, state, serving_uav):
        """Create abstracted state representation as per paper"""
        # Reorder UAV positions to put serving UAV first
        uav_positions = state[:self.num_uavs * 3].reshape(-1, 3)
        serving_uav_pos = uav_positions[serving_uav].copy()
        uav_positions[0], uav_positions[serving_uav] = serving_uav_pos, uav_positions[0]
        
        # Reorder channel gains based on serving UAV
        channel_gains = state[self.num_uavs * 3:].copy()
        serving_users = np.where(state[-self.num_users:] == serving_uav)[0]
        for i, user in enumerate(serving_users):
            channel_gains[i], channel_gains[user] = channel_gains[user], channel_gains[i]
            
        return np.concatenate([uav_positions.flatten(), channel_gains])

    def choose_action(self, state, epsilon, serving_uav, user_association):
        """Choose action using epsilon-greedy policy with action masking"""
        cluster_size = len(np.where(user_association.iloc[0,:] == serving_uav)[0])
        mask = self.action_masker.get_mask(cluster_size)
        
        if np.random.random() < epsilon:
            return self.action_masker.get_random_valid_action(cluster_size)
        
        # Get Q-values with masking
        state = np.expand_dims(self.state_abstraction(state, serving_uav), axis=0)
        mask = np.expand_dims(mask, axis=0)
        q_values = self.model.predict([state, mask], verbose=0)[0]
        
        return np.argmax(q_values)

    def remember(self, state, action, next_state, reward, serving_uav, cluster_size):
        """Store experience in replay buffer with abstracted states"""
        abstracted_state = self.state_abstraction(state, serving_uav)
        abstracted_next_state = self.state_abstraction(next_state, serving_uav)
        self.replay_buffer.append((abstracted_state, action, abstracted_next_state, 
                                 reward, cluster_size))

    def train(self, gamma=0.99):
        """Train the network using experience replay and action masking"""
        if len(self.replay_buffer) < self.batch_size:
            return
        
        self.step += 1
        
        # Sample batch
        batch = random.sample(self.replay_buffer, self.batch_size)
        states = np.array([exp[0] for exp in batch])
        actions = np.array([exp[1] for exp in batch])
        next_states = np.array([exp[2] for exp in batch])
        rewards = np.array([exp[3] for exp in batch])
        cluster_sizes = np.array([exp[4] for exp in batch])
        
        # Generate masks for current and next states
        masks = np.array([self.action_masker.get_mask(size)[0] for size in cluster_sizes])
        next_masks = masks.copy()
        
        # Get current Q values
        current_q = self.model.predict([states, masks], verbose=0)
        
        # Get next Q values from target network
        next_q = self.target_model.predict([next_states, next_masks], verbose=0)
        
        # Update Q values
        for i in range(self.batch_size):
            current_q[i][actions[i]] = rewards[i] + gamma * np.max(next_q[i])
        
        # Train the model
        self.model.fit([states, masks], current_q, batch_size=self.batch_size, 
                      verbose=0)
        
        # Update target network
        if self.step % self.update_freq == 0:
            self.target_model.set_weights(self.model.get_weights())
            
    def get_power_allocation(self, action, cluster_size):
        """Get power allocation scheme from action"""
        return self.action_masker.get_power_allocation_scheme(action, cluster_size)