In [7]:
import numpy as np
import argparse
from typing import Dict, Any

def analyze_dataset(dataset_directory: str):
    """
    載入並分析 teleoperation 收集的數據集。

    Args:
        dataset_directory: 數據集 (.npz) 文件的路徑。
    """
    print(f"--- 開始分析數據集: {dataset_directory} ---")
    
    try:
        # 載入 .npz 文件
        # allow_pickle=True 是必須的，因為數據是以 Python 列表形式儲存的
        loaded_data: Dict[str, Any] = np.load(dataset_directory, allow_pickle=True)
        
        # 檢查必須的鍵
        required_keys = ["observations", "actions", "done"]
        for key in required_keys:
            if key not in loaded_data:
                print(f"錯誤: 數據集中缺少必要的鍵 '{key}'。")
                return

        observations = loaded_data["observations"]
        actions = loaded_data["actions"]
        done = loaded_data["done"]

        # --- 基本統計 ---
        
        total_timesteps = len(observations)
        
        if total_timesteps == 0:
            print("數據集為空，沒有時間步可供分析。")
            return

        # 總 episode 數是 done 標誌為 True (或 1.0) 的數量
        total_episodes = np.sum(done)
        
        print(f"**總結**")
        print(f"  總時間步數 (Total Timesteps): {total_timesteps}")
        print(f"  總 Episode 數 (Total Episodes): {total_episodes}")
        
        # --- 數據轉換與形狀分析 ---
        observations_array = None
        actions_array = None

        # 將 NumPy object array (包含列表) 轉換為單個 NumPy 陣列，以便進行高效計算
        try:
            # np.stack 將 list 中的每個 array 沿著新軸堆疊起來
            observations_array = np.stack(observations)
            actions_array = np.stack(actions)
            
            obs_shape = observations_array.shape[1:]
            action_shape = actions_array.shape[1:]
            
            print(f"\n**數據形狀與類型**")
            print(f"  所有觀察值陣列形狀: {observations_array.shape}")
            print(f"  單個觀察值形狀 (Observation Shape): {obs_shape}")
            print(f"  觀察值數據類型 (Observation Dtype): {observations_array.dtype}")
            print(f"  所有行動值陣列形狀: {actions_array.shape}")
            print(f"  單個行動值形狀 (Action Shape): {action_shape}")
            print(f"  行動值數據類型 (Action Dtype): {actions_array.dtype}")

        except ValueError as e:
            print(f"警告: 數據元素無法堆疊成單一陣列。這可能表示觀察值或行動值的形狀不一致。原始錯誤: {e}")
            print(f"  單個觀察值形狀 (Observation Shape): {observations[0].shape}")
            print(f"  單個行動值形狀 (Action Shape): {actions[0].shape}")
            
        # --- Episode 長度分析 ---
        
        if total_episodes > 0:
            # 計算每個 episode 的長度
            done_indices = np.where(done)[0]
            start_indices = np.concatenate(([0], done_indices[:-1] + 1))
            episode_lengths = done_indices - start_indices + 1
            
            avg_length = np.mean(episode_lengths)
            min_length = np.min(episode_lengths)
            max_length = np.max(episode_lengths)

            print(f"\n**Episode 長度統計 (Timesteps)**")
            print(f"  平均長度 (Mean Length): {avg_length:.2f}")
            print(f"  最小長度 (Min Length): {min_length}")
            print(f"  最大長度 (Max Length): {max_length}")
            
        # --- 觀察值和行動值的分佈分析 (Imitation Learning 關鍵) ---
        
        if observations_array is not None:
            print(f"\n**觀察值 (Observations) 分佈統計**")
            # 沿著時間軸 (axis=0) 計算每個維度的平均值和標準差
            obs_mean = np.mean(observations_array, axis=0)
            obs_std = np.std(observations_array, axis=0)
            
            # 印出所有維度的平均值和標準差的平均值，作為整體數據的分佈概覽
            print(f"  觀察值總平均 (Overall Mean): {np.mean(obs_mean):.4f}")
            print(f"  觀察值總標準差 (Overall Std): {np.mean(obs_std):.4f}")
            # 顯示前幾個維度的詳細統計 (最多 10 個)
            for i in range(min(obs_shape[0], 10)):
                print(f"    維度 {i:02d}: Mean={obs_mean[i]:.4f}, Std={obs_std[i]:.4f}")
            if obs_shape[0] > 10:
                print("    ... (僅顯示前 10 個維度的統計)")
        
        if actions_array is not None:
            print(f"\n**行動值 (Actions) 分佈統計**")
            action_mean = np.mean(actions_array, axis=0)
            action_std = np.std(actions_array, axis=0)

            print(f"  行動值總平均 (Overall Mean): {np.mean(action_mean):.4f}")
            print(f"  行動值總標準差 (Overall Std): {np.mean(action_std):.4f}")
            # 行動通常維度較少，顯示更多
            for i in range(min(action_shape[0], 6)): 
                 print(f"    維度 {i}: Mean={action_mean[i]:.4f}, Std={action_std[i]:.4f}")
            if action_shape[0] > 6:
                print("    ... (僅顯示前 6 個維度的統計)")


        print("\n--- 分析結束 ---")

    except FileNotFoundError:
        print(f"錯誤: 未找到文件 {dataset_directory}。請確認路徑是否正確。")
    except Exception as e:
        # 捕捉所有其他異常
        print(f"分析時發生意外錯誤: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Analysis script for teleoperation dataset.")
    parser.add_argument(
        "--data_set_directory", 
        type=str, 
        default="./data/dataset_kick.npz", 
        help="Dataset file path to analyze."
    )

    # 關鍵修正: 使用 parse_known_args() 忽略 Jupyter/IPython 傳入的額外參數，避免報錯
    args, unknown = parser.parse_known_args()
    
    if unknown:
        print(f"警告: 忽略未識別的參數 (如 Jupyter kernel 參數): {unknown}")

    analyze_dataset(args.data_set_directory)

警告: 忽略未識別的參數 (如 Jupyter kernel 參數): ['--f=c:\\Users\\ricky\\AppData\\Roaming\\jupyter\\runtime\\kernel-v3b5bddf7451c4bae7431db2d9a837a0186abfec93.json']
--- 開始分析數據集: ./data/dataset_kick.npz ---
**總結**
  總時間步數 (Total Timesteps): 15511
  總 Episode 數 (Total Episodes): 7

**數據形狀與類型**
  所有觀察值陣列形狀: (15511, 89)
  單個觀察值形狀 (Observation Shape): (89,)
  觀察值數據類型 (Observation Dtype): float64
  所有行動值陣列形狀: (15511, 12)
  單個行動值形狀 (Action Shape): (12,)
  行動值數據類型 (Action Dtype): float32

**Episode 長度統計 (Timesteps)**
  平均長度 (Mean Length): 2215.86
  最小長度 (Min Length): 590
  最大長度 (Max Length): 2797

**觀察值 (Observations) 分佈統計**
  觀察值總平均 (Overall Mean): 0.4015
  觀察值總標準差 (Overall Std): 0.8510
    維度 00: Mean=-0.4143, Std=0.2370
    維度 01: Mean=0.0684, Std=0.1063
    維度 02: Mean=-0.0268, Std=0.1444
    維度 03: Mean=0.8665, Std=0.3877
    維度 04: Mean=-0.3890, Std=0.2338
    維度 05: Mean=-0.0980, Std=0.1212
    維度 06: Mean=-0.3845, Std=0.2641
    維度 07: Mean=-0.0748, Std=0.1328
    維度 08: Mean=0.0444, Std=0.0894
  