In [1]:
# === 儲存格 1：掛載 Google Drive 並設定工作目錄 ===
from google.colab import drive
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)
pd.options.mode.chained_assignment = None

drive.mount('/content/drive')
WORK_DIR = '/content/drive/MyDrive/FRL_Slicing_Sim' # 確保與 Notebook 1 一致
if not os.path.exists(WORK_DIR): os.makedirs(WORK_DIR)
os.chdir(WORK_DIR)
print(f"目前工作目錄: {os.getcwd()}")




Mounted at /content/drive
目前工作目錄: /content/drive/MyDrive/FRL_Slicing_Sim


In [6]:
# === 儲存格 2：定義相關路徑和載入預處理數據 ===
print("\n--- Cell 2: 載入預處理數據 ---")
DATASET_LOCAL_NAME = 'colosseum-oran-coloran-dataset'
DATASET_DIR = os.path.join(WORK_DIR, DATASET_LOCAL_NAME)
PREPARED_DATA_PATH = os.path.join(DATASET_DIR, 'kpi_traces_final_v_robust.parquet')

df_kpi = pd.DataFrame()
if os.path.exists(PREPARED_DATA_PATH):
    print(f"正在從 {PREPARED_DATA_PATH} 載入預處理數據...")
    try:
        df_kpi = pd.read_parquet(PREPARED_DATA_PATH)
        print("預處理數據載入成功！")
        print("數據形狀:", df_kpi.shape)
        if 'Timestamp' in df_kpi.columns and not pd.api.types.is_datetime64_any_dtype(df_kpi):
            df_kpi = pd.to_datetime(df_kpi, errors='coerce')
            print("\nTimestamp 欄位已轉換為 datetime 物件。")
    except Exception as e: print(f"讀取 Parquet 檔案失敗: {e}")
else:
    print(f"錯誤：找不到預處理的數據檔案 {PREPARED_DATA_PATH}。請先執行 Notebook 1。")






--- Cell 2: 載入預處理數據 ---
正在從 /content/drive/MyDrive/FRL_Slicing_Sim/colosseum-oran-coloran-dataset/kpi_traces_final_v_robust.parquet 載入預處理數據...
預處理數據載入成功！
數據形狀: (6183234, 38)


In [None]:
# @title
# === 最終儲存格：模擬環境、聯邦學習框架與主循環 (最終修正版) ===
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import random
from collections import deque
import copy
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import warnings

# 忽略在 groupby.diff() 中可能出現的 FutureWarning
warnings.simplefilter(action='ignore', category=FutureWarning)

# --- 1. SliceSimEnv: 採用預計算與快取優化 ---
print("\n--- [1/4] 定義優化後的 SliceSimEnv ---")
class SliceSimEnv:
    def __init__(self, df_data, gnb_id_to_sim=1, slice_ids_to_sim=None):
        self.ts_col = 'timestamp'
        self.gnb_id = gnb_id_to_sim
        self.sim_slice_ids = slice_ids_to_sim if slice_ids_to_sim is not None else [0, 2]

        self.action_space_map = {0: "Prioritize_eMBB", 1: "Balanced", 2: "Prioritize_URLLC"}
        self.action_size = len(self.action_space_map)
        self.state_size = len(self.sim_slice_ids) * 2
        self.default_state = np.zeros(self.state_size)

        self.precomputed_states = {}
        self.timestamps_array = np.array([])

        if df_data.empty:
            print("錯誤：傳入數據為空，無法初始化環境。")
            return

        print(f"      正在為 BS_ID={self.gnb_id} 準備數據...")
        mask = (df_data['BS_ID'] == self.gnb_id) & (df_data['Slice_ID'].isin(self.sim_slice_ids))
        slice_data = df_data[mask].copy()

        if slice_data.empty:
            print(f"警告：在 BS_ID={self.gnb_id} 和 Slice_IDs={self.sim_slice_ids} 下找不到數據。")
            return

        slice_data.sort_values(by=self.ts_col, inplace=True)
        self.timestamps_array = slice_data[self.ts_col].unique()

        print(f"      正在對 {len(self.timestamps_array)} 個唯一時間戳進行狀態預計算...")
        grouped_by_ts = slice_data.groupby(self.ts_col)
        for timestamp, group in tqdm(grouped_by_ts, desc="預計算進度"):
            state_vector = np.zeros(self.state_size)
            for i, slice_id in enumerate(self.sim_slice_ids):
                slice_specific_data = group[group['Slice_ID'] == slice_id]

                tput = 0.0
                lat_proxy = 0.0

                if not slice_specific_data.empty:
                    # --- 修正處: 先判斷欄位是否存在，再安全取值 ---
                    if 'Throughput_DL_Mbps' in slice_specific_data.columns:
                        tput = slice_specific_data['Throughput_DL_Mbps'].iloc[0]

                    if 'Buffer_Occupancy_DL_bytes' in slice_specific_data.columns:
                        lat_proxy = slice_specific_data['Buffer_Occupancy_DL_bytes'].iloc[0]

                state_vector[i*2] = tput if pd.notna(tput) else 0.0
                state_vector[i*2 + 1] = lat_proxy if pd.notna(lat_proxy) else 0.0

            self.precomputed_states[timestamp] = state_vector

        print("      ✅ 狀態預計算與快取完成！")
        self.reset()

    def _get_state_at_timestamp(self, timestamp):
        return self.precomputed_states.get(timestamp, self.default_state)

    def reset(self):
        if len(self.timestamps_array) == 0:
            return self.default_state
        self.current_timestamp_idx = np.random.randint(0, len(self.timestamps_array))
        self.current_timestamp = self.timestamps_array[self.current_timestamp_idx]
        return self._get_state_at_timestamp(self.current_timestamp)

    def step(self, action_id):
        if len(self.timestamps_array) < 2:
            return self.default_state, 0, True, {}

        self.current_timestamp_idx = (self.current_timestamp_idx + 1) % len(self.timestamps_array)
        done = (self.current_timestamp_idx == 0)

        next_timestamp = self.timestamps_array[self.current_timestamp_idx]
        state = self._get_state_at_timestamp(next_timestamp)

        reward = 0
        if len(state) == self.state_size and len(self.sim_slice_ids) == 2:
            s0_tput, s0_lat_proxy, s2_tput, s2_lat_proxy = state
            w_s0_tput, w_s2_lat = (0.5, 0.5)
            if action_id == 0: w_s0_tput, w_s2_lat = (0.7, 0.3)
            elif action_id == 2: w_s0_tput, w_s2_lat = (0.3, 0.7)

            reward_s0 = np.log1p(s0_tput)
            reward_s2 = - (s2_lat_proxy * 0.00001)
            reward = w_s0_tput * reward_s0 + w_s2_lat * reward_s2

        return state, reward, done, {'timestamp': next_timestamp, 'action_taken': self.action_space_map.get(action_id)}

# --- 2. RLAgent 與 FLServer 框架 ---
print("\n--- [2/4] 定義 RLAgent 與 FLServer 框架 ---")
class RLAgent:
    def __init__(self, state_size, action_size, learning_rate=0.001, device='cpu'):
        self.state_size, self.action_size = state_size, action_size
        self.memory = deque(maxlen=20000)
        self.gamma, self.epsilon, self.epsilon_min, self.epsilon_decay = 0.95, 1.0, 0.01, 0.998
        self.device = device
        self.model = nn.Sequential(
            nn.Linear(self.state_size, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU(),
            nn.Linear(64, self.action_size)
        ).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.criterion = nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        with torch.no_grad():
            act_values = self.model(state_tensor)
        return torch.argmax(act_values).item()

    def replay(self, batch_size):
        if len(self.memory) < batch_size: return 0.0
        minibatch = random.sample(self.memory, batch_size)
        states = torch.from_numpy(np.vstack([e[0] for e in minibatch if e is not None])).float().to(self.device)
        actions = torch.tensor([e[1] for e in minibatch if e is not None], device=self.device).long().view(-1, 1)
        rewards = torch.tensor([e[2] for e in minibatch if e is not None], device=self.device).float().view(-1, 1)
        next_states = torch.from_numpy(np.vstack([e[3] for e in minibatch if e is not None])).float().to(self.device)
        dones = torch.tensor([e[4] for e in minibatch if e is not None], device=self.device).bool().view(-1, 1)

        current_q_values = self.model(states).gather(1, actions)
        with torch.no_grad():
            max_next_q_values = self.model(next_states).max(dim=1)[0].unsqueeze(1)
            target_q_values = rewards + (self.gamma * max_next_q_values * (~dones))

        loss = self.criterion(current_q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay
        return loss.item()

class FLServer:
    def aggregate_models(self, client_model_states):
        if not client_model_states: return None
        processed_states = [{k: v.cpu() for k, v in state_dict.items()} for state_dict in client_model_states]
        global_dict = copy.deepcopy(processed_states[0])
        for k in global_dict.keys():
            global_dict[k] = torch.stack([state[k] for state in processed_states], 0).mean(0)
        return global_dict

    def distribute_model(self, clients_agents, global_model_dict):
        if global_model_dict:
            for agent in clients_agents:
                agent.model.load_state_dict(global_model_dict)

# --- 3. 聯邦學習主循環 ---
print("\n--- [3/4] 準備運行聯邦學習主循環 ---")
if 'df_kpi' in locals() and not df_kpi.empty:
    # --- 參數定義 ---
    NUM_CLIENTS = 3
    COMM_ROUNDS = 10
    LOCAL_EPISODES = 30
    STEPS_PER_EPISODE = 100
    BATCH_SIZE_REPLAY = 64
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"將使用設備: {DEVICE}")

    # --- 客戶端選擇 ---
    server = FLServer()
    available_bs_ids = df_kpi['BS_ID'].unique()
    num_to_sample = min(NUM_CLIENTS, len(available_bs_ids))
    if num_to_sample == 0:
        print("錯誤：數據中沒有可用的 BS_ID。")
    else:
        sim_bs_ids = np.random.choice(available_bs_ids, num_to_sample, replace=False)
        print(f"將模擬的 BS_IDs: {sim_bs_ids}")

        # --- 環境與代理初始化 ---
        print("正在初始化客戶端環境與 RL 代理...")
        client_envs = [SliceSimEnv(df_kpi, gnb_id_to_sim=bs_id, slice_ids_to_sim=[0, 2]) for bs_id in sim_bs_ids]

        # 檢查是否有任何環境成功初始化
        if not any(len(env.timestamps_array) > 0 for env in client_envs):
            print("錯誤：所有被選中的 BS_ID 都沒有有效的數據來創建模擬環境。")
        else:
            temp_env_for_sizes = next(env for env in client_envs if len(env.timestamps_array) > 0)
            STATE_SIZE, ACTION_SIZE = temp_env_for_sizes.state_size, temp_env_for_sizes.action_size
            clients_agents = [RLAgent(STATE_SIZE, ACTION_SIZE, device=DEVICE) for _ in sim_bs_ids]
            all_round_avg_rewards = []

            # --- 聯邦學習主循環 ---
            print("\n--- [4/4] 聯邦學習模擬開始 ---")
            for comm_round in tqdm(range(COMM_ROUNDS), desc="聯邦學習回合"):
                print(f"\n--- 聯邦學習回合 {comm_round + 1}/{COMM_ROUNDS} ---")
                client_model_states, current_round_rewards = [], []

                for i, agent in enumerate(clients_agents):
                    env = client_envs[i]
                    if len(env.timestamps_array) < 2:
                        print(f"  客戶端 {i+1} (BS_ID: {env.gnb_id}) 數據不足，本回合跳過訓練。")
                        if hasattr(agent, 'model'): client_model_states.append(agent.model.state_dict())
                        continue

                    print(f"  訓練客戶端 {i+1}/{len(clients_agents)} (BS_ID: {env.gnb_id})...")
                    for episode in range(LOCAL_EPISODES):
                        state = env.reset()
                        for step in range(STEPS_PER_EPISODE):
                            action = agent.act(state)
                            next_state, reward, done, _ = env.step(action)
                            agent.remember(state, action, reward, next_state, done)
                            state = next_state
                            if len(agent.memory) > BATCH_SIZE_REPLAY:
                                agent.replay(BATCH_SIZE_REPLAY)
                            if done: break
                    client_model_states.append(agent.model.state_dict())

                if client_model_states:
                    global_model_dict = server.aggregate_models(client_model_states)
                    if global_model_dict:
                        server.distribute_model(clients_agents, global_model_dict)
                        print("  模型聚合與分發完成。")

            print("\n=== 模擬結束 ===")

            # --- [視覺化] 繪製結果圖表 ---
            print("\n--- 正在繪製訓練結果圖表 ---")

            # 創建一個 1x2 的子圖佈局
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

            # 圖 1: 平均獎勵變化
            ax1.plot(range(1, COMM_ROUNDS + 1), all_round_avg_rewards, marker='o', linestyle='-', color='b')
            ax1.set_title('Average Reward per Communication Round', fontsize=16)
            ax1.set_xlabel('Communication Round', fontsize=12)
            ax1.set_ylabel('Average Reward', fontsize=12)
            ax1.grid(True)

            # 圖 2: 平均損失變化
            ax2.plot(range(1, COMM_ROUNDS + 1), all_round_avg_loss, marker='o', linestyle='-', color='r')
            ax2.set_title('Average Loss per Communication Round', fontsize=16)
            ax2.set_xlabel('Communication Round', fontsize=12)
            ax2.set_ylabel('Average Loss', fontsize=12)
            ax2.grid(True)

            plt.tight_layout() # 自動調整子圖間距
            plt.show()

else:
    print("錯誤：預處理數據 (df_kpi) 為空，無法開始聯邦學習模擬。")


--- [1/4] 定義優化後的 SliceSimEnv ---

--- [2/4] 定義 RLAgent 與 FLServer 框架 ---

--- [3/4] 準備運行聯邦學習主循環 ---
將使用設備: cuda
將模擬的 BS_IDs: [2 6 3]
正在初始化客戶端環境與 RL 代理...
      正在為 BS_ID=2 準備數據...
      正在對 867738 個唯一時間戳進行狀態預計算...


預計算進度:   0%|          | 0/867738 [00:00<?, ?it/s]

      ✅ 狀態預計算與快取完成！
      正在為 BS_ID=6 準備數據...
      正在對 889288 個唯一時間戳進行狀態預計算...


預計算進度:   0%|          | 0/889288 [00:00<?, ?it/s]

      ✅ 狀態預計算與快取完成！
      正在為 BS_ID=3 準備數據...
      正在對 873440 個唯一時間戳進行狀態預計算...


預計算進度:   0%|          | 0/873440 [00:00<?, ?it/s]