In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers, models
from collections import deque
import random

class NOMA_PPO:
    def __init__(self, 
                 num_users=10,
                 max_buffer_size=5,
                 discount_factor=0.99,
                 gae_lambda=0.95,
                 ppo_epsilon=0.2,
                 actor_learning_rate=0.0003,
                 critic_learning_rate=0.001,
                 entropy_coef=0.01,
                 batch_size=64,
                 channel_threshold=0.1,
                 tau_threshold=5,
                 max_noma_users=4):
        """
        Initialize the NOMA_PPO agent.
        
        Args:
            num_users: Number of users in the system
            max_buffer_size: Maximum buffer size for each user
            discount_factor: Discount factor for future rewards
            gae_lambda: Lambda parameter for GAE
            ppo_epsilon: Clipping parameter for PPO
            actor_learning_rate: Learning rate for the actor network
            critic_learning_rate: Learning rate for the critic network
            entropy_coef: Coefficient for entropy term in loss function
            batch_size: Batch size for training
            channel_threshold: Threshold for channel quality
            tau_threshold: Threshold for channel age
            max_noma_users: Maximum number of users that can be scheduled simultaneously (B)
        """
        self.num_users = num_users
        self.max_buffer_size = max_buffer_size
        self.gamma = discount_factor
        self.gae_lambda = gae_lambda
        self.clip_ratio = ppo_epsilon
        self.entropy_coef = entropy_coef
        self.batch_size = batch_size
        self.eta_threshold = channel_threshold
        self.tau_threshold = tau_threshold
        self.max_noma_users = max_noma_users  # B value in the paper
        
        # Input size: buffer info + timing info + channel info + last reward
        self.input_size = 5 * num_users + 1
        
        # Build the actor and critic networks
        self.actor = self._build_actor_network()
        self.critic = self._build_critic_network()
        
        # Optimizers
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_learning_rate)
        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_learning_rate)
        
        # Experience buffer
        self.buffer = []
        
    def _build_actor_network(self):
        """Builds the actor network with branching architecture as described in the paper"""
        inputs = layers.Input(shape=(self.input_size,))
        
        # Shared layers
        x = layers.Dense(128, activation='relu')(inputs)
        x = layers.Dense(64, activation='relu')(x)
        
        # Branches for each user (one output per user)
        output_layers = []
        for _ in range(self.num_users):
            branch = layers.Dense(32, activation='relu')(x)
            # Output activation probability for each user
            branch_output = layers.Dense(1, activation='sigmoid')(branch)
            output_layers.append(branch_output)
        
        # Combine all branch outputs
        outputs = layers.Concatenate()(output_layers)
        
        return models.Model(inputs=inputs, outputs=outputs)
    
    def _build_critic_network(self):
        """Builds the critic network to estimate value function"""
        inputs = layers.Input(shape=(self.input_size,))
        x = layers.Dense(128, activation='relu')(inputs)
        x = layers.Dense(64, activation='relu')(x)
        value = layers.Dense(1, activation=None)(x)
        
        return models.Model(inputs=inputs, outputs=value)
    
    def get_edf_prior(self, buffer_info):
        """
        Earliest Deadline First prior
        Returns scheduling priority based on deadlines
        """
        # Extract head-of-line delays from buffer_info
        hol_delays = buffer_info[:self.num_users]  # Assuming first values are head-of-line delays
        
        # Sort users by increasing deadline (smaller deadline = higher priority)
        sorted_indices = np.argsort(hol_delays)
        
        # Initialize prior with zeros
        prior = np.zeros(self.num_users)
        
        # Set priority 1 for the B users with the smallest deadline
        for i in range(min(self.max_noma_users, len(sorted_indices))):
            if hol_delays[sorted_indices[i]] > 0:  # Only if there's a packet
                prior[sorted_indices[i]] = 1.0
                
        return prior
    
    def get_channel_prior(self, channel_info, channel_age):
        """
        Channel quality prior
        Returns priority based on channel conditions
        """
        # Extract channel quality and age
        channel_gains = channel_info[:self.num_users]
        ages = channel_age[:self.num_users]
        
        # Initialize prior with ones
        prior = np.ones(self.num_users)
        
        # Set to 0 for poor channels or outdated channel information
        for i in range(self.num_users):
            if channel_gains[i] <= self.eta_threshold and ages[i] <= self.tau_threshold:
                prior[i] = 0.0
                
        return prior
    
    def combine_priors(self, edf_prior, channel_prior):
        """Combine EDF and channel priors element-wise"""
        return edf_prior * channel_prior
    
    def get_policy_with_prior(self, agent_state):
        """
        Apply Bayesian policy using prior knowledge
        q(a|A; θπ) ∝ π(a|A; θπ) ⊙ f(a;A)
        """
        # Parse state components
        buffer_info = agent_state[:self.num_users * 2]  # Buffer info
        timing_info = agent_state[self.num_users * 2:self.num_users * 4]  # τp, τa, τs
        channel_info = agent_state[self.num_users * 4:self.num_users * 5]  # η
        
        # Get policy from neural network
        raw_policy = self.actor(np.array([agent_state])).numpy()[0]
        
        # Get EDF prior
        edf_prior = self.get_edf_prior(buffer_info)
        
        # Get channel prior
        channel_prior = self.get_channel_prior(channel_info, timing_info[:self.num_users])
        
        # Combine priors
        combined_prior = self.combine_priors(edf_prior, channel_prior)
        
        # Apply Bayesian update (element-wise multiplication)
        posterior_policy = raw_policy * combined_prior
        
        # Normalize if needed
        if np.sum(posterior_policy) > 0:
            posterior_policy = posterior_policy / np.sum(posterior_policy)
            
        return posterior_policy
    
    def choose_action(self, agent_state, deterministic=False):
        """Choose action based on policy with prior knowledge"""
        policy = self.get_policy_with_prior(agent_state)
        
        if deterministic:
            # Select users with highest probability up to max_noma_users
            sorted_indices = np.argsort(-policy)
            action = np.zeros(self.num_users)
            count = 0
            for idx in sorted_indices:
                if policy[idx] > 0.5 and count < self.max_noma_users:
                    action[idx] = 1
                    count += 1
        else:
            # Sample action for each user independently
            action = np.zeros(self.num_users)
            for i in range(self.num_users):
                if random.random() < policy[i]:
                    action[i] = 1
            
            # Ensure we don't exceed max_noma_users
            if np.sum(action) > self.max_noma_users:
                # If too many users selected, keep only the top max_noma_users
                indices = np.argsort(-policy)
                action = np.zeros(self.num_users)
                for i in range(self.max_noma_users):
                    action[indices[i]] = 1
                    
        return action
    
    def store_experience(self, state, action, reward, next_state, done):
        """Store experience in replay buffer"""
        self.buffer.append((state, action, reward, next_state, done))
    
    def compute_gae(self, rewards, values, next_values, dones):
        """Compute Generalized Advantage Estimation"""
        advantages = np.zeros_like(rewards)
        last_advantage = 0
        
        for t in reversed(range(len(rewards))):
            # If episode is done, use reward as terminal value
            if dones[t]:
                delta = rewards[t] - values[t]
            else:
                delta = rewards[t] + self.gamma * next_values[t] - values[t]
            
            # Recursive update of advantage
            advantages[t] = delta + self.gamma * self.gae_lambda * last_advantage * (1 - dones[t])
            last_advantage = advantages[t]
            
        # Returns are advantage + value
        returns = advantages + values
        
        return returns, advantages
    
    def train(self):
        """Train the agent using PPO"""
        if len(self.buffer) < self.batch_size:
            return
        
        # Sample batch from buffer
        batch = random.sample(self.buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = np.array(states)
        actions = np.array(actions)
        rewards = np.array(rewards)
        next_states = np.array(next_states)
        dones = np.array(dones)
        
        # Get values for states and next states
        values = self.critic(states).numpy().flatten()
        next_values = self.critic(next_states).numpy().flatten()
        
        # Compute returns and advantages
        returns, advantages = self.compute_gae(rewards, values, next_values, dones)
        
        # Normalize advantages
        advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-8)
        
        # Get old action probabilities
        old_probs = np.zeros((self.batch_size, self.num_users))
        for i in range(self.batch_size):
            policy = self.get_policy_with_prior(states[i])
            old_probs[i] = policy
        
        # PPO training loop
        for _ in range(5):  # Multiple epochs of training
            indices = np.arange(self.batch_size)
            np.random.shuffle(indices)
            
            for start in range(0, self.batch_size, 32):
                end = start + 32
                if end <= self.batch_size:
                    idx = indices[start:end]
                    
                    with tf.GradientTape() as actor_tape, tf.GradientTape() as critic_tape:
                        # Get current policy
                        current_policy = self.actor(states[idx])
                        
                        # Apply prior
                        modified_policy = tf.zeros_like(current_policy)
                        for i, idx_val in enumerate(idx):
                            prior = tf.convert_to_tensor(self.get_channel_prior(
                                states[idx_val][self.num_users*4:self.num_users*5],
                                states[idx_val][self.num_users*2:self.num_users*3]
                            ) * self.get_edf_prior(states[idx_val][:self.num_users*2]))
                            modified_policy[i] = current_policy[i] * prior
                        
                        # Compute ratio
                        ratio = tf.reduce_prod(
                            actions[idx] * modified_policy + (1 - actions[idx]) * (1 - modified_policy), axis=1
                        ) / tf.reduce_prod(
                            actions[idx] * old_probs[idx] + (1 - actions[idx]) * (1 - old_probs[idx]), axis=1
                        )
                        
                        # Compute surrogate losses
                        surrogate1 = ratio * advantages[idx]
                        surrogate2 = tf.clip_by_value(ratio, 1-self.clip_ratio, 1+self.clip_ratio) * advantages[idx]
                        
                        # Actor loss
                        actor_loss = -tf.reduce_mean(tf.minimum(surrogate1, surrogate2))
                        
                        # Add entropy bonus
                        entropy = -tf.reduce_mean(
                            modified_policy * tf.math.log(modified_policy + 1e-10) + 
                            (1 - modified_policy) * tf.math.log(1 - modified_policy + 1e-10)
                        )
                        actor_loss = actor_loss - self.entropy_coef * entropy
                        
                        # Critic loss (MSE)
                        value_preds = self.critic(states[idx])
                        critic_loss = tf.reduce_mean(tf.square(returns[idx] - value_preds))
                    
                    # Compute gradients and apply updates
                    actor_grads = actor_tape.gradient(actor_loss, self.actor.trainable_variables)
                    critic_grads = critic_tape.gradient(critic_loss, self.critic.trainable_variables)
                    
                    self.actor_optimizer.apply_gradients(zip(actor_grads, self.actor.trainable_variables))
                    self.critic_optimizer.apply_gradients(zip(critic_grads, self.critic.trainable_variables))
        
        # Clear buffer after training
        self.buffer = []

# Simplified example of usage
def preprocess_state(buffer_status, timing_info, channel_info, last_reward):
    """
    Preprocess state information into the format expected by the agent.
    
    Args:
        buffer_status: Buffer information (B)
        timing_info: Timing information (τp, τa, τs)
        channel_info: Channel gains (η)
        last_reward: Last received reward r(t-1)
        
    Returns:
        Concatenated and preprocessed state vector
    """
    # Concatenate all information as described in the paper
    return np.concatenate([buffer_status, timing_info, channel_info, [last_reward]])

def example_usage():
    # Example parameters
    num_users = 10
    max_buffer_size = 5
    
    # Initialize the NOMA-PPO agent
    agent = NOMA_PPO(
        num_users=num_users,
        max_buffer_size=max_buffer_size,
        discount_factor=0.99,
        gae_lambda=0.95,
        ppo_epsilon=0.2,
        actor_learning_rate=0.0003,
        critic_learning_rate=0.001,
        entropy_coef=0.01,
        batch_size=64,
        channel_threshold=0.1,
        tau_threshold=5,
        max_noma_users=4
    )
    
    # Simulated environment interaction
    for episode in range(10):
        # Example state (would come from environment in real implementation)
        buffer_status = np.random.randint(0, 2, size=num_users*2)  # Buffer info
        timing_info = np.random.rand(num_users*2)  # Timing info
        channel_info = np.random.rand(num_users)  # Channel gains
        last_reward = 0
        
        state = preprocess_state(buffer_status, timing_info, channel_info, last_reward)
        done = False
        episode_reward = 0
        
        while not done:
            # Choose action based on current state
            action = agent.choose_action(state)
            
            # Execute action in environment (simplified)
            # In real implementation, this would interact with a simulator
            reward = np.sum(action) * 0.5  # Simplified reward
            
            # Get next state (simplified)
            next_buffer_status = np.random.randint(0, 2, size=num_users*2)
            next_timing_info = np.random.rand(num_users*2)
            next_channel_info = np.random.rand(num_users)
            next_state = preprocess_state(next_buffer_status, next_timing_info, next_channel_info, reward)
            
            # Store experience
            agent.store_experience(state, action, reward, next_state, False)
            
            # Update state
            state = next_state
            episode_reward += reward
            
            # Check if episode is done (simplified)
            done = np.random.random() < 0.05  # 5% chance of episode ending
        
        # Train the agent
        agent.train()
        
        print(f"Episode {episode+1}, Total Reward: {episode_reward}")

if __name__ == "__main__":
    example_usage()

In [None]:
import numpy as np
import tensorflow as tf
import random
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from collections import deque

class NOMA_PPO:
    # [All existing code remains the same]
    
    # Add this method to collect metrics
    def get_metrics(self):
        """Return current metrics for visualization"""
        return {
            "buffer_size": len(self.buffer),
            "actor_loss": getattr(self, "last_actor_loss", 0),
            "critic_loss": getattr(self, "last_critic_loss", 0),
            "entropy": getattr(self, "last_entropy", 0),
        }
    
    # Modify the train method to store metrics
    def train(self):
        """Train the agent using PPO"""
        if len(self.buffer) < self.batch_size:
            return {}
        
        # [Keep the existing code here]
        
        # Track losses and metrics
        actor_losses = []
        critic_losses = []
        entropy_values = []
        ratios = []
        
        # PPO training loop
        for epoch in range(5):  # Multiple epochs of training
            indices = np.arange(self.batch_size)
            np.random.shuffle(indices)
            
            for start in range(0, self.batch_size, 32):
                end = start + 32
                if end <= self.batch_size:
                    idx = indices[start:end]
                    
                    with tf.GradientTape() as actor_tape, tf.GradientTape() as critic_tape:
                        # [Keep the existing code here]
                        
                        # Store metrics
                        actor_losses.append(actor_loss.numpy())
                        critic_losses.append(critic_loss.numpy())
                        entropy_values.append(entropy.numpy())
                        ratios.append(tf.reduce_mean(ratio).numpy())
                    
                    # [Keep the existing code here]
        
        # Store metrics for visualization
        self.last_actor_loss = np.mean(actor_losses)
        self.last_critic_loss = np.mean(critic_losses)
        self.last_entropy = np.mean(entropy_values)
        self.last_ratio = np.mean(ratios)
        
        # Clear buffer after training
        self.buffer = []
        
        # Return metrics for visualization
        return {
            "actor_loss": self.last_actor_loss,
            "critic_loss": self.last_critic_loss,
            "entropy": self.last_entropy,
            "ratio": self.last_ratio
        }

# Modify the example usage to include visualization
def improved_example_usage():
    # Example parameters
    num_users = 10
    max_buffer_size = 5
    num_episodes = 50  # Increased for better visualization
    
    # Initialize the NOMA-PPO agent
    agent = NOMA_PPO(
        num_users=num_users,
        max_buffer_size=max_buffer_size,
        discount_factor=0.99,
        gae_lambda=0.95,
        ppo_epsilon=0.2,
        actor_learning_rate=0.0003,
        critic_learning_rate=0.001,
        entropy_coef=0.01,
        batch_size=64,
        channel_threshold=0.1,
        tau_threshold=5,
        max_noma_users=4
    )
    
    # Metrics for plotting
    rewards_history = []
    avg_rewards_history = []
    actor_loss_history = []
    critic_loss_history = []
    entropy_history = []
    scheduled_users_history = []
    
    # Simulated environment interaction
    for episode in range(num_episodes):
        episode_rewards = []
        episode_scheduled_users = []
        
        # Example state (would come from environment in real implementation)
        buffer_status = np.random.randint(0, 2, size=num_users*2).astype(np.float32)  # Buffer info
        timing_info = np.random.rand(num_users*2).astype(np.float32)  # Timing info
        channel_info = np.random.rand(num_users).astype(np.float32)  # Channel gains
        last_reward = 0.0
        
        state = preprocess_state(buffer_status, timing_info, channel_info, last_reward)
        done = False
        episode_reward = 0
        
        step = 0
        max_steps = 20  # Limit steps per episode
        
        while not done and step < max_steps:
            # Choose action based on current state
            action = agent.choose_action(state)
            
            # Track number of scheduled users
            num_scheduled = np.sum(action)
            episode_scheduled_users.append(num_scheduled)
            
            # Execute action in environment (simplified)
            # In real implementation, this would interact with a simulator
            reward = np.sum(action) * 0.5  # Simplified reward
            episode_rewards.append(reward)
            
            # Add some randomness to model varying channel conditions
            channel_quality = np.random.rand()
            if channel_quality < 0.2:  # 20% chance of bad channel
                reward *= 0.5  # Reduced reward due to poor channel
            
            # Get next state (simplified)
            next_buffer_status = np.random.randint(0, 2, size=num_users*2).astype(np.float32)
            next_timing_info = np.random.rand(num_users*2).astype(np.float32)
            next_channel_info = np.random.rand(num_users).astype(np.float32)
            next_state = preprocess_state(next_buffer_status, next_timing_info, next_channel_info, reward)
            
            # Store experience
            agent.store_experience(state, action, reward, next_state, False)
            
            # Update state
            state = next_state
            episode_reward += reward
            
            # Check if episode is done (simplified)
            done = np.random.random() < 0.05  # 5% chance of episode ending
            step += 1
        
        # Store episode reward
        rewards_history.append(episode_reward)
        
        # Store average scheduled users
        scheduled_users_history.append(np.mean(episode_scheduled_users))
        
        # Calculate moving average
        if len(rewards_history) >= 10:
            avg_reward = np.mean(rewards_history[-10:])
        else:
            avg_reward = np.mean(rewards_history)
        avg_rewards_history.append(avg_reward)
        
        # Train the agent if we have enough samples
        if len(agent.buffer) >= agent.batch_size:
            metrics = agent.train()
            
            # Store training metrics
            if metrics:
                actor_loss_history.append(metrics["actor_loss"])
                critic_loss_history.append(metrics["critic_loss"])
                entropy_history.append(metrics["entropy"])
        
        print(f"Episode {episode+1}, Reward: {episode_reward:.2f}, Avg Reward: {avg_reward:.2f}")
    
    # Create visualizations
    plot_training_results(
        rewards_history, 
        avg_rewards_history, 
        actor_loss_history, 
        critic_loss_history, 
        entropy_history,
        scheduled_users_history
    )

def plot_training_results(rewards, avg_rewards, actor_losses, critic_losses, entropies, scheduled_users):
    """Create various plots to visualize the agent's performance"""
    plt.figure(figsize=(20, 15))
    
    # Plot 1: Episode Rewards
    plt.subplot(3, 2, 1)
    plt.plot(rewards, label='Episode Reward')
    plt.plot(avg_rewards, label='Moving Average (10 episodes)', linewidth=2)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('Episode Rewards')
    plt.legend()
    plt.grid(True)
    
    # Plot 2: Actor Loss
    plt.subplot(3, 2, 2)
    if actor_losses:
        plt.plot(actor_losses)
        plt.xlabel('Training Update')
        plt.ylabel('Actor Loss')
        plt.title('Actor Network Loss')
        plt.grid(True)
    else:
        plt.text(0.5, 0.5, 'No training data yet', ha='center', va='center')
        plt.title('Actor Network Loss (No Data)')
    
    # Plot 3: Critic Loss
    plt.subplot(3, 2, 3)
    if critic_losses:
        plt.plot(critic_losses)
        plt.xlabel('Training Update')
        plt.ylabel('Critic Loss')
        plt.title('Critic Network Loss')
        plt.grid(True)
    else:
        plt.text(0.5, 0.5, 'No training data yet', ha='center', va='center')
        plt.title('Critic Network Loss (No Data)')
    
    # Plot 4: Entropy
    plt.subplot(3, 2, 4)
    if entropies:
        plt.plot(entropies)
        plt.xlabel('Training Update')
        plt.ylabel('Entropy')
        plt.title('Policy Entropy')
        plt.grid(True)
    else:
        plt.text(0.5, 0.5, 'No training data yet', ha='center', va='center')
        plt.title('Policy Entropy (No Data)')
    
    # Plot 5: Average Number of Scheduled Users
    plt.subplot(3, 2, 5)
    plt.plot(scheduled_users)
    plt.xlabel('Episode')
    plt.ylabel('Avg. Users Scheduled')
    plt.title('Average Number of Scheduled Users per Episode')
    plt.grid(True)
    
    # Plot 6: Distribution of Scheduled Users
    plt.subplot(3, 2, 6)
    plt.hist(scheduled_users, bins=range(0, 5), alpha=0.7, rwidth=0.8)
    plt.xlabel('Number of Users Scheduled')
    plt.ylabel('Frequency')
    plt.title('Distribution of Scheduled Users')
    plt.xticks(range(0, 5))
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('noma_ppo_results.png')
    plt.show()

if __name__ == "__main__":
    improved_example_usage()

In [None]:
import numpy as np
import tensorflow as tf
import random
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from collections import deque

class NOMA_PPO:
    def __init__(self, 
                 num_users=10,
                 max_buffer_size=5,
                 discount_factor=0.99,
                 gae_lambda=0.95,
                 ppo_epsilon=0.2,
                 actor_learning_rate=0.0003,
                 critic_learning_rate=0.001,
                 entropy_coef=0.01,
                 batch_size=64,
                 channel_threshold=0.1,
                 tau_threshold=5,
                 max_noma_users=4):
        """
        Initialize the NOMA_PPO agent.
        
        Args:
            num_users: Number of users in the system
            max_buffer_size: Maximum buffer size for each user
            discount_factor: Discount factor for future rewards
            gae_lambda: Lambda parameter for GAE
            ppo_epsilon: Clipping parameter for PPO
            actor_learning_rate: Learning rate for the actor network
            critic_learning_rate: Learning rate for the critic network
            entropy_coef: Coefficient for entropy term in loss function
            batch_size: Batch size for training
            channel_threshold: Threshold for channel quality
            tau_threshold: Threshold for channel age
            max_noma_users: Maximum number of users that can be scheduled simultaneously (B)
        """
        self.num_users = num_users
        self.max_buffer_size = max_buffer_size
        self.gamma = discount_factor
        self.gae_lambda = gae_lambda
        self.clip_ratio = ppo_epsilon
        self.entropy_coef = entropy_coef
        self.batch_size = batch_size
        self.eta_threshold = channel_threshold
        self.tau_threshold = tau_threshold
        self.max_noma_users = max_noma_users  # B value in the paper
        
        # Input size: buffer info + timing info + channel info + last reward
        self.input_size = 5 * num_users + 1
        
        # Build the actor and critic networks
        self.actor = self._build_actor_network()
        self.critic = self._build_critic_network()
        
        # Optimizers
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_learning_rate)
        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_learning_rate)
        
        # Experience buffer
        self.buffer = []
        
        # For tracking metrics
        self.last_actor_loss = 0
        self.last_critic_loss = 0
        self.last_entropy = 0
        self.last_ratio = 0
        
    def _build_actor_network(self):
        """Builds the actor network with branching architecture as described in the paper"""
        inputs = layers.Input(shape=(self.input_size,))
        
        # Shared layers
        x = layers.Dense(128, activation='relu')(inputs)
        x = layers.Dense(64, activation='relu')(x)
        
        # Branches for each user (one output per user)
        output_layers = []
        for _ in range(self.num_users):
            branch = layers.Dense(32, activation='relu')(x)
            # Output activation probability for each user
            branch_output = layers.Dense(1, activation='sigmoid')(branch)
            output_layers.append(branch_output)
        
        # Combine all branch outputs
        outputs = layers.Concatenate()(output_layers)
        
        return models.Model(inputs=inputs, outputs=outputs)
    
    def _build_critic_network(self):
        """Builds the critic network to estimate value function"""
        inputs = layers.Input(shape=(self.input_size,))
        x = layers.Dense(128, activation='relu')(inputs)
        x = layers.Dense(64, activation='relu')(x)
        value = layers.Dense(1, activation=None)(x)
        
        return models.Model(inputs=inputs, outputs=value)
    
    def get_edf_prior(self, buffer_info):
        """
        Earliest Deadline First prior
        Returns scheduling priority based on deadlines
        """
        # Extract head-of-line delays from buffer_info
        hol_delays = buffer_info[:self.num_users]  # Assuming first values are head-of-line delays
        
        # Sort users by increasing deadline (smaller deadline = higher priority)
        sorted_indices = np.argsort(hol_delays)
        
        # Initialize prior with zeros
        prior = np.zeros(self.num_users, dtype=np.float32)
        
        # Set priority 1 for the B users with the smallest deadline
        for i in range(min(self.max_noma_users, len(sorted_indices))):
            if hol_delays[sorted_indices[i]] > 0:  # Only if there's a packet
                prior[sorted_indices[i]] = 1.0
                
        return prior
    
    def get_channel_prior(self, channel_info, channel_age):
        """
        Channel quality prior
        Returns priority based on channel conditions
        """
        # Extract channel quality and age
        channel_gains = channel_info[:self.num_users]
        ages = channel_age[:self.num_users]
        
        # Initialize prior with ones
        prior = np.ones(self.num_users, dtype=np.float32)
        
        # Set to 0 for poor channels or outdated channel information
        for i in range(self.num_users):
            if channel_gains[i] <= self.eta_threshold and ages[i] <= self.tau_threshold:
                prior[i] = 0.0
                
        return prior
    
    def combine_priors(self, edf_prior, channel_prior):
        """Combine EDF and channel priors element-wise"""
        return edf_prior * channel_prior
    
    def get_policy_with_prior(self, agent_state):
        """
        Apply Bayesian policy using prior knowledge
        q(a|A; θπ) ∝ π(a|A; θπ) ⊙ f(a;A)
        """
        # Parse state components
        buffer_info = agent_state[:self.num_users * 2]  # Buffer info
        timing_info = agent_state[self.num_users * 2:self.num_users * 4]  # τp, τa, τs
        channel_info = agent_state[self.num_users * 4:self.num_users * 5]  # η
        
        # Get policy from neural network
        raw_policy = self.actor(np.array([agent_state], dtype=np.float32)).numpy()[0]
        
        # Get EDF prior
        edf_prior = self.get_edf_prior(buffer_info)
        
        # Get channel prior
        channel_prior = self.get_channel_prior(channel_info, timing_info[:self.num_users])
        
        # Combine priors
        combined_prior = self.combine_priors(edf_prior, channel_prior)
        
        # Apply Bayesian update (element-wise multiplication)
        posterior_policy = raw_policy * combined_prior
        
        # Normalize if needed
        if np.sum(posterior_policy) > 0:
            posterior_policy = posterior_policy / np.sum(posterior_policy)
            
        return posterior_policy
    
    def choose_action(self, agent_state, deterministic=False):
        """Choose action based on policy with prior knowledge"""
        policy = self.get_policy_with_prior(agent_state)
        
        if deterministic:
            # Select users with highest probability up to max_noma_users
            sorted_indices = np.argsort(-policy)
            action = np.zeros(self.num_users, dtype=np.float32)
            count = 0
            for idx in sorted_indices:
                if policy[idx] > 0.5 and count < self.max_noma_users:
                    action[idx] = 1
                    count += 1
        else:
            # Sample action for each user independently
            action = np.zeros(self.num_users, dtype=np.float32)
            for i in range(self.num_users):
                if random.random() < policy[i]:
                    action[i] = 1
            
            # Ensure we don't exceed max_noma_users
            if np.sum(action) > self.max_noma_users:
                # If too many users selected, keep only the top max_noma_users
                indices = np.argsort(-policy)
                action = np.zeros(self.num_users, dtype=np.float32)
                for i in range(self.max_noma_users):
                    action[indices[i]] = 1
                    
        return action
    
    def store_experience(self, state, action, reward, next_state, done):
        """Store experience in replay buffer"""
        self.buffer.append((state, action, reward, next_state, done))
    
    def compute_gae(self, rewards, values, next_values, dones):
        """Compute Generalized Advantage Estimation"""
        advantages = np.zeros_like(rewards, dtype=np.float32)
        last_advantage = 0
        
        for t in reversed(range(len(rewards))):
            # If episode is done, use reward as terminal value
            if dones[t]:
                delta = rewards[t] - values[t]
            else:
                delta = rewards[t] + self.gamma * next_values[t] - values[t]
            
            # Recursive update of advantage
            advantages[t] = delta + self.gamma * self.gae_lambda * last_advantage * (1 - dones[t])
            last_advantage = advantages[t]
            
        # Returns are advantage + value
        returns = advantages + values
        
        return returns, advantages
    
    def get_metrics(self):
        """Return current metrics for visualization"""
        return {
            "buffer_size": len(self.buffer),
            "actor_loss": self.last_actor_loss,
            "critic_loss": self.last_critic_loss,
            "entropy": self.last_entropy,
        }
    
    def train(self):
        """Train the agent using PPO"""
        if len(self.buffer) < self.batch_size:
            return {}
        
        # Sample batch from buffer
        batch = random.sample(self.buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = np.array(states, dtype=np.float32)
        actions = np.array(actions, dtype=np.float32)
        rewards = np.array(rewards, dtype=np.float32)
        next_states = np.array(next_states, dtype=np.float32)
        dones = np.array(dones, dtype=np.float32)
        
        # Get values for states and next states
        values = self.critic(states).numpy().flatten()
        next_values = self.critic(next_states).numpy().flatten()
        
        # Compute returns and advantages
        returns, advantages = self.compute_gae(rewards, values, next_values, dones)
        
        # Normalize advantages
        advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-8)
        
        # Get old action probabilities
        old_probs = np.zeros((self.batch_size, self.num_users), dtype=np.float32)
        for i in range(self.batch_size):
            policy = self.get_policy_with_prior(states[i])
            old_probs[i] = policy
        
        # Track losses and metrics
        actor_losses = []
        critic_losses = []
        entropy_values = []
        ratios = []
        
        # PPO training loop
        for _ in range(5):  # Multiple epochs of training
            indices = np.arange(self.batch_size)
            np.random.shuffle(indices)
            
            for start in range(0, self.batch_size, 32):
                end = start + 32
                if end <= self.batch_size:
                    idx = indices[start:end]
                    
                    with tf.GradientTape() as actor_tape, tf.GradientTape() as critic_tape:
                        # Get current policy
                        current_policy = self.actor(tf.convert_to_tensor(states[idx], dtype=tf.float32))
                        
                        # Apply prior (simplified implementation for demonstration)
                        modified_policy = current_policy  # Placeholder for actual implementation
                        
                        # Add small epsilon to prevent division by zero
                        epsilon = 1e-8
                        
                        # Create masks for actions
                        action_mask = actions[idx]
                        inverse_action_mask = 1.0 - action_mask
                        
                        # Compute policy components
                        policy_component1 = action_mask * modified_policy
                        policy_component2 = inverse_action_mask * (1.0 - modified_policy)
                        new_policy_probs = policy_component1 + policy_component2
                        
                        # Same for old policy
                        old_policy_component1 = action_mask * tf.convert_to_tensor(old_probs[idx], dtype=tf.float32)
                        old_policy_component2 = inverse_action_mask * (1.0 - tf.convert_to_tensor(old_probs[idx], dtype=tf.float32))
                        old_policy_probs = old_policy_component1 + old_policy_component2
                        
                        # Compute ratio 
                        new_policy_prob = tf.reduce_prod(new_policy_probs + epsilon, axis=1)
                        old_policy_prob = tf.reduce_prod(old_policy_probs + epsilon, axis=1)
                        ratio = new_policy_prob / old_policy_prob
                        
                        # Compute surrogate losses
                        advantages_tensor = tf.convert_to_tensor(advantages[idx], dtype=tf.float32)
                        surrogate1 = ratio * advantages_tensor
                        surrogate2 = tf.clip_by_value(ratio, 1-self.clip_ratio, 1+self.clip_ratio) * advantages_tensor
                        
                        # Actor loss
                        actor_loss = -tf.reduce_mean(tf.minimum(surrogate1, surrogate2))
                        
                        # Add entropy bonus
                        entropy = -tf.reduce_mean(
                            modified_policy * tf.math.log(modified_policy + epsilon) + 
                            (1 - modified_policy) * tf.math.log(1 - modified_policy + epsilon)
                        )
                        actor_loss = actor_loss - self.entropy_coef * entropy
                        
                        # Critic loss
                        returns_tensor = tf.convert_to_tensor(returns[idx], dtype=tf.float32)
                        value_preds = self.critic(states[idx])
                        critic_loss = tf.reduce_mean(tf.square(returns_tensor - value_preds))
                        
                        # Store metrics
                        actor_losses.append(actor_loss.numpy())
                        critic_losses.append(critic_loss.numpy())
                        entropy_values.append(entropy.numpy())
                        ratios.append(tf.reduce_mean(ratio).numpy())
                    
                    # Compute gradients and apply updates
                    actor_grads = actor_tape.gradient(actor_loss, self.actor.trainable_variables)
                    critic_grads = critic_tape.gradient(critic_loss, self.critic.trainable_variables)
                    
                    self.actor_optimizer.apply_gradients(zip(actor_grads, self.actor.trainable_variables))
                    self.critic_optimizer.apply_gradients(zip(critic_grads, self.critic.trainable_variables))
        
        # Store metrics for visualization
        self.last_actor_loss = np.mean(actor_losses) if actor_losses else 0
        self.last_critic_loss = np.mean(critic_losses) if critic_losses else 0
        self.last_entropy = np.mean(entropy_values) if entropy_values else 0
        self.last_ratio = np.mean(ratios) if ratios else 0
        
        # Clear buffer after training
        self.buffer = []
        
        # Return metrics for visualization
        return {
            "actor_loss": self.last_actor_loss,
            "critic_loss": self.last_critic_loss,
            "entropy": self.last_entropy,
            "ratio": self.last_ratio
        }

def preprocess_state(buffer_status, timing_info, channel_info, last_reward):
    """
    Preprocess state information into the format expected by the agent.
    
    Args:
        buffer_status: Buffer information (B)
        timing_info: Timing information (τp, τa, τs)
        channel_info: Channel gains (η)
        last_reward: Last received reward r(t-1)
        
    Returns:
        Concatenated and preprocessed state vector
    """
    # Concatenate all information as described in the paper
    return np.concatenate([buffer_status, timing_info, channel_info, [last_reward]], dtype=np.float32)

def improved_example_usage():
    # Example parameters
    num_users = 10
    max_buffer_size = 5
    num_episodes = 50  # Increased for better visualization
    
    # Initialize the NOMA-PPO agent
    agent = NOMA_PPO(
        num_users=num_users,
        max_buffer_size=max_buffer_size,
        discount_factor=0.99,
        gae_lambda=0.95,
        ppo_epsilon=0.2,
        actor_learning_rate=0.0003,
        critic_learning_rate=0.001,
        entropy_coef=0.01,
        batch_size=64,
        channel_threshold=0.1,
        tau_threshold=5,
        max_noma_users=4
    )
    
    # Metrics for plotting
    rewards_history = []
    avg_rewards_history = []
    