In [None]:
import numpy as np
import matplotlib.pyplot as plt
from models.system_model import SystemModel
from models.dqn import DQN

# System Parameters
NumberOfUAVs = 3
NumberOfUsers = 6
Bandwidth = 30  # kHz
F_c = 2  # GHz
R_require = 0.15  # kb
NoisePower = 10**(-9) * 10000
MAXUserspeed = 0.1/3  # m/s
PMAXUserspeed = 0.4/3  # m/s

def main():
    Episodes_number = 300
    T = 120  # Service duration and time steps
    T_AS = np.arange(0, T, 40)
    env = SystemModel()
    agent = DQN(NumberOfUAVs, NumberOfUsers)
    Epsilon = 0.9
    datarate_seq = np.zeros(T)  # Store datarate data
    WorstuserRate_seq = np.zeros(T)
    Through_put_seq = np.zeros(Episodes_number)
    Worstuser_TP_seq = np.zeros(Episodes_number)
    UAV_trajectory = []
    User_trajectory = []

    for episode in range(Episodes_number):
        env.Reset_position()
        Epsilon -= 0.9 / (Episodes_number - 50)  # Accelerate convergence
        p = 0

        for t in range(T):
            if t in T_AS:
                User_AS_List = agent.User_association(env.PositionOfUAVs, env.PositionOfUsers, 
                                                    NumberOfUAVs, NumberOfUsers)

            if episode == Episodes_number-1:
                UAV_trajectory.append(env.PositionOfUAVs.values.copy())
                User_trajectory.append(env.PositionOfUsers.values.copy())

            for UAV in range(NumberOfUAVs):
                # Calculate channel gains
                Distence_CG = env.Get_Distance_U2K(env.PositionOfUAVs, env.PositionOfUsers, 
                                                  NumberOfUAVs, NumberOfUsers)
                PL_for_CG = env.Get_Propergation_Loss(Distence_CG, env.PositionOfUAVs, 
                                                     NumberOfUAVs, NumberOfUsers, F_c)
                CG = env.Get_Channel_Gain_NOMA(NumberOfUAVs, NumberOfUsers, PL_for_CG, 
                                              User_AS_List, NoisePower)
                Eq_CG = env.Get_Channel_Gain_NOMA(NumberOfUAVs, NumberOfUsers, PL_for_CG, 
                                                 User_AS_List, NoisePower)

                # Generate current state and choose action
                State = env.Create_state_Noposition(UAV, User_AS_List, CG)
                action_name = agent.Choose_action(State, Epsilon, UAV, User_AS_List)
                env.take_action_NOMA(action_name, UAV, User_AS_List, Eq_CG)

                # Calculate reward
                Distence = env.Get_Distance_U2K(env.PositionOfUAVs, env.PositionOfUsers, 
                                              NumberOfUAVs, NumberOfUsers)
                P_L = env.Get_Propergation_Loss(Distence, env.PositionOfUAVs, 
                                              NumberOfUAVs, NumberOfUsers, F_c)
                SINR = env.Get_SINR_NNOMA(NumberOfUAVs, NumberOfUsers, P_L, 
                                        User_AS_List, Eq_CG, NoisePower)
                DataRate, SumRate, WorstuserRate = env.Calcullate_Datarate(SINR, 
                                                                          NumberOfUsers, Bandwidth)

                Reward = SumRate
                if WorstuserRate < R_require:
                    Reward = Reward/2
                    p += 1

                # Get next state
                CG_next = env.Get_Channel_Gain_NOMA(NumberOfUAVs, NumberOfUsers, P_L, 
                                                   User_AS_List, NoisePower)
                Next_state = env.Create_state_Noposition(UAV, User_AS_List, CG_next)

                # Store experience
                agent.remember(State[0], action_name, Next_state[0], Reward)
                agent.train()

                # Update user positions
                env.User_randomMove(MAXUserspeed, NumberOfUsers)
                env.User_Purposive_Move_6(PMAXUserspeed)

                if UAV == (NumberOfUAVs-1):
                    Rate_during_t = SumRate
                    datarate_seq[t] = Rate_during_t
                    WorstuserRate_seq[t] = WorstuserRate

        # Calculate throughput
        Through_put = np.sum(datarate_seq)
        Worstuser_TP = np.sum(WorstuserRate_seq)
        Through_put_seq[episode] = Through_put
        Worstuser_TP_seq[episode] = Worstuser_TP

        print(f'Episode={episode}, Epsilon={Epsilon:.4f}, Punishment={p}, Through_put={Through_put:.2f}')

    # Save results
    save_and_plot_results(Through_put_seq, Worstuser_TP_seq, datarate_seq,
                         env.PositionOfUsers, env.PositionOfUAVs, 
                         UAV_trajectory, User_trajectory)

def save_and_plot_results(Through_put_seq, Worstuser_TP_seq, datarate_seq, 
                         final_user_positions, final_uav_positions,
                         uav_trajectory, user_trajectory):
    """Save and plot training results"""
    # Save numpy arrays
    np.save("results/Through_put_NOMA.npy", Through_put_seq)
    np.save("results/WorstUser_Through_put_NOMA.npy", Worstuser_TP_seq)
    np.save("results/Total_Data_Rate_NOMA.npy", datarate_seq)
    np.save("results/PositionOfUsers_end_NOMA.npy", final_user_positions)
    np.save("results/PositionOfUAVs_end_NOMA.npy", final_uav_positions)
    np.save('results/UAV_trajectory.npy', uav_trajectory)
    np.save('results/User_trajectory.npy', user_trajectory)

    # Plot results
    plot_training_results(Through_put_seq, Worstuser_TP_seq, datarate_seq)

def plot_training_results(Through_put_seq, Worstuser_TP_seq, datarate_seq):
    """Plot training metrics"""
    Episodes_number = len(Through_put_seq)
    T = len(datarate_seq)
    
    # Plot throughput
    plt.figure(1)
    x_axis = range(1, Episodes_number+1)
    plt.plot(x_axis, Through_put_seq)
    plt.xlabel('Episodes')
    plt.ylabel('Throughput')
    plt.grid(True)
    plt.savefig('results/Through_put_NOMA.png')
    plt.close()

    # Plot worst user throughput
    plt.figure(2)
    plt.plot(x_axis, Worstuser_TP_seq)
    plt.xlabel('Episodes')
    plt.ylabel('Throughput of Worst User')
    plt.grid(True)
    plt.savefig('results/WorstUser_Through_put_NOMA.png')
    plt.close()

    # Plot data rate
    plt.figure(3)
    x_axis = range(T)
    plt.plot(x_axis, datarate_seq)
    plt.xlabel('Time slots')
    plt.ylabel('Data Rate')
    plt.grid(True)
    plt.savefig('results/Total_Data_Rate_NOMA.png')
    plt.close()

if __name__ == '__main__':
    main()