In [None]:
# === 儲存格 1：掛載 Google Drive、設定工作目錄並安裝專案套件 ===
import os
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import warnings
from google.colab import drive

# --- 基本設定 ---
warnings.simplefilter(action='ignore', category=FutureWarning)
pd.options.mode.chained_assignment = None

# --- 掛載與路徑設定 ---
drive.mount('/content/drive')
WORK_DIR = '/content/drive/MyDrive/FRL_Slicing_Sim' # 確保此路徑與您的雲端硬碟一致
if not os.path.exists(WORK_DIR):
    os.makedirs(WORK_DIR)
os.chdir(WORK_DIR)
print(f"目前工作目錄: {os.getcwd()}")

# --- 安裝本地套件 ---
# 這一步確保 Notebook 會使用您在 src/ 中的程式碼
print("\n正在安裝 colosseum_oran_frl_demo 套件...")
!pip install -e . -q
print("✅ 套件安裝完成。")

Mounted at /content/drive
目前工作目錄: /content/drive/MyDrive/FRL_Slicing_Sim


In [None]:
# === 儲存格 2：載入數據與定義超參數 ===
from colosseum_oran_frl_demo.config import Paths

# --- 載入數據 ---
print("--- 正在載入預處理數據 ---")
# 使用 Paths 物件來確保路徑一致性
# 注意：這裡假設您的數據準備腳本（或 Notebook 1）已成功執行
DATA_PATH = Paths.PROCESSED / 'kpi_traces_final_v_robust.parquet'

df_kpi = pd.DataFrame()
if DATA_PATH.exists():
    try:
        df_kpi = pd.read_parquet(DATA_PATH)
        print(f"✅ 成功從 {DATA_PATH} 載入數據！")
        print(f"   數據形狀: {df_kpi.shape}")
    except Exception as e:
        print(f"❌ 讀取 Parquet 檔案失敗: {e}")
else:
    print(f"❌ 錯誤：找不到預處理的數據檔案 {DATA_PATH}。請先執行數據準備腳本。")

# --- 訓練超參數定義 ---
# 您可以在此處快速調整實驗參數
NUM_CLIENTS = 3
COMM_ROUNDS = 10
LOCAL_STEPS = 2000 # 在原始碼中是 2000，這裡保持一致
BATCH_SIZE_REPLAY = 64
SIM_BS_IDS = [1, 2, 6] # 選擇要模擬的基地台 ID
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"\n--- 訓練參數設定完成 ---")
print(f"設備: {DEVICE}")
print(f"聯邦學習回合: {COMM_ROUNDS}")
print(f"本地訓練步數: {LOCAL_STEPS}")
print(f"模擬的 BS IDs: {SIM_BS_IDS}")


--- Cell 2: 載入預處理數據 ---
正在從 /content/drive/MyDrive/FRL_Slicing_Sim/colosseum-oran-coloran-dataset/kpi_traces_final_v_robust.parquet 載入預處理數據...
預處理數據載入成功！
數據形狀: (6183234, 38)


In [None]:
# === 儲存格 3：執行聯邦強化學習主循環 ===
from tqdm.notebook import tqdm
import copy

# 從您的套件中匯入核心組件
from colosseum_oran_frl_demo.envs.slice_sim_env import SliceSimEnv
from colosseum_oran_frl_demo.agents.rl_agent import RLAgent
from colosseum_oran_frl_demo.agents.fed_server import fedavg
from colosseum_oran_frl_demo.utils.plots import plot_training_results
from colosseum_oran_frl_demo.config import HP # 匯入超參數

# --- [1/4] 初始化環境與代理 (Agents) ---
print("--- [1/4] 正在初始化客戶端環境與 RL 代理... ---")

if 'df_kpi' in locals() and not df_kpi.empty:
    client_envs = [SliceSimEnv(df_kpi, gnb_id=bs_id) for bs_id in SIM_BS_IDS]
    # 確保所有選擇的 BS 都有數據
    client_envs = [env for env in client_envs if len(env.ts) > 0]

    if not client_envs:
        print("❌ 錯誤：所選的 BS ID 都沒有有效的數據來創建模擬環境。")
    else:
        print(f"✅ 成功為 {len(client_envs)} 個客戶端創建環境。")

        clients_agents = [
            RLAgent(
                state_size=env.state_size,
                action_size=env.action_size,
                lr=HP.LR, # 從設定檔讀取學習率
                gamma=HP.GAMMA, # 從設定檔讀取 gamma
                device=DEVICE
            ) for env in client_envs
        ]

        # --- [2/4] 聯邦學習主循環 ---
        print("\n--- [2/4] 聯邦學習模擬開始 ---")

        all_round_avg_rewards = []
        all_round_avg_losses = []

        for comm_round in tqdm(range(COMM_ROUNDS), desc="聯邦學習回合"):

            # 收集所有客戶端模型
            client_model_states = [copy.deepcopy(agent.model.state_dict()) for agent in clients_agents]

            # 執行聯邦平均
            global_model_dict = fedavg(client_model_states)

            # 將聚合後的模型分發回所有客戶端
            if global_model_dict:
                for agent in clients_agents:
                    agent.model.load_state_dict(global_model_dict)

            # --- 本地訓練 ---
            current_round_rewards = []
            current_round_losses = []

            for i, agent in enumerate(clients_agents):
                env = client_envs[i]
                state, _ = env.reset()
                episode_reward = 0

                for step in range(LOCAL_STEPS):
                    action = agent.act(state)

                    # 修正：正確解包 5 個回傳值
                    next_state, reward, terminated, truncated, _ = env.step(action)

                    # 修正：將 terminated 或 truncated 作為 done 旗標
                    done = terminated or truncated
                    agent.remember(state, action, reward, next_state, done)

                    # 執行 Replay 並記錄 loss
                    loss = agent.replay(BATCH_SIZE_REPLAY)
                    if loss is not None:
                        current_round_losses.append(loss)

                    episode_reward += reward
                    state = next_state

                    if done:
                        break # 如果環境結束，則重置並開始下一個 episode (在此簡化為結束)

                current_round_rewards.append(episode_reward / (step + 1))

            # 記錄每一回合的平均獎勵與損失
            avg_reward_this_round = np.mean(current_round_rewards) if current_round_rewards else 0
            avg_loss_this_round = np.mean(current_round_losses) if current_round_losses else 0
            all_round_avg_rewards.append(avg_reward_this_round)
            all_round_avg_losses.append(avg_loss_this_round)

            print(f"--- 回合 {comm_round + 1}/{COMM_ROUNDS} ---")
            print(f"  平均獎勵: {avg_reward_this_round:.4f}")
            print(f"  平均損失: {avg_loss_this_round:.6f}")

        # --- [3/4] 儲存結果 ---
        print("\n--- [3/4] 儲存最終模型與訓練歷史 ---")
        final_global_model = fedavg([agent.model.state_dict() for agent in clients_agents])
        torch.save(final_global_model, 'global_model.pt')
        print("✅ 全域模型已儲存至 global_model.pt")

        # --- [4/4] 視覺化結果 ---
        print("\n--- [4/4] 正在繪製訓練結果圖表 ---")
        plot_training_results(all_round_avg_rewards, all_round_avg_losses)
else:
    print("❌ 錯誤：df_kpi 不存在或為空，無法開始模擬。")


--- [1/4] 定義優化後的 SliceSimEnv ---

--- [2/4] 定義 RLAgent 與 FLServer 框架 ---

--- [3/4] 準備運行聯邦學習主循環 ---
將使用設備: cuda
將模擬的 BS_IDs: [2 6 3]
正在初始化客戶端環境與 RL 代理...
      正在為 BS_ID=2 準備數據...
      正在對 867738 個唯一時間戳進行狀態預計算...


預計算進度:   0%|          | 0/867738 [00:00<?, ?it/s]

      ✅ 狀態預計算與快取完成！
      正在為 BS_ID=6 準備數據...
      正在對 889288 個唯一時間戳進行狀態預計算...


預計算進度:   0%|          | 0/889288 [00:00<?, ?it/s]

      ✅ 狀態預計算與快取完成！
      正在為 BS_ID=3 準備數據...
      正在對 873440 個唯一時間戳進行狀態預計算...


預計算進度:   0%|          | 0/873440 [00:00<?, ?it/s]