In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from ..config import EPISODES, TIME_STEPS

def create_directory_structure():
    """Create necessary directories for saving results"""
    directories = ['results', 'results/plots', 'results/data']
    for dir_name in dirs:
        os.makedirs(dir_name, exist_ok=True)

def save_results(results):
    """Save training results"""
    # Save numpy arrays
    np.save("results/data/throughput.npy", results['throughput'])
    np.save("results/data/worst_user_throughput.npy", results['worst_user_throughput'])
    np.save("results/data/datarate.npy", results['datarate'])
    np.save("results/data/final_uav_positions.npy", results['final_uav_positions'])
    np.save("results/data/final_user_positions.npy", results['final_user_positions'])
    np.save("results/data/uav_trajectory.npy", results['uav_trajectory'])
    np.save("results/data/user_trajectory.npy", results['user_trajectory'])

def plot_metrics(results):
    """Plot training metrics"""
    # Plot throughput
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, EPISODES+1), results['throughput'])
    plt.title('System Throughput over Episodes')
    plt.xlabel('Episode')
    plt.ylabel('Throughput')
    plt.grid(True)
    plt.savefig('results/plots/throughput.png')
    plt.close()

    # Plot worst user throughput
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, EPISODES+1), results['worst_user_throughput'])
    plt.title('Worst User Throughput over Episodes')
    plt.xlabel('Episode')
    plt.ylabel('Throughput')
    plt.grid(True)
    plt.savefig('results/plots/worst_user_throughput.png')
    plt.close()

    # Plot data rate
    plt.figure(figsize=(10, 6))
    plt.plot(range(TIME_STEPS), results['datarate'])
    plt.title('Data Rate over Time Steps')
    plt.xlabel('Time Step')
    plt.ylabel('Data Rate')
    plt.grid(True)
    plt.savefig('results/plots/datarate.png')
    plt.close()

def plot_trajectory(trajectory, episode):
    """Plot UAV and user trajectories"""
    plt.figure(figsize=(12, 8))
    
    # Plot UAV trajectories
    for uav_idx in range(trajectory['uav_trajectory'][0].shape[1]):
        uav_x = [step[:,uav_idx][0] for step in trajectory['uav_trajectory']]
        uav_y = [step[:,uav_idx][1] for step in trajectory['uav_trajectory']]
        plt.plot(uav_x, uav_y, 'b-', alpha=0.5, 
                label=f'UAV {uav_idx+1}' if uav_idx==0 else "")
        plt.plot(uav_x[0], uav_y[0], 'b^', 
                label='UAV Start' if uav_idx==0 else "")
        plt.plot(uav_x[-1], uav_y[-1], 'bs', 
                label='UAV End' if uav_idx==0 else "")
    
    # Plot user trajectories
    for user_idx in range(trajectory['user_trajectory'][0].shape[1]):
        user_x = [step[:,user_idx][0] for step in trajectory['user_trajectory']]
        user_y = [step[:,user_idx][1] for step in trajectory['user_trajectory']]
        plt.plot(user_x, user_y, 'r:', alpha=0.3, 
                label='User Path' if user_idx==0 else "")
        plt.plot(user_x[0], user_y[0], 'r^', 
                label='User Start' if user_idx==0 else "")
        plt.plot(user_x[-1], user_y[-1], 'rs', 
                label='User End' if user_idx==0 else "")
    
    plt.title(f'UAV and User Trajectories - Episode {episode}')
    plt.xlabel('X Position (m)')
    plt.ylabel('Y Position (m)')
    plt.grid(True)
    plt.legend()
    plt.savefig(f'results/plots/trajectory_episode_{episode}.png')
    plt.close()