In [None]:

import os
import random
import dill as pickle
import numpy as np
import torch
from torch.distributions import Categorical
from tqdm import tqdm

from subrl.utils.environment import GridWorld
from subrl.utils.network import append_state
from subrl.utils.network import policy as agent_net
from subrl.utils.visualization import Visu
from subpo import calculate_submodular_reward, compute_subpo_advantages
# from sub_go_explore import run_go_explore_for_dataset_generation

workspace = "NM"

In [4]:

params = {
    "env": {
        "start": 1,
        "step_size": 0.1,
        "shape": {"x": 7, "y": 14},
        "horizon": 40,
        "node_weight": "constant",
        "disc_size": "small",
        "n_players": 3,
        "Cx_lengthscale": 2,
        "Cx_noise": 0.001,
        "Fx_lengthscale": 1,
        "Fx_noise": 0.001,
        "Cx_beta": 1.5,
        "Fx_beta": 1.5,
        "generate": False,
        "env_file_name": "env_data.pkl",
        "cov_module": "Matern",
        "stochasticity": 0.0,
        "domains": "two_room"
    },
    "alg": {
        "gamma": 1,
        "type": "NM",
        "ent_coef": 0.03,
        "epochs": 500,
        "lr": 0.01
    },
    "common": {
        "a": 1,
        "subgrad": "greedy",
        "grad": "pytorch",
        "algo": "both",
        "init": "deterministic",
        "batch_size": 300
    },
    "visu": {
        "wb": "disabled",
        "a": 1
    }
}
env_load_path = workspace + \
    "/environments/" + params["env"]["node_weight"]+ "/env_1" 

params['env']['num'] = 1
# start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="code-" + params["env"]["node_weight"],
#     mode=params["visu"]["wb"],
#     config=params
# )

epochs = params["alg"]["epochs"]

H = params["env"]["horizon"]
MAX_Ret = 2*(H+1)
if params["env"]["disc_size"] == "large":
    MAX_Ret = 3*(H+2)
    
env = GridWorld(
    env_params=params["env"], common_params=params["common"], visu_params=params["visu"], env_file_path=env_load_path)
node_size = params["env"]["shape"]['x']*params["env"]["shape"]['y']
# TransitionMatrix = torch.zeros(node_size, node_size)

if params["env"]["node_weight"] == "entropy" or params["env"]["node_weight"] == "steiner_covering" or params["env"]["node_weight"] == "GP": 
    a_file = open(env_load_path +".pkl", "rb")
    data = pickle.load(a_file)
    a_file.close()

if params["env"]["node_weight"] == "entropy":
    env.cov = data
if params["env"]["node_weight"] == "steiner_covering":
    env.items_loc = data
if params["env"]["node_weight"] == "GP":
    env.weight = data

visu = Visu(env_params=params["env"])
# plt, fig = visu.stiener_grid( items_loc=env.items_loc, init=34)
# wandb.log({"chart": wandb.Image(fig)})
# plt.close()
# Hori_TransitionMatrix = torch.zeros(node_size*H, node_size*H)
# for node in env.horizon_transition_graph.nodes:
#     connected_edges = env.horizon_transition_graph.edges(node)
#     for u, v in connected_edges:
#         Hori_TransitionMatrix[u[0]*node_size+u[1], v[0]*node_size + v[1]] = 1.0
env.get_horizon_transition_matrix()
# policy = Policy(TransitionMatrix=TransitionMatrix, Hori_TransitionMatrix=Hori_TransitionMatrix, ActionTransitionMatrix=env.Hori_ActionTransitionMatrix[:, :, :, 0],
#                 agent_param=params["agent"], env_param=params["env"])


x_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6.4999, 6.5001, 7.4999, 7.5001, 8.4999, 8.5001, 9.4999, 9.5001, 10.4999, 10.5001, 11.4999, 11.5001, 12.4999, 12.5001, 13.4999, 13.5001]
y_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6.4999, 6.5001]


In [5]:
def select_cell_from_archive(archive):
    """
    Select a cell from the archive for exploration.
    Cells with the fewest selection counts are prioritized.
    """
    if not archive:
        return None, None

    # Find the minimum selection count
    min_times_selected = float('inf')
    for cell_id in archive:
        if archive[cell_id]['times_selected'] < min_times_selected:
            min_times_selected = archive[cell_id]['times_selected']
    
    # Find all cells with the minimum selection count
    least_visited_cells = []
    for cell_id in archive:
        if archive[cell_id]['times_selected'] == min_times_selected:
            least_visited_cells.append(cell_id)
            
    #  Randomly select one of these cells
    selected_cell_id = random.choice(least_visited_cells)
    
    return selected_cell_id, archive[selected_cell_id]

def sample_excellent_trajectories(filepath="go_explore_archive_file_two_Room_98.pkl", 
                                  method='top_n', 
                                  n=10, 
                                  p=0.1, 
                                  threshold=0):
    """
        Load data from the Go-Explore archive and sample high-quality trajectories based on the specified method.

        Args:
            filepath (str): Path to the .pkl archive file.
            method (str): Sampling method. Options are 'top_n', 'top_p', or 'threshold'.
            n (int): Number of trajectories to sample for the 'top_n' method.
            p (float): Percentage of top trajectories to sample for the 'top_p' method (e.g., 0.1 means top 10%).
            threshold (float): Minimum reward threshold for the 'threshold' method.
        
        Returns:
            list: A list of trajectory dictionaries with high rewards, sorted in descending order of reward.
                  Returns an empty list if the file does not exist or the archive is empty.
    """
    # 1. Check if the file exists and load the data
    if not os.path.exists(filepath):
        print(f"Error: Archive file not found '{filepath}'")
        return []
    
    try:
        with open(filepath, "rb") as f:
            archive = pickle.load(f)
        if not archive:
            print("警告：存檔庫為空。")
            return []
    except Exception as e:
        print(f"讀取文件時出錯: {e}")
        return []

    # 2. 提取所有軌跡數據並按獎勵排序
    # archive.values() 返回的是包含 reward, states, actions 等信息的字典
    all_trajectories_data = list(archive.values())
    
    # 按 'reward' 鍵從高到低排序
    all_trajectories_data.sort(key=lambda x: x['reward'], reverse=True)

    # 3. 根據指定方法進行採樣
    sampled_trajectories = []
    if method == 'top_n':
        # 取獎勵最高的前 N 條
        num_to_sample = min(n, len(all_trajectories_data))
        sampled_trajectories = all_trajectories_data[:num_to_sample]
        print(f"方法: Top-N。從 {len(all_trajectories_data)} 條軌跡中篩選出最好的 {len(sampled_trajectories)} 條。")

    elif method == 'top_p':
        # 取獎勵最高的前 P%
        if not (0 < p <= 1):
            print("錯誤：百分比 'p' 必須在 (0, 1] 之間。")
            return []
        num_to_sample = int(len(all_trajectories_data) * p)
        sampled_trajectories = all_trajectories_data[:num_to_sample]
        print(f"方法: Top-P。從 {len(all_trajectories_data)} 條軌跡中篩選出最好的前 {p*100:.1f}% ({len(sampled_trajectories)} 條)。")

    elif method == 'threshold':
        # 取獎勵高於指定門檻的所有軌跡
        sampled_trajectories = [data for data in all_trajectories_data if data['reward'] >= threshold]
        print(f"方法: Threshold。從 {len(all_trajectories_data)} 條軌跡中篩選出 {len(sampled_trajectories)} 條獎勵不低於 {threshold} 的軌跡。")
        
    else:
        print(f"錯誤：未知的採樣方法 '{method}'。請使用 'top_n', 'top_p', 或 'threshold'。")

    return sampled_trajectories


In [12]:


# --- Core parameter setting ---
EXPLORATION_STEPS = 500000  # The total number of exploration steps can be adjusted as needed, and the larger the value, the more thorough the exploration
K_STEPS = 15             # Random exploration steps starting from selected cells 

# --- Go-Explore main  ---
def run_go_explore_phase1_spacetime(env=GridWorld, params=None):
    """
    執行 Go-Explore 的第一階段探索（使用時空細胞）
    """
    print("--- Go-Explore Phase 1: Spacetime Exploration 開始 ---")
    
    # 1. 初始化存檔庫 (Archive)
    # ---【核心修改】：將細胞 Key 定義為 (時間, 狀態ID) ---
    # 結構: { (time, state_id): {'reward': float, 'states': list, 'actions': list, 'times_selected': int} }
    archive = {}
    
    env.common_params["batch_size"]=1
    env.initialize()
    initial_state_tensor = env.state.clone()
    initial_state_id = initial_state_tensor.item()
    
    # ---【核心修改】：初始細胞現在是 (時間=0, 狀態) ---
    initial_cell_key = (0, initial_state_id)
    
    initial_reward = calculate_submodular_reward([initial_state_tensor], env)
    
    archive[initial_cell_key] = {
        'reward': initial_reward,
        'states': [initial_state_tensor],
        'actions': [],
        'times_selected': 0
    }
    print(f"初始時空細胞加入存檔庫: Cell {initial_cell_key}, Reward: {initial_reward}")
    
    # 2. 執行 N 次探索循環
    pbar = tqdm(total=EXPLORATION_STEPS, desc="Go-Exploring (Spacetime)")
    for step in range(EXPLORATION_STEPS):
        # 2.1 選擇細胞 (選擇函數無需修改，它通用於任何 key)
        cell_key_to_explore_from, selected_cell_data = select_cell_from_archive(archive)
        
        if selected_cell_data is None:
            print("錯誤：存檔庫為空，無法繼續探索。")
            break
            
        archive[cell_key_to_explore_from]['times_selected'] += 1

        # 2.2 前往 (Go To) 該細胞狀態
        env.initialize()
        for action in selected_cell_data['actions']:
            env.step(0, torch.tensor([action]))

        # 2.3 從該狀態開始，隨機探索 (Explore) k 步
        current_states = selected_cell_data['states'][:]
        current_actions = selected_cell_data['actions'][:]
        
        k_STEPS=random.randint(5,K_STEPS)
        
        for _ in range(k_STEPS):
            if len(current_actions) >= params["env"]["horizon"] - 1:
                break
                
            random_action = random.randint(0, env.action_dim - 1)
            env.step(0, torch.tensor([random_action]))
            
            new_state_tensor = env.state.clone()
            
            current_states.append(new_state_tensor)
            current_actions.append(random_action)
            
            # ---【核心修改】：更新存檔庫時使用 (時間, 狀態) 作為 Key ---
            new_state_id = new_state_tensor.item()
            time_step = len(current_actions) # 當前時間步 = 已執行動作的數量
            new_cell_key = (time_step, new_state_id)
            
            new_reward = calculate_submodular_reward(current_states, env)
            
            if new_cell_key not in archive or new_reward > archive[new_cell_key]['reward']:
                archive[new_cell_key] = {
                    'reward': new_reward,
                    'states': current_states[:],
                    'actions': current_actions[:],
                    'times_selected': 0
                }
        if step% 10000 == 0:
            print(f"探索步骤: {step / EXPLORATION_STEPS * 100:.2f}%")
            print(f"當前存檔庫大小: {len(archive)}")
            _best_trajectory_data = max(archive.values(), key=lambda x: x['reward'])
            _max_reward = _best_trajectory_data['reward']
            print(f"找到的最佳獎勵值: {_max_reward}")
            print(f"對應的軌跡長度: {len(_best_trajectory_data['states'])}")
                
        pbar.update(1)
    pbar.close()

    # 3. 探索結束
    print("\n--- 探索完成 ---")
    if not archive:
        print("錯誤：存檔庫為空！")
        return None

    best_trajectory_data = max(archive.values(), key=lambda x: x['reward'])
    max_reward = best_trajectory_data['reward']
    print(f"找到的最佳獎勵值: {max_reward}")
    print(f"對應的軌跡長度: {len(best_trajectory_data['states'])}")

    # 4. 保存存檔庫
    archive_filename = "go_explore_archive_spacetime_.pkl"
    with open(archive_filename, "wb") as f:
        pickle.dump(archive, f)
    print(f"完整的時空細胞存檔庫已保存至: {archive_filename}")
    
    return best_trajectory_data

# --- 運行最終升級版的 Go-Explore ---
best_found_trajectory = run_go_explore_phase1_spacetime(env, params)



--- Go-Explore Phase 1: Spacetime Exploration 開始 ---
初始時空細胞加入存檔庫: Cell (0, 34), Reward: 0.0


Go-Exploring (Spacetime):   0%|          | 0/500000 [00:00<?, ?it/s]

探索步骤: 0.00%
當前存檔庫大小: 15
找到的最佳獎勵值: 8
對應的軌跡長度: 10
探索步骤: 2.00%
當前存檔庫大小: 2083
找到的最佳獎勵值: 47
對應的軌跡長度: 40
探索步骤: 4.00%
當前存檔庫大小: 2196
找到的最佳獎勵值: 56
對應的軌跡長度: 40
探索步骤: 6.00%
當前存檔庫大小: 2246
找到的最佳獎勵值: 62
對應的軌跡長度: 40
探索步骤: 8.00%
當前存檔庫大小: 2267
找到的最佳獎勵值: 63
對應的軌跡長度: 40
探索步骤: 10.00%
當前存檔庫大小: 2298
找到的最佳獎勵值: 63
對應的軌跡長度: 40
探索步骤: 12.00%
當前存檔庫大小: 2298
找到的最佳獎勵值: 66
對應的軌跡長度: 40
探索步骤: 14.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 16.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 18.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 20.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 22.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 24.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 26.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 28.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 30.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 32.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 34.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 36.00%
當前存檔庫大小: 2312
找到的最佳獎勵值: 68
對應的軌跡長度: 40
探索步骤: 38.00%
當前存檔庫大小

In [40]:
print(best_found_trajectory)

{'reward': 68, 'states': [tensor([34]), tensor([33]), tensor([32]), tensor([32]), tensor([31]), tensor([30]), tensor([44]), tensor([58]), tensor([72]), tensor([71]), tensor([70]), tensor([56]), tensor([42]), tensor([28]), tensor([14]), tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([17]), tensor([31]), tensor([32]), tensor([33]), tensor([34]), tensor([35]), tensor([36]), tensor([37]), tensor([38]), tensor([24]), tensor([10]), tensor([11]), tensor([12]), tensor([26]), tensor([40]), tensor([54]), tensor([68]), tensor([82]), tensor([81]), tensor([80]), tensor([66])], 'actions': [3, 3, 4, 3, 3, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 4, 4, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4], 'times_selected': 4106}


In [6]:

top_20_trajectories = sample_excellent_trajectories(filepath="go_explore_archive_file_two_Room_98.pkl",method='top_n', n=300)
if top_20_trajectories:
    print(f"其中最好的一條獎勵為: {top_20_trajectories[0]['reward']}")
    print(f"最差的一條（在這300條中）獎勵為: {top_20_trajectories[-1]['reward']}\n")
    

方法: Top-N。從 2312 條軌跡中篩選出最好的 300 條。
其中最好的一條獎勵為: 68
最差的一條（在這300條中）獎勵為: 59



In [11]:
# 範例1：獲取獎勵最高的 20 條軌跡
print("--- 範例 1: 採樣 Top 20 ---")
top_20_trajectories = sample_excellent_trajectories(method='top_n', n=100)
if top_20_trajectories:
    print(f"其中最好的一條獎勵為: {top_20_trajectories[0]['reward']}")
    print(f"最差的一條（在這300條中）獎勵為: {top_20_trajectories[-1]['reward']}\n")

# 範例2：獲取獎勵排名前 5% 的軌跡
print("--- 範例 2: 採樣 Top 5% ---")
top_5_percent_trajectories = sample_excellent_trajectories(method='top_p', p=0.05)
if top_5_percent_trajectories:
    # 打印其中一條軌跡的詳細信息以供檢查
    sample_traj_data = top_5_percent_trajectories[0]
    print(f"抽樣檢查最好的一條軌跡：獎勵={sample_traj_data['reward']}, 長度={len(sample_traj_data['states'])}\n")

# 範例3：獲取所有獎勵值大於等於 45 的軌跡
print("--- 範例 3: 採樣獎勵 >= 45 的軌跡 ---")
high_reward_trajectories = sample_excellent_trajectories(method='threshold', threshold=68)
if high_reward_trajectories:
    print(f"所有高分軌跡的平均獎勵為: {np.mean([d['reward'] for d in high_reward_trajectories]):.2f}\n")


--- 範例 1: 採樣 Top 20 ---
方法: Top-N。從 2312 條軌跡中篩選出最好的 100 條。
其中最好的一條獎勵為: 68
最差的一條（在這300條中）獎勵為: 64

--- 範例 2: 採樣 Top 5% ---
方法: Top-P。從 2312 條軌跡中篩選出最好的前 5.0% (115 條)。
抽樣檢查最好的一條軌跡：獎勵=68, 長度=40

--- 範例 3: 採樣獎勵 >= 45 的軌跡 ---
方法: Threshold。從 2312 條軌跡中篩選出 11 條獎勵不低於 68 的軌跡。
所有高分軌跡的平均獎勵為: 68.00

