In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import time
from caro_env import CaroEnv
from dqn_agent import DQNAgent
from pg_agent import PGAgent
from a2c_agent import A2CAgent

def calculate_moving_average(rewards, window=100):
    if len(rewards) < window:
        return rewards
    moving_avg = []
    for i in range(len(rewards)):
        start = max(0, i - window + 1)
        moving_avg.append(sum(rewards[start:i + 1]) / (i - start + 1))
    return moving_avg

def plot_training_progress(rewards, algorithm, episodes, save_path):
    try:
        plt.figure(figsize=(10, 6))
        plt.plot(rewards, alpha=0.3, label=f"{algorithm} Raw Rewards", color='gray')
        plt.plot(calculate_moving_average(rewards), label=f"{algorithm} Moving Avg", color='blue')
        plt.xlabel("Episode")
        plt.ylabel("Total Reward")
        plt.title(f"{algorithm} Training Progress ({episodes} Episodes, 5x5 Board)")
        plt.legend()
        plt.grid(True)
        plt.savefig(save_path)
        print(f"Plot saved to: {os.path.abspath(save_path)}")
        if os.path.exists(save_path):
            print(f"Verified: Plot file '{save_path}' exists.")
        else:
            print(f"Error: Plot file '{save_path}' was not created.")
        plt.show()
    except Exception as e:
        print(f"Error while plotting {algorithm}: {e}")
    finally:
        plt.close()

def evaluate_agent(agent, agent_type, env, episodes=100):
    total_rewards = []
    success_count = 0
    for _ in range(episodes):
        state = env.reset()
        total_reward = 0
        for _ in range(env.max_steps):
            valid_actions = env.get_valid_actions()
            if len(valid_actions) == 0:
                break
            if env.current_player == 1:
                if agent_type == 'DQN':
                    action = agent.act(state, epsilon=0, valid_actions=valid_actions)
                else:
                    action, _ = agent.act(state, valid_actions=valid_actions)
                state, reward, done, _ = env.step(action)
                total_reward += reward
                if done:
                    if reward > 0:
                        success_count += 1
                    break
            else:
                valid_actions = env.get_valid_actions()
                if len(valid_actions) == 0:
                    break
                action = np.random.choice(valid_actions)
                state, reward, done, _ = env.step(action)
                if done:
                    break
        total_rewards.append(total_reward)
    avg_reward = np.mean(total_rewards)
    success_rate = success_count / episodes
    return avg_reward, success_rate

def main():
    # Khởi tạo môi trường
    env = CaroEnv(verbose=False, board_size=5, win_length=3)
    state_shape = env.observation_space
    action_size = env.action_space.n

    print("=" * 50)
    print("BẮT ĐẦU TRAINING 5000 EPISODES CHO MỖI ALGORITHM")
    print("=" * 50)

    episodes = 5000
    print_interval = 200
    update_target_freq = 10

    # Huấn luyện DQN
    print("\n🤖 Bắt đầu huấn luyện DQN...")
    start_time = time.time()
    dqn_agent = DQNAgent(state_shape, action_size)
    rewards_dqn = []

    for e in range(episodes):
        state = env.reset()
        total_reward = 0
        for time_step in range(env.max_steps):
            valid_actions = env.get_valid_actions()
            if len(valid_actions) == 0:
                break
            if env.current_player == 1:
                action = dqn_agent.act(state, valid_actions=valid_actions)
                next_state, reward, done, _ = env.step(action)
                dqn_agent.remember(state, action, reward, next_state, done)
                state = next_state
                total_reward += reward
                if done:
                    break
            else:
                valid_actions = env.get_valid_actions()
                if len(valid_actions) == 0:
                    break
                action = np.random.choice(valid_actions)
                state, reward, done, _ = env.step(action)
                if done:
                    break
        rewards_dqn.append(total_reward)
        dqn_agent.train()
        if e % update_target_freq == 0:
            dqn_agent.update_target()
        if dqn_agent.epsilon > dqn_agent.epsilon_min:
            dqn_agent.epsilon *= dqn_agent.epsilon_decay
        if (e + 1) % print_interval == 0:
            avg_reward = np.mean(rewards_dqn[-print_interval:])
            print(f"DQN Episode: {e + 1}/{episodes}, Avg Reward: {avg_reward:.2f}, Epsilon: {dqn_agent.epsilon:.3f}")

    dqn_time = time.time() - start_time
    torch.save(dqn_agent.model.state_dict(), 'dqn_model_5x5.pth')
    print(f"✅ DQN training completed in {dqn_time:.2f} seconds!")
    plot_training_progress(rewards_dqn, "DQN", episodes, 'dqn_training_progress_5x5.png')

    # Huấn luyện Policy Gradient
    print("\n🎯 Bắt đầu huấn luyện Policy Gradient...")
    start_time = time.time()
    pg_agent = PGAgent(state_shape, action_size)
    rewards_pg = []

    for e in range(episodes):
        state = env.reset()
        states, actions, rewards_list = [], [], []
        total_reward = 0
        for time_step in range(env.max_steps):
            valid_actions = env.get_valid_actions()
            if len(valid_actions) == 0:
                break
            if env.current_player == 1:
                action, prob = pg_agent.act(state, valid_actions=valid_actions)
                if action not in valid_actions:
                    print(f"Invalid action {action}, choosing random")
                    action = np.random.choice(valid_actions)
                next_state, reward, done, _ = env.step(action)
                states.append(state)
                actions.append(action)
                rewards_list.append(reward)
                state = next_state
                total_reward += reward
                if done:
                    break
            else:
                valid_actions = env.get_valid_actions()
                if len(valid_actions) == 0:
                    break
                action = np.random.choice(valid_actions)
                state, reward, done, _ = env.step(action)
                if done:
                    break
        rewards_pg.append(total_reward)
        if len(states) > 0:
            pg_agent.train(states, actions, rewards_list)
        if (e + 1) % print_interval == 0:
            avg_reward = np.mean(rewards_pg[-print_interval:])
            print(f"PG Episode: {e + 1}/{episodes}, Avg Reward: {avg_reward:.2f}")

    pg_time = time.time() - start_time
    torch.save(pg_agent.policy.state_dict(), 'pg_model_5x5.pth')
    print(f"✅ PG training completed in {pg_time:.2f} seconds!")
    plot_training_progress(rewards_pg, "PG", episodes, 'pg_training_progress_5x5.png')

    # Huấn luyện A2C
    print("\n🎭 Bắt đầu huấn luyện A2C...")
    start_time = time.time()
    a2c_agent = A2CAgent(state_shape, action_size)
    rewards_a2c = []

    for e in range(episodes):
        state = env.reset()
        states, actions, rewards_list = [], [], []
        total_reward = 0
        for time_step in range(env.max_steps):
            valid_actions = env.get_valid_actions()
            if len(valid_actions) == 0:
                break
            if env.current_player == 1:
                action, prob = a2c_agent.act(state, valid_actions=valid_actions)
                if action not in valid_actions:
                    print(f"Invalid action {action}, choosing random")
                    action = np.random.choice(valid_actions)
                next_state, reward, done, _ = env.step(action)
                states.append(state)
                actions.append(action)
                rewards_list.append(reward)
                state = next_state
                total_reward += reward
                if done:
                    break
            else:
                valid_actions = env.get_valid_actions()
                if len(valid_actions) == 0:
                    break
                action = np.random.choice(valid_actions)
                state, reward, done, _ = env.step(action)
                if done:
                    break
        rewards_a2c.append(total_reward)
        if len(states) > 0:
            a2c_agent.train(states, actions, rewards_list, state, done)
        if (e + 1) % print_interval == 0:
            avg_reward = np.mean(rewards_a2c[-print_interval:])
            print(f"A2C Episode: {e + 1}/{episodes}, Avg Reward: {avg_reward:.2f}")

    a2c_time = time.time() - start_time
    torch.save(a2c_agent.model.state_dict(), 'a2c_model_5x5.pth')
    print(f"✅ A2C training completed in {a2c_time:.2f} seconds!")
    plot_training_progress(rewards_a2c, "A2C", episodes, 'a2c_training_progress_5x5.png')

    # Vẽ biểu đồ so sánh tất cả thuật toán
    print("\n📈 Vẽ biểu đồ so sánh tất cả thuật toán...")
    try:
        plt.figure(figsize=(10, 6))
        plt.plot(calculate_moving_average(rewards_dqn), label="DQN", color='blue')
        plt.plot(calculate_moving_average(rewards_pg), label="PG", color='red')
        plt.plot(calculate_moving_average(rewards_a2c), label="A2C", color='green')
        plt.xlabel("Episode")
        plt.ylabel("Moving Average Reward (window=100)")
        plt.title(f"Training Progress ({episodes} Episodes, 5x5 Board)")
        plt.legend()
        plt.grid(True)
        save_path = 'combined_training_progress_5x5.png'
        plt.savefig(save_path)
        print(f"Combined plot saved to: {os.path.abspath(save_path)}")
        if os.path.exists(save_path):
            print(f"Verified: Combined plot file '{save_path}' exists.")
        else:
            print(f"Error: Combined plot file '{save_path}' was not created.")
        plt.show()
    except Exception as e:
        print(f"Error while plotting combined graph: {e}")
    finally:
        plt.close()

    # Đánh giá hiệu suất
    print("\n📊 Đánh giá hiệu suất các model...")
    dqn_avg, dqn_success = evaluate_agent(dqn_agent, 'DQN', env, 100)
    pg_avg, pg_success = evaluate_agent(pg_agent, 'PG', env, 100)
    a2c_avg, a2c_success = evaluate_agent(a2c_agent, 'A2C', env, 100)

    print("\n" + "=" * 60)
    print("KẾT QUẢ ĐÁNH GIÁ CUỐI CÙNG (100 episodes test)")
    print("=" * 60)
    print(f"DQN  - Avg Reward: {dqn_avg:.2f}, Success Rate: {dqn_success:.1%}")
    print(f"PG   - Avg Reward: {pg_avg:.2f}, Success Rate: {pg_success:.1%}")
    print(f"A2C  - Avg Reward: {a2c_avg:.2f}, Success Rate: {a2c_success:.1%}")

    # Xác định model tốt nhất
    best_scores = {'DQN': dqn_avg, 'PG': pg_avg, 'A2C': a2c_avg}
    best_model = max(best_scores, key=best_scores.get)
    print(f"\n🏆 MODEL TỐT NHẤT: {best_model} với điểm số {best_scores[best_model]:.2f}")

    return best_model, {'DQN': dqn_agent, 'PG': pg_agent, 'A2C': a2c_agent}[best_model]

if __name__ == "__main__":
    best_model, best_agent = main()