In [None]:
import numpy as np
import matplotlib.pyplot as plt
from models.system_model import SystemModel
from models.dqn import DQN
from utils.helpers import create_directory_structure
import json

# Load configuration
config = {
    'service_zone_x': 500,
    'service_zone_y': 500,
    'height_limit': 150,
    'num_uavs': 3,
    'users_per_cell': 2,
    'frequency': 2,  # GHz
    'bandwidth': 30,  # kHz
    'r_require': 0.15,  # kb
    'uav_speed': 5,  # m/s
    'power_unit': 100 * 10000,  # 20mW
    'noise_power': 10**(-9) * 10000,
    'episodes': 300,
    'time_steps': 120,
    'clustering_interval': 40,
    'epsilon_start': 0.9,
    'epsilon_end': 0.05,
    'epsilon_decay': 200,
    'gamma': 0.99,
    'batch_size': 128
}

def train():
    """Main training loop"""
    # Create necessary directories
    create_directory_structure()
    
    # Initialize environment and agent
    env = SystemModel(config)
    state_dim = env.num_uavs * 3 + env.num_users
    agent = DQN(state_dim=state_dim, 
                num_uavs=env.num_uavs, 
                num_users=env.num_users)
    
    # Training metrics
    episode_rewards = []
    throughput_history = []
    worst_user_rates = []
    
    # Training loop
    for episode in range(config['episodes']):
        state = env.reset()
        episode_reward = 0
        epsilon = max(
            config['epsilon_end'],
            config['epsilon_start'] - episode / config['epsilon_decay']
        )
        
        for t in range(config['time_steps']):
            # Periodic user clustering
            if t % config['clustering_interval'] == 0:
                user_association = agent.cluster_users(env.uav_positions, env.user_positions)
            
            # UAV actions
            for uav_idx in range(env.num_uavs):
                # Get cluster size for current UAV
                cluster_size = len(np.where(user_association.iloc[0,:] == uav_idx)[0])
                
                # Choose and take action
                action = agent.choose_action(state, epsilon, uav_idx, user_association)
                power_allocation = agent.get_power_allocation(action, cluster_size)
                env.take_action(action, uav_idx, power_allocation)
                
                # Calculate reward
                sinr = env.calculate_sinr(user_association)
                rates, sum_rate, worst_rate = env.calculate_rates(sinr)
                
                # Apply penalty if QoS requirement not met
                reward = sum_rate
                if worst_rate < config['r_require']:
                    reward *= 0.5
                
                # Get next state and store experience
                next_state = env.get_state(uav_idx)
                agent.remember(state, action, next_state, reward, uav_idx, cluster_size)
                
                # Train agent
                agent.train(gamma=config['gamma'])
                
                state = next_state
                episode_reward += reward
            
            # Update metrics
            if t == config['time_steps'] - 1:
                throughput_history.append(sum_rate)
                worst_user_rates.append(worst_rate)
        
        episode_rewards.append(episode_reward)
        
        # Print progress
        if (episode + 1) % 10 == 0:
            print(f"Episode {episode + 1}/{config['episodes']}")
            print(f"Average Reward: {episode_reward/config['time_steps']:.2f}")
            print(f"Throughput: {sum_rate:.2f}")
            print(f"Worst User Rate: {worst_rate:.2f}")
            print("-" * 50)
    
    # Save results
    save_results(episode_rewards, throughput_history, worst_user_rates)

def save_results(episode_rewards, throughput_history, worst_user_rates):
    """Save training results and generate plots"""
    # Save metrics
    np.save("results/episode_rewards.npy", episode_rewards)
    np.save("results/throughput_history.npy", throughput_history)
    np.save("results/worst_user_rates.npy", worst_user_rates)

    # Plot episode rewards
    plt.figure(figsize=(10, 6))
    plt.plot(episode_rewards)
    plt.title('Episode Rewards over Time')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.grid(True)
    plt.savefig('results/episode_rewards.png')
    plt.close()

    # Plot throughput history
    plt.figure(figsize=(10, 6))
    plt.plot(throughput_history)
    plt.title('System Throughput over Episodes')
    plt.xlabel('Episode')
    plt.ylabel('Throughput (kb/s)')
    plt.grid(True)
    plt.savefig('results/throughput.png')
    plt.close()

    # Plot worst user rates
    plt.figure(figsize=(10, 6))
    plt.plot(worst_user_rates)
    plt.title('Worst User Rate over Episodes')
    plt.xlabel('Episode')
    plt.ylabel('Rate (kb/s)')
    plt.grid(True)
    plt.savefig('results/worst_user_rates.png')
    plt.close()

    # Save configuration
    with open('results/config.json', 'w') as f:
        json.dump(config, f, indent=4)

def evaluate(model_path="results/final_model.h5"):
    """Evaluate trained model performance"""
    env = SystemModel(config)
    state_dim = env.num_uavs * 3 + env.num_users
    agent = DQN(state_dim=state_dim, 
                num_uavs=env.num_uavs, 
                num_users=env.num_users)
    
    # Load trained model weights
    agent.model.load_weights(model_path)
    
    # Evaluation metrics
    throughputs = []
    worst_rates = []
    trajectories = []
    
    # Run evaluation episodes
    for episode in range(5):  # Run 5 evaluation episodes
        state = env.reset()
        episode_trajectory = []
        
        for t in range(config['time_steps']):
            if t % config['clustering_interval'] == 0:
                user_association = agent.cluster_users(env.uav_positions, env.user_positions)
            
            episode_trajectory.append({
                'uav_positions': env.uav_positions.copy(),
                'user_positions': env.user_positions.copy()
            })
            
            for uav_idx in range(env.num_uavs):
                cluster_size = len(np.where(user_association.iloc[0,:] == uav_idx)[0])
                
                # Choose best action (no exploration)
                action = agent.choose_action(state, epsilon=0, uav_idx=uav_idx,
                                          user_association=user_association)
                power_allocation = agent.get_power_allocation(action, cluster_size)
                env.take_action(action, uav_idx, power_allocation)
                
                sinr = env.calculate_sinr(user_association)
                rates, sum_rate, worst_rate = env.calculate_rates(sinr)
                state = env.get_state(uav_idx)
            
            if t == config['time_steps'] - 1:
                throughputs.append(sum_rate)
                worst_rates.append(worst_rate)
        
        trajectories.append(episode_trajectory)
    
    # Save evaluation results
    np.save("results/eval_throughputs.npy", throughputs)
    np.save("results/eval_worst_rates.npy", worst_rates)
    np.save("results/eval_trajectories.npy", trajectories)
    
    # Print evaluation metrics
    print("\nEvaluation Results:")
    print(f"Average Throughput: {np.mean(throughputs):.2f} ± {np.std(throughputs):.2f}")
    print(f"Average Worst User Rate: {np.mean(worst_rates):.2f} ± {np.std(worst_rates):.2f}")
    
    # Plot example trajectory
    plot_trajectory(trajectories[0])

def plot_trajectory(trajectory):
    """Plot UAV and user trajectories for visualization"""
    plt.figure(figsize=(12, 8))
    
    # Plot UAV trajectories
    for uav_idx in range(config['num_uavs']):
        uav_x = [step['uav_positions'].iloc[0, uav_idx] for step in trajectory]
        uav_y = [step['uav_positions'].iloc[1, uav_idx] for step in trajectory]
        uav_z = [step['uav_positions'].iloc[2, uav_idx] for step in trajectory]
        
        plt.plot(uav_x, uav_y, 'b-', alpha=0.5, label=f'UAV {uav_idx+1}' if uav_idx==0 else "")
        plt.plot(uav_x[0], uav_y[0], 'b^', label='UAV Start' if uav_idx==0 else "")
        plt.plot(uav_x[-1], uav_y[-1], 'bs', label='UAV End' if uav_idx==0 else "")
    
    # Plot user positions
    for user_idx in range(config['num_uavs'] * config['users_per_cell']):
        user_x = [step['user_positions'].iloc[0, user_idx] for step in trajectory]
        user_y = [step['user_positions'].iloc[1, user_idx] for step in trajectory]
        
        plt.plot(user_x, user_y, 'r:', alpha=0.3, label='User Path' if user_idx==0 else "")
        plt.plot(user_x[0], user_y[0], 'r^', label='User Start' if user_idx==0 else "")
        plt.plot(user_x[-1], user_y[-1], 'rs', label='User End' if user_idx==0 else "")
    
    plt.title('UAV and User Trajectories')
    plt.xlabel('X Position (m)')
    plt.ylabel('Y Position (m)')
    plt.grid(True)
    plt.legend()
    plt.savefig('results/trajectory.png')
    plt.close()

if __name__ == "__main__":
    # Train the model
    train()
    
    # Evaluate the trained model
    evaluate()