In [None]:

import os
import random
import dill as pickle
import numpy as np
import torch
from torch.distributions import Categorical
from tqdm import tqdm

from subrl.utils.environment import GridWorld
from subrl.utils.network import append_state
from subrl.utils.network import policy as agent_net
from subrl.utils.visualization import Visu
from subpo import calculate_submodular_reward, compute_subpo_advantages
workspace = "NM"

In [2]:

params = {
    "env": {
        "start": 1,
        "step_size": 0.1,
        "shape": {"x": 11, "y": 18},
        "horizon": 80,
        "node_weight": "constant",
        "disc_size": "small",
        "n_players": 3,
        "Cx_lengthscale": 2,
        "Cx_noise": 0.001,
        "Fx_lengthscale": 1,
        "Fx_noise": 0.001,
        "Cx_beta": 1.5,
        "Fx_beta": 1.5,
        "generate": False,
        "env_file_name": 'env_data.pkl',
        "cov_module": 'Matern',
        "stochasticity": 0.0,
        "domains": "two_room_2",
        "num": 1,  # 替代原来的args.env
        "initial": 80
    },
    "alg": {
        "gamma": 1,
        "type": "NM",
        "ent_coef": 0.0,
        "epochs": 140,
        "lr": 0.02
    },
    "common": {
        "a": 1,
        "subgrad": "greedy",
        "grad": "pytorch",
        "algo": "both",
        "init": "deterministic",
        "batch_size": 3000
    },
    "visu": {
        "wb": "disabled",
        "a": 1
    }
}
env_load_path = workspace + \
    "/environments/" + params["env"]["node_weight"]+ "/env_1" 

params['env']['num'] = 1
# start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="code-" + params["env"]["node_weight"],
#     mode=params["visu"]["wb"],
#     config=params
# )

epochs = params["alg"]["epochs"]

H = params["env"]["horizon"]
MAX_Ret = 2*(H+1)
if params["env"]["disc_size"] == "large":
    MAX_Ret = 3*(H+2)
    
env = GridWorld(
    env_params=params["env"], common_params=params["common"], visu_params=params["visu"], env_file_path=env_load_path)
node_size = params["env"]["shape"]['x']*params["env"]["shape"]['y']
# TransitionMatrix = torch.zeros(node_size, node_size)

if params["env"]["node_weight"] == "entropy" or params["env"]["node_weight"] == "steiner_covering" or params["env"]["node_weight"] == "GP": 
    a_file = open(env_load_path +".pkl", "rb")
    data = pickle.load(a_file)
    a_file.close()

if params["env"]["node_weight"] == "entropy":
    env.cov = data
if params["env"]["node_weight"] == "steiner_covering":
    env.items_loc = data
if params["env"]["node_weight"] == "GP":
    env.weight = data

visu = Visu(env_params=params["env"])
# plt, fig = visu.stiener_grid( items_loc=env.items_loc, init=34)
# wandb.log({"chart": wandb.Image(fig)})
# plt.close()
# Hori_TransitionMatrix = torch.zeros(node_size*H, node_size*H)
# for node in env.horizon_transition_graph.nodes:
#     connected_edges = env.horizon_transition_graph.edges(node)
#     for u, v in connected_edges:
#         Hori_TransitionMatrix[u[0]*node_size+u[1], v[0]*node_size + v[1]] = 1.0
env.get_horizon_transition_matrix()
# policy = Policy(TransitionMatrix=TransitionMatrix, Hori_TransitionMatrix=Hori_TransitionMatrix, ActionTransitionMatrix=env.Hori_ActionTransitionMatrix[:, :, :, 0],
#                 agent_param=params["agent"], env_param=params["env"])


x_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6.4999, 6.5001, 7.4999, 7.5001, 8.4999, 8.5001, 9.4999, 9.5001, 10.4999, 10.5001, 11.4999, 11.5001, 12.4999, 12.5001, 13.4999, 13.5001, 14.4999, 14.5001, 15.4999, 15.5001, 16.4999, 16.5001, 17.4999, 17.5001]
y_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6.4999, 6.5001, 7.4999, 7.5001, 8.4999, 8.5001, 9.4999, 9.5001, 10.4999, 10.5001]


In [3]:
def select_cell_from_archive(archive):
    """
    Select a cell from the archive for exploration.
    Cells with the fewest selection counts are prioritized.
    """
    if not archive:
        return None, None

    # Find the minimum selection count
    min_times_selected = float('inf')
    for cell_id in archive:
        if archive[cell_id]['times_selected'] < min_times_selected:
            min_times_selected = archive[cell_id]['times_selected']
    
    # Find all cells with the minimum selection count
    least_visited_cells = []
    for cell_id in archive:
        if archive[cell_id]['times_selected'] == min_times_selected:
            least_visited_cells.append(cell_id)
            
    #  Randomly select one of these cells
    selected_cell_id = random.choice(least_visited_cells)
    
    return selected_cell_id, archive[selected_cell_id]

def sample_excellent_trajectories(filepath="two_Room_80_go_explore_archive_spacetime.pkl", 
                                  method='top_n', 
                                  n=10, 
                                  p=0.1, 
                                  threshold=0):
    """
        Load data from the Go-Explore archive and sample high-quality trajectories based on the specified method.

        Args:
            filepath (str): Path to the .pkl archive file.
            method (str): Sampling method. Options are 'top_n', 'top_p', or 'threshold'.
            n (int): Number of trajectories to sample for the 'top_n' method.
            p (float): Percentage of top trajectories to sample for the 'top_p' method (e.g., 0.1 means top 10%).
            threshold (float): Minimum reward threshold for the 'threshold' method.
        
        Returns:
            list: A list of trajectory dictionaries with high rewards, sorted in descending order of reward.
                  Returns an empty list if the file does not exist or the archive is empty.
    """
    # 1. Check if the file exists and load the data
    if not os.path.exists(filepath):
        print(f"Error: Archive file not found '{filepath}'")
        return []
    
    try:
        with open(filepath, "rb") as f:
            archive = pickle.load(f)
        if not archive:
            print("警告：存檔庫為空。")
            return []
    except Exception as e:
        print(f"讀取文件時出錯: {e}")
        return []

    # 2. 提取所有軌跡數據並按獎勵排序
    # archive.values() 返回的是包含 reward, states, actions 等信息的字典
    all_trajectories_data = list(archive.values())
    
    # 按 'reward' 鍵從高到低排序
    all_trajectories_data.sort(key=lambda x: x['reward'], reverse=True)

    # 3. 根據指定方法進行採樣
    sampled_trajectories = []
    if method == 'top_n':
        # 取獎勵最高的前 N 條
        num_to_sample = min(n, len(all_trajectories_data))
        sampled_trajectories = all_trajectories_data[:num_to_sample]
        print(f"方法: Top-N。從 {len(all_trajectories_data)} 條軌跡中篩選出最好的 {len(sampled_trajectories)} 條。")

    elif method == 'top_p':
        # 取獎勵最高的前 P%
        if not (0 < p <= 1):
            print("錯誤：百分比 'p' 必須在 (0, 1] 之間。")
            return []
        num_to_sample = int(len(all_trajectories_data) * p)
        sampled_trajectories = all_trajectories_data[:num_to_sample]
        print(f"方法: Top-P。從 {len(all_trajectories_data)} 條軌跡中篩選出最好的前 {p*100:.1f}% ({len(sampled_trajectories)} 條)。")

    elif method == 'threshold':
        # 取獎勵高於指定門檻的所有軌跡
        sampled_trajectories = [data for data in all_trajectories_data if data['reward'] >= threshold]
        print(f"方法: Threshold。從 {len(all_trajectories_data)} 條軌跡中篩選出 {len(sampled_trajectories)} 條獎勵不低於 {threshold} 的軌跡。")
        
    else:
        print(f"錯誤：未知的採樣方法 '{method}'。請使用 'top_n', 'top_p', 或 'threshold'。")

    return sampled_trajectories


In [4]:


# --- Core parameter setting ---
EXPLORATION_STEPS = 500000  # The total number of exploration steps can be adjusted as needed, and the larger the value, the more thorough the exploration
K_STEPS = 15            # Random exploration steps starting from selected cells 

# --- Go-Explore main  ---
def run_go_explore_phase1_spacetime(env=GridWorld, params=None,exploration_step=EXPLORATION_STEPS,isLoad=False, archive_path="go_explore_triplet_archive2.pkl",k=K_STEPS):
    """
    執行 Go-Explore 的第一階段探索（使用時空細胞）
    """
    print("--- Go-Explore Phase 1: Spacetime Exploration 開始 ---")
    
    # 1. 初始化存檔庫 (Archive)
    # ---【核心修改】：將細胞 Key 定義為 (時間, 狀態ID) ---
    # 結構: { (time, state_id): {'reward': float, 'states': list, 'actions': list, 'times_selected': int} }
    archive = {}
    if isLoad:
        print("正在從存檔庫載入已保存的時空細胞...")
        with open(archive_path, "rb") as f:
            archive = pickle.load(f)
        print(f"成功載入存檔庫，包含 {len(archive)} 個細胞。")


    
    env.common_params["batch_size"]=1
    env.initialize(params["env"]["initial"])
    initial_state_tensor = env.state.clone()
    initial_state_id = initial_state_tensor.item()
    
    # ---【核心修改】：初始細胞現在是 (時間=0, 狀態) ---
    initial_cell_key = (0, initial_state_id)
    
    initial_reward = calculate_submodular_reward([initial_state_tensor], env)
    
    archive[initial_cell_key] = {
        'reward': initial_reward,
        'states': [initial_state_tensor],
        'actions': [],
        'times_selected': 0
    }
    print(f"初始時空細胞加入存檔庫: Cell {initial_cell_key}, Reward: {initial_reward}")
    
    # 2. 執行 N 次探索循環
    pbar = tqdm(total=exploration_step, desc="Go-Exploring (Spacetime)")
    for step in range(exploration_step):
        # 2.1 選擇細胞 (選擇函數無需修改，它通用於任何 key)
        cell_key_to_explore_from, selected_cell_data = select_cell_from_archive(archive)
        
        if selected_cell_data is None:
            print("錯誤：存檔庫為空，無法繼續探索。")
            break
            
        archive[cell_key_to_explore_from]['times_selected'] += 1

        # 2.2 前往 (Go To) 該細胞狀態
        env.initialize(params["env"]["initial"])
        for action in selected_cell_data['actions']:
            env.step(0, torch.tensor([action]))

        # 2.3 從該狀態開始，隨機探索 (Explore) k 步
        current_states = selected_cell_data['states'][:]
        current_actions = selected_cell_data['actions'][:]
        
        k=random.randint(k,k)
        
        for _ in range(k):
            if len(current_actions) >= params["env"]["horizon"] - 1:
                break
                
            # random_action = random.randint(0, env.action_dim - 1)
            # env.step(0, torch.tensor([random_action]))
            # 
            # new_state_tensor = env.state.clone()
            
            current_state_tensor = env.state.clone()
            new_state_tensor=current_state_tensor # 初始化新状态张量
            while current_state_tensor== new_state_tensor:
                random_action = random.randint(0, env.action_dim - 1)
                while random_action==0:
                    random_action = random.randint(0, env.action_dim - 1)
                    
                env.step(0, torch.tensor([random_action]))
                
                new_state_tensor = env.state.clone()
                if current_state_tensor== new_state_tensor:
                    env.return_to_pre_step()
            
            current_states.append(new_state_tensor)
            current_actions.append(random_action)
            
            # ---【核心修改】：更新存檔庫時使用 (時間, 狀態) 作為 Key ---
            new_state_id = new_state_tensor.item()
            time_step = len(current_actions) # 當前時間步 = 已執行動作的數量
            new_cell_key = (time_step, new_state_id)
            
            new_reward = calculate_submodular_reward(current_states, env)
            
            if new_cell_key not in archive or new_reward > archive[new_cell_key]['reward']:
                archive[new_cell_key] = {
                    'reward': new_reward,
                    'states': current_states[:],
                    'actions': current_actions[:],
                    'times_selected': 0
                }
            # elif new_cell_key in archive and new_reward < archive[new_cell_key]['reward']:
            #     break
            
        if step% ( exploration_step / 10) == 0:
            print(f"探索步骤: {step / exploration_step * 100:.2f}%")
            print(f"當前存檔庫大小: {len(archive)}")
            _best_trajectory_data = max(archive.values(), key=lambda x: x['reward'])
            _max_reward = _best_trajectory_data['reward']
            print(f"找到的最佳獎勵值: {_max_reward}")
            print(f"對應的軌跡長度: {len(_best_trajectory_data['states'])}")
                
        pbar.update(1)
    pbar.close()

    # 3. 探索結束
    print("\n--- 探索完成 ---")
    if not archive:
        print("錯誤：存檔庫為空！")
        return None

    best_trajectory_data = max(archive.values(), key=lambda x: x['reward'])
    max_reward = best_trajectory_data['reward']
    print(f"找到的最佳獎勵值: {max_reward}")
    print(f"對應的軌跡長度: {len(best_trajectory_data['states'])}")

    # 4. 保存存檔庫
    archive_filename = archive_path
    with open(archive_filename, "wb") as f:
        pickle.dump(archive, f)
    print(f"完整的時空細胞存檔庫已保存至: {archive_filename}")
    
    return best_trajectory_data

In [6]:

# --- 運行最終升級版的 Go-Explore ---
best_found_trajectory = run_go_explore_phase1_spacetime(env, params,exploration_step=300000,archive_path="two_Room_80_go_explore_archive_file.pkl",isLoad=True,k=40)


--- Go-Explore Phase 1: Spacetime Exploration 開始 ---
正在從存檔庫載入已保存的時空細胞...
成功載入存檔庫，包含 5202 個細胞。
初始時空細胞加入存檔庫: Cell (0, 80), Reward: 0.0


Go-Exploring (Spacetime):   0%|          | 18/300000 [00:00<29:15, 170.86it/s]

探索步骤: 0.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  10%|█         | 30030/300000 [02:41<29:02, 154.97it/s]

探索步骤: 10.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  20%|██        | 60023/300000 [05:12<22:58, 174.13it/s]

探索步骤: 20.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  30%|███       | 90035/300000 [07:55<19:20, 180.87it/s]

探索步骤: 30.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  40%|████      | 120026/300000 [10:42<16:40, 179.90it/s]

探索步骤: 40.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  50%|█████     | 150023/300000 [13:29<14:53, 167.93it/s]

探索步骤: 50.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  60%|██████    | 180020/300000 [16:17<11:41, 170.95it/s]

探索步骤: 60.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  70%|███████   | 210026/300000 [19:04<08:19, 180.21it/s]

探索步骤: 70.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  80%|████████  | 240034/300000 [21:51<05:07, 195.07it/s]

探索步骤: 80.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime):  90%|█████████ | 270034/300000 [24:38<02:49, 177.17it/s]

探索步骤: 90.00%
當前存檔庫大小: 5202
找到的最佳獎勵值: 137
對應的軌跡長度: 78


Go-Exploring (Spacetime): 100%|██████████| 300000/300000 [27:31<00:00, 181.68it/s]



--- 探索完成 ---
找到的最佳獎勵值: 137
對應的軌跡長度: 78
完整的時空細胞存檔庫已保存至: two_Room_80_go_explore_archive_file.pkl


In [1]:

top_20_trajectories = sample_excellent_trajectories(filepath="go_explore_archive_file_two_Room_198.pkl",method='top_n', n=100)
if top_20_trajectories:
    print(f"其中最好的一條獎勵為: {top_20_trajectories[0]['reward']}")
    print(f"最差的一條（在這300條中）獎勵為: {top_20_trajectories[-1]['reward']}\n")
    

NameError: name 'sample_excellent_trajectories' is not defined

In [7]:
# 範例1：獲取獎勵最高的 20 條軌跡
print("--- 範例 1: 採樣 Top 20 ---")
top_20_trajectories = sample_excellent_trajectories(filepath=".pkl",method='top_n', n=100)
if top_20_trajectories:
    print(f"其中最好的一條獎勵為: {top_20_trajectories[0]['reward']}")
    print(f"最差的一條（在這300條中）獎勵為: {top_20_trajectories[-1]['reward']}\n")

print(top_20_trajectories[0])




--- 範例 1: 採樣 Top 20 ---
方法: Top-N。從 5202 條軌跡中篩選出最好的 100 條。
其中最好的一條獎勵為: 138
最差的一條（在這300條中）獎勵為: 134

{'reward': 138, 'states': [tensor([80]), tensor([79]), tensor([78]), tensor([77]), tensor([76]), tensor([75]), tensor([57]), tensor([39]), tensor([40]), tensor([22]), tensor([4]), tensor([3]), tensor([2]), tensor([1]), tensor([0]), tensor([18]), tensor([36]), tensor([37]), tensor([55]), tensor([73]), tensor([72]), tensor([90]), tensor([108]), tensor([126]), tensor([127]), tensor([145]), tensor([163]), tensor([164]), tensor([165]), tensor([147]), tensor([148]), tensor([130]), tensor([112]), tensor([94]), tensor([95]), tensor([96]), tensor([97]), tensor([98]), tensor([80]), tensor([81]), tensor([82]), tensor([83]), tensor([84]), tensor([102]), tensor([103]), tensor([104]), tensor([122]), tensor([140]), tensor([139]), tensor([157]), tensor([175]), tensor([176]), tensor([177]), tensor([178]), tensor([160]), tensor([142]), tensor([124]), tensor([106]), tensor([88]), tensor([70]), tensor([52]),

In [16]:
#基于embedding的模仿学习

import torch
import torch.nn.functional as F  # 这里导入 F
import torch.nn as nn

class TemporalStateEncoder(nn.Module):
    def __init__(self, num_states=900, embed_dim=16, hidden_dim=32):
        super().__init__()
        self.embedding = nn.Embedding(num_states, embed_dim)
        self.lstm = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim)

    def forward(self, state_seq):
        indices = [i for i in state_seq if i >= 0]
        if not indices:
            return torch.zeros(self.lstm.hidden_size)

        input_emb = self.embedding(torch.tensor(indices).long()).unsqueeze(1)  # [T, 1, D]
        _, (h_n, _) = self.lstm(input_emb)
        return h_n.squeeze(0).squeeze(0)  # [hidden_dim]

def encode_temporal_state(state_seq, embed_table):
    """
    输入: state_seq: Tensor[H]，如 [34, 33, -1, ..., -1]
    输出: Tensor[embed_dim]，嵌入向量
    """
    indices = [i for i in state_seq if i >= 0]  # 去除 -1
    if not indices:
        return torch.zeros(embed_table.embedding_dim)
    indices_tensor = torch.tensor(indices, dtype=torch.long)
    embeds = embed_table(indices_tensor)
    return embeds.mean(dim=0)
def encode_temporal_state2(state_seq, encoder):
    return encoder(state_seq)

def submodular_selector_temporal(trajectories, embed_table,temporal_encoder, budget=50, lambda_div=0.5, per_traj_limit=True,horizon=40):
    """
    改进版子模选择器：
    - 使用欧几里得距离作为 diversity 惩罚项
    - 可选启用：每条轨迹最多选一个状态（保证轨迹多样性）
    """
    state_vectors = []
    action_labels = []
    traj_ids = []

    for traj_id, traj in enumerate(trajectories):
        states = [int(s.item()) for s in traj['states']]
        actions = traj['actions']
        for t, action in enumerate(actions):
            temporal_state = [-1]*horizon
            for h in range(t+1):
                temporal_state[h] = states[h]
            temporal_tensor = torch.tensor(temporal_state, dtype=torch.long)
            # encoded = encode_temporal_state(temporal_tensor, embed_table) # 使用嵌入表编码
            # temporal_encoder = TemporalStateEncoder() #使用lstm+embedding编码
            vec = encode_temporal_state2(temporal_tensor, temporal_encoder)
            state_vectors.append(vec.detach())
            action_labels.append(action)
            traj_ids.append(traj_id)

    print(f"Total states collected: {len(state_vectors)}")
    all_states = torch.stack(state_vectors)
    all_actions = torch.tensor(action_labels, dtype=torch.long)
    traj_ids = torch.tensor(traj_ids)

    selected_indices = []
    selected_vectors = []
    selected_trajs = set()

    for _ in range(min(budget, len(all_states))):
        best_score, best_idx = -float("inf"), -1

        for i in range(len(all_states)):
            if i in selected_indices:
                continue
            if per_traj_limit and traj_ids[i].item() in selected_trajs:
                continue

            candidate = all_states[i].unsqueeze(0)
            reward = torch.abs(candidate).mean().item()

            if selected_vectors:
                selected_tensor = torch.stack(selected_vectors)
                sims = ((candidate - selected_tensor)**2).sum(dim=1)  # Euclidean squared
                diversity_penalty = -sims.mean().item()  # maximize distance
            else:
                diversity_penalty = 0

            score = reward + lambda_div * diversity_penalty

            if score > best_score:
                best_score = score
                best_idx = i

        if best_idx == -1:
            break

        selected_indices.append(best_idx)
        selected_vectors.append(all_states[best_idx])
        selected_trajs.add(traj_ids[best_idx].item())

    return all_states[selected_indices], all_actions[selected_indices],all_states,all_actions

In [18]:

embed_dim = 16
num_states = 900
elite_trajectories_data = sample_excellent_trajectories(
        filepath="two_Room_80_go_explore_archive_file.pkl", 
        method='top_n', 
        n=50)

embed_table = torch.nn.Embedding(num_states, 16)
temporal_encoder = TemporalStateEncoder(num_states=198, embed_dim=16, hidden_dim=32)
print(embed_table)
selected_states, selected_actions, all_states,all_actions = submodular_selector_temporal(
    trajectories=elite_trajectories_data[:],
    embed_table=embed_table,
    temporal_encoder=temporal_encoder,
    budget=500,
    lambda_div=2.0,
    per_traj_limit=True,
    horizon=params["env"]["horizon"]
)

方法: Top-N。從 5202 條軌跡中篩選出最好的 50 條。
Embedding(900, 16)
Total states collected: 3866
