In [7]:
import argparse
import errno
import os
import random
from importlib.metadata import requires
from timeit import timeit
import dill as pickle
import numpy as np
import scipy
import torch
import wandb
import yaml
from sympy import Matrix, MatrixSymbol, derive_by_array, symarray
from torch.distributions import Categorical
from tqdm.notebook import tqdm

from utils.environment import GridWorld
from utils.network import append_state
from utils.network import policy as agent_net
from utils.visualization import Visu
from utils.subpo import calculate_submodular_reward, compute_subpo_advantages


workspace = "NM"

In [8]:

params = {
    "env": {
        "start": 1,
        "step_size": 0.1,
        "shape": {"x": 7, "y": 14},
        "horizon": 40,
        "node_weight": "constant",
        "disc_size": "small",
        "n_players": 3,
        "Cx_lengthscale": 2,
        "Cx_noise": 0.001,
        "Fx_lengthscale": 1,
        "Fx_noise": 0.001,
        "Cx_beta": 1.5,
        "Fx_beta": 1.5,
        "generate": False,
        "env_file_name": "env_data.pkl",
        "cov_module": "Matern",
        "stochasticity": 0.0,
        "domains": "two_room"
    },
    "alg": {
        "gamma": 1,
        "type": "NM",
        "ent_coef": 0.05,
        "epochs": 140,
        "lr": 0.01
    },
    "common": {
        "a": 1,
        "subgrad": "greedy",
        "grad": "pytorch",
        "algo": "both",
        "init": "deterministic",
        "batch_size": 300
    },
    "visu": {
        "wb": "disabled",
        "a": 1
    }
}
env_load_path = workspace + \
    "/environments/" + params["env"]["node_weight"]+ "/env_1" 

params['env']['num'] = 1
# start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="code-" + params["env"]["node_weight"],
#     mode=params["visu"]["wb"],
#     config=params
# )

epochs = params["alg"]["epochs"]

H = params["env"]["horizon"]
MAX_Ret = 2*(H+1)
if params["env"]["disc_size"] == "large":
    MAX_Ret = 3*(H+2)
    
env = GridWorld(
    env_params=params["env"], common_params=params["common"], visu_params=params["visu"], env_file_path=env_load_path)
node_size = params["env"]["shape"]['x']*params["env"]["shape"]['y']
# TransitionMatrix = torch.zeros(node_size, node_size)

if params["env"]["node_weight"] == "entropy" or params["env"]["node_weight"] == "steiner_covering" or params["env"]["node_weight"] == "GP": 
    a_file = open(env_load_path +".pkl", "rb")
    data = pickle.load(a_file)
    a_file.close()

if params["env"]["node_weight"] == "entropy":
    env.cov = data
if params["env"]["node_weight"] == "steiner_covering":
    env.items_loc = data
if params["env"]["node_weight"] == "GP":
    env.weight = data

visu = Visu(env_params=params["env"])
# plt, fig = visu.stiener_grid( items_loc=env.items_loc, init=34)
# wandb.log({"chart": wandb.Image(fig)})
# plt.close()
# Hori_TransitionMatrix = torch.zeros(node_size*H, node_size*H)
# for node in env.horizon_transition_graph.nodes:
#     connected_edges = env.horizon_transition_graph.edges(node)
#     for u, v in connected_edges:
#         Hori_TransitionMatrix[u[0]*node_size+u[1], v[0]*node_size + v[1]] = 1.0
env.get_horizon_transition_matrix()
# policy = Policy(TransitionMatrix=TransitionMatrix, Hori_TransitionMatrix=Hori_TransitionMatrix, ActionTransitionMatrix=env.Hori_ActionTransitionMatrix[:, :, :, 0],
#                 agent_param=params["agent"], env_param=params["env"])


x_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6.4999, 6.5001, 7.4999, 7.5001, 8.4999, 8.5001, 9.4999, 9.5001, 10.4999, 10.5001, 11.4999, 11.5001, 12.4999, 12.5001, 13.4999, 13.5001]
y_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6.4999, 6.5001]


In [26]:
def select_cell_from_archive(archive):
    """
    Select a cell from the archive for exploration.
    Cells with the fewest selection counts are prioritized.
    """
    if not archive:
        return None, None

    # Find the minimum selection count
    min_times_selected = float('inf')
    for cell_id in archive:
        if archive[cell_id]['times_selected'] < min_times_selected:
            min_times_selected = archive[cell_id]['times_selected']
    
    # Find all cells with the minimum selection count
    least_visited_cells = []
    for cell_id in archive:
        if archive[cell_id]['times_selected'] == min_times_selected:
            least_visited_cells.append(cell_id)
            
    #  Randomly select one of these cells
    selected_cell_id = random.choice(least_visited_cells)
    
    return selected_cell_id, archive[selected_cell_id]

def sample_excellent_trajectories(filepath="go_explore_archive_spacetime.pkl", 
                                  method='top_n', 
                                  n=10, 
                                  p=0.1, 
                                  threshold=0):
    """
        Load data from the Go-Explore archive and sample high-quality trajectories based on the specified method.

        Args:
            filepath (str): Path to the .pkl archive file.
            method (str): Sampling method. Options are 'top_n', 'top_p', or 'threshold'.
            n (int): Number of trajectories to sample for the 'top_n' method.
            p (float): Percentage of top trajectories to sample for the 'top_p' method (e.g., 0.1 means top 10%).
            threshold (float): Minimum reward threshold for the 'threshold' method.
        
        Returns:
            list: A list of trajectory dictionaries with high rewards, sorted in descending order of reward.
                  Returns an empty list if the file does not exist or the archive is empty.
    """
    # 1. Check if the file exists and load the data
    if not os.path.exists(filepath):
        print(f"Error: Archive file not found '{filepath}'")
        return []
    
    try:
        with open(filepath, "rb") as f:
            archive = pickle.load(f)
        if not archive:
            print("警告：存檔庫為空。")
            return []
    except Exception as e:
        print(f"讀取文件時出錯: {e}")
        return []

    # 2. 提取所有軌跡數據並按獎勵排序
    # archive.values() 返回的是包含 reward, states, actions 等信息的字典
    all_trajectories_data = list(archive.values())
    
    # 按 'reward' 鍵從高到低排序
    all_trajectories_data.sort(key=lambda x: x['reward'], reverse=True)

    # 3. 根據指定方法進行採樣
    sampled_trajectories = []
    if method == 'top_n':
        # 取獎勵最高的前 N 條
        num_to_sample = min(n, len(all_trajectories_data))
        sampled_trajectories = all_trajectories_data[:num_to_sample]
        print(f"方法: Top-N。從 {len(all_trajectories_data)} 條軌跡中篩選出最好的 {len(sampled_trajectories)} 條。")

    elif method == 'top_p':
        # 取獎勵最高的前 P%
        if not (0 < p <= 1):
            print("錯誤：百分比 'p' 必須在 (0, 1] 之間。")
            return []
        num_to_sample = int(len(all_trajectories_data) * p)
        sampled_trajectories = all_trajectories_data[:num_to_sample]
        print(f"方法: Top-P。從 {len(all_trajectories_data)} 條軌跡中篩選出最好的前 {p*100:.1f}% ({len(sampled_trajectories)} 條)。")

    elif method == 'threshold':
        # 取獎勵高於指定門檻的所有軌跡
        sampled_trajectories = [data for data in all_trajectories_data if data['reward'] >= threshold]
        print(f"方法: Threshold。從 {len(all_trajectories_data)} 條軌跡中篩選出 {len(sampled_trajectories)} 條獎勵不低於 {threshold} 的軌跡。")
        
    else:
        print(f"錯誤：未知的採樣方法 '{method}'。請使用 'top_n', 'top_p', 或 'threshold'。")

    return sampled_trajectories


In [23]:


# --- Core parameter setting ---
EXPLORATION_STEPS = 1000000  # The total number of exploration steps can be adjusted as needed, and the larger the value, the more thorough the exploration
K_STEPS = 15               # Random exploration steps starting from selected cells 

# --- Go-Explore main  ---
def run_go_explore_phase1_spacetime(env=GridWorld, params=None):
    """
    執行 Go-Explore 的第一階段探索（使用時空細胞）
    """
    print("--- Go-Explore Phase 1: Spacetime Exploration 開始 ---")
    
    # 1. 初始化存檔庫 (Archive)
    # ---【核心修改】：將細胞 Key 定義為 (時間, 狀態ID) ---
    # 結構: { (time, state_id): {'reward': float, 'states': list, 'actions': list, 'times_selected': int} }
    archive = {}
    
    env.common_params["batch_size"]=1
    env.initialize()
    initial_state_tensor = env.state.clone()
    initial_state_id = initial_state_tensor.item()
    
    # ---【核心修改】：初始細胞現在是 (時間=0, 狀態) ---
    initial_cell_key = (0, initial_state_id)
    
    initial_reward = calculate_submodular_reward([initial_state_tensor], env)
    
    archive[initial_cell_key] = {
        'reward': initial_reward,
        'states': [initial_state_tensor],
        'actions': [],
        'times_selected': 0
    }
    print(f"初始時空細胞加入存檔庫: Cell {initial_cell_key}, Reward: {initial_reward}")
    
    # 2. 執行 N 次探索循環
    pbar = tqdm(total=EXPLORATION_STEPS, desc="Go-Exploring (Spacetime)")
    for step in range(EXPLORATION_STEPS):
        # 2.1 選擇細胞 (選擇函數無需修改，它通用於任何 key)
        cell_key_to_explore_from, selected_cell_data = select_cell_from_archive(archive)
        
        if selected_cell_data is None:
            print("錯誤：存檔庫為空，無法繼續探索。")
            break
            
        archive[cell_key_to_explore_from]['times_selected'] += 1

        # 2.2 前往 (Go To) 該細胞狀態
        env.initialize()
        for action in selected_cell_data['actions']:
            env.step(0, torch.tensor([action]))

        # 2.3 從該狀態開始，隨機探索 (Explore) k 步
        current_states = selected_cell_data['states'][:]
        current_actions = selected_cell_data['actions'][:]
        
        for _ in range(K_STEPS):
            if len(current_actions) >= params["env"]["horizon"] - 1:
                break
                
            random_action = random.randint(0, env.action_dim - 1)
            env.step(0, torch.tensor([random_action]))
            
            new_state_tensor = env.state.clone()
            
            current_states.append(new_state_tensor)
            current_actions.append(random_action)
            
            # ---【核心修改】：更新存檔庫時使用 (時間, 狀態) 作為 Key ---
            new_state_id = new_state_tensor.item()
            time_step = len(current_actions) # 當前時間步 = 已執行動作的數量
            new_cell_key = (time_step, new_state_id)
            
            new_reward = calculate_submodular_reward(current_states, env)
            
            if new_cell_key not in archive or new_reward > archive[new_cell_key]['reward']:
                archive[new_cell_key] = {
                    'reward': new_reward,
                    'states': current_states[:],
                    'actions': current_actions[:],
                    'times_selected': 0
                }
        
        pbar.update(1)
    pbar.close()

    # 3. 探索結束
    print("\n--- 探索完成 ---")
    if not archive:
        print("錯誤：存檔庫為空！")
        return None

    best_trajectory_data = max(archive.values(), key=lambda x: x['reward'])
    max_reward = best_trajectory_data['reward']
    print(f"找到的最佳獎勵值: {max_reward}")
    print(f"對應的軌跡長度: {len(best_trajectory_data['states'])}")

    # 4. 保存存檔庫
    archive_filename = "go_explore_archive_spacetime.pkl"
    with open(archive_filename, "wb") as f:
        pickle.dump(archive, f)
    print(f"完整的時空細胞存檔庫已保存至: {archive_filename}")
    
    return best_trajectory_data

# --- 運行最終升級版的 Go-Explore ---
best_found_trajectory = run_go_explore_phase1_spacetime(env, params)



--- Go-Explore Phase 1: Spacetime Exploration 開始 ---
初始時空細胞加入存檔庫: Cell (0, 34), Reward: 0.0


Go-Exploring (Spacetime):   0%|          | 0/1000000 [00:00<?, ?it/s]


--- 探索完成 ---
找到的最佳獎勵值: 68
對應的軌跡長度: 40
完整的時空細胞存檔庫已保存至: go_explore_archive_spacetime.pkl


In [33]:
print(best_found_trajectory)
print(calculate_submodular_reward(best_found_trajectory['states'],env))
# --- 使用範例 ---

# 範例1：獲取獎勵最高的 20 條軌跡
print("--- 範例 1: 採樣 Top 20 ---")
top_20_trajectories = sample_excellent_trajectories(method='top_n', n=300)
if top_20_trajectories:
    print(f"其中最好的一條獎勵為: {top_20_trajectories[0]['reward']}")
    print(f"最差的一條（在這20條中）獎勵為: {top_20_trajectories[-1]['reward']}\n")

# 範例2：獲取獎勵排名前 5% 的軌跡
print("--- 範例 2: 採樣 Top 5% ---")
top_5_percent_trajectories = sample_excellent_trajectories(method='top_p', p=0.05)
if top_5_percent_trajectories:
    # 打印其中一條軌跡的詳細信息以供檢查
    sample_traj_data = top_5_percent_trajectories[0]
    print(f"抽樣檢查最好的一條軌跡：獎勵={sample_traj_data['reward']}, 長度={len(sample_traj_data['states'])}\n")

# 範例3：獲取所有獎勵值大於等於 45 的軌跡
print("--- 範例 3: 採樣獎勵 >= 45 的軌跡 ---")
high_reward_trajectories = sample_excellent_trajectories(method='threshold', threshold=45)
if high_reward_trajectories:
    print(f"所有高分軌跡的平均獎勵為: {np.mean([d['reward'] for d in high_reward_trajectories]):.2f}\n")


{'reward': 68, 'states': [tensor([34]), tensor([35]), tensor([36]), tensor([37]), tensor([38]), tensor([24]), tensor([10]), tensor([11]), tensor([12]), tensor([26]), tensor([40]), tensor([54]), tensor([68]), tensor([82]), tensor([81]), tensor([80]), tensor([66]), tensor([52]), tensor([51]), tensor([37]), tensor([36]), tensor([35]), tensor([34]), tensor([33]), tensor([32]), tensor([31]), tensor([30]), tensor([16]), tensor([2]), tensor([1]), tensor([0]), tensor([14]), tensor([28]), tensor([42]), tensor([56]), tensor([70]), tensor([71]), tensor([72]), tensor([58]), tensor([44])], 'actions': [1, 1, 1, 1, 4, 4, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 3, 4, 3, 3, 3, 3, 3, 3, 3, 4, 4, 3, 3, 2, 2, 2, 2, 2, 1, 1, 4, 4], 'times_selected': 361}
68
--- 範例 1: 採樣 Top 20 ---
方法: Top-N。從 2312 條軌跡中篩選出最好的 300 條。
其中最好的一條獎勵為: 68
最差的一條（在這20條中）獎勵為: 60

--- 範例 2: 採樣 Top 5% ---
方法: Top-P。從 2312 條軌跡中篩選出最好的前 5.0% (115 條)。
抽樣檢查最好的一條軌跡：獎勵=68, 長度=40

--- 範例 3: 採樣獎勵 >= 45 的軌跡 ---
方法: Threshold。從 2312 條軌跡中篩選出 804 條獎勵不低於 45

In [8]:

# Agent's policy
if params["alg"]["type"]=="M" or params["alg"]["type"]=="SRL":
    agent = agent_net(2, env.action_dim)
else:
    agent = agent_net(H-1, env.action_dim)
optim = torch.optim.Adam(agent.parameters(), lr=params["alg"]["lr"])

for t_eps in range(epochs):
    mat_action = []
    mat_state = []
    mat_return = []
    marginal_return = []
    mat_done = []
    
    # Used to store data for all trajectories within a batch (new)
    batch_trajectories = [[] for _ in range(params["common"]["batch_size"])]
    batch_log_probs = [[] for _ in range(params["common"]["batch_size"])]
    
    # print(t_eps)
    env.initialize()
    mat_state.append(env.state)
    init_state = env.state
    # print(torch.mean(env.weighted_traj_return([init_state])))
    # print(t_eps, " ", mat_state, " ", env.weighted_traj_return(mat_state))
    list_batch_state = []
    for h_iter in range(H-1):
        if params["alg"]["type"]=="M" or params["alg"]["type"]=="SRL":
            batch_state = mat_state[-1].reshape(-1, 1).float()
            # append time index to the state
            batch_state = torch.cat(
                [batch_state, h_iter*torch.ones_like(batch_state)], 1)
        else:
            batch_state = append_state(mat_state, H-1)
        action_prob = agent(batch_state)
        # action_prob = pi_h_s_a[h_iter, mat_state[-1]]
        # policy_dist = Categorical(torch.nn.Softmax()(action_prob))
        policy_dist = Categorical(action_prob)
        actions = policy_dist.sample()
        
        # TODO： Record log_deb for each trajectory in the batch (new)
        log_probs = policy_dist.log_prob(actions)
        for i in range(params["common"]["batch_size"]):
            batch_log_probs[i].append(log_probs[i])
        
        mat_action.append(actions)
        
        env.step(h_iter, actions)
        mat_state.append(env.state)  # s+1
        # print(t_eps, " ", mat_state, " ", env.weighted_traj_return(mat_state))
        # mat_return.append(env.batched_marginal_coverage(
        #     mat_state, [init_state]))
        
        # TODO：Record the status for each trajectory in the batch
        for i in range(params["common"]["batch_size"]):
            # The shape of env.state is (batch_size, 1)，, and we take the state of the i sample 
            batch_trajectories[i].append(env.state[i])
            
        mat_return.append(env.weighted_traj_return(mat_state, type = params["alg"]["type"]))
        if h_iter ==0:
            marginal_return.append(mat_return[h_iter])
        else:
            # if params["alg"]["type"]=="SRL":
            marginal_return.append(mat_return[h_iter])
            # else:
            # marginal_return.append(mat_return[h_iter] - mat_return[h_iter-1])
        list_batch_state.append(batch_state)
        # mat_return.append(env.weighted_traj_return(
        #     mat_state) - env.weighted_traj_return([init_state]))

    ###################
    # Compute gradients
    ###################
    # --- Gradient calculation and update stage ---
    all_advantages = []
    final_rewards = []
    # 1. Calculate the corresponding SUBPO advantage function for each trajectory in the batch
    for i in range(params["common"]["batch_size"]):
        # Construct a single trajectory, and add an initial state
        single_trajectory = [init_state[i]] + batch_trajectories[i]
        
        # 計算這條軌跡的最終獎勵，用於監控
        final_rewards.append(calculate_submodular_reward(single_trajectory, env))

        # 計算正確的 SUBPO 優勢函數
        advantages = compute_subpo_advantages(single_trajectory, env)
        all_advantages.extend(advantages[:-1]) 


    all_advantages = torch.tensor(all_advantages, dtype=torch.float32)

    # Convert log_debs from a list of lists to a flat tensor
    flat_log_probs = torch.cat([torch.stack(lp) for lp in batch_log_probs])

    # 標準化優勢函數
    if len(all_advantages) > 1:
        all_advantages = (all_advantages - all_advantages.mean()) / (all_advantages.std() + 1e-9)

    # 2. 計算損失函數
    ent_coef = params["alg"]["ent_coef"]
    # 移除熵的衰減，使用固定的係數以保持探索
    entropy_term = policy_dist.entropy().mean() 
    
    loss = -1 * (torch.mean(flat_log_probs * all_advantages) + ent_coef * entropy_term)

    # 3. 更新梯度
    optim.zero_grad()
    loss.backward()
    optim.step()

    # --- 打印日誌 ---
    obj_mean = torch.tensor(final_rewards).float().mean()
    obj_max = torch.tensor(final_rewards).float().max()
    obj_median = torch.tensor(final_rewards).float().median()
    obj_min = torch.tensor(final_rewards).float().min()

    if t_eps % 1 == 0:
        print(f"Epoch {t_eps} | Mean: {obj_mean:.2f} | Max: {obj_max:.2f} | Median: {obj_median:.2f} | Min: {obj_min:.2f} | Entropy: {entropy_term:.4f} | Loss: {loss.item():.4f}")


    
# wandb.finish()

Epoch 0 | Mean: 16.84 | Max: 32.00 | Median: 16.00 | Min: 7.00 | Entropy: 1.5830 | Loss: -0.0802
Epoch 1 | Mean: 19.65 | Max: 31.00 | Median: 21.00 | Min: 8.00 | Entropy: 1.5573 | Loss: -0.0649
Epoch 2 | Mean: 22.50 | Max: 34.00 | Median: 23.00 | Min: 10.00 | Entropy: 1.5602 | Loss: -0.0949
Epoch 3 | Mean: 23.14 | Max: 35.00 | Median: 23.00 | Min: 12.00 | Entropy: 1.5839 | Loss: -0.1523
Epoch 4 | Mean: 22.84 | Max: 32.00 | Median: 22.00 | Min: 17.00 | Entropy: 1.5485 | Loss: -0.2059
Epoch 5 | Mean: 23.82 | Max: 34.00 | Median: 24.00 | Min: 14.00 | Entropy: 1.5358 | Loss: -0.2533
Epoch 6 | Mean: 24.75 | Max: 33.00 | Median: 25.00 | Min: 18.00 | Entropy: 1.5549 | Loss: -0.2875
Epoch 7 | Mean: 25.57 | Max: 33.00 | Median: 26.00 | Min: 14.00 | Entropy: 1.5556 | Loss: -0.3120
Epoch 8 | Mean: 26.03 | Max: 35.00 | Median: 26.00 | Min: 14.00 | Entropy: 1.5302 | Loss: -0.3111
Epoch 9 | Mean: 25.78 | Max: 33.00 | Median: 26.00 | Min: 10.00 | Entropy: 1.5429 | Loss: -0.3103
Epoch 10 | Mean: 25.53

In [27]:
params["common"]["batch_size"]=10000
env = GridWorld(
    env_params=params["env"], common_params=params["common"], visu_params=params["visu"], env_file_path=env_load_path)
node_size = params["env"]["shape"]['x']*params["env"]["shape"]['y']
env.get_horizon_transition_matrix()
env.initialize()
init_state = env.state
mat_action = []
mat_state = []
mat_return = []
marginal_return = []
mat_state.append(env.state)
for h_iter in range(H-1):
    if params["alg"]["type"]=="M" or params["alg"]["type"]=="SRL":
        batch_state = env.state.reshape(-1, 1).float()
        # append time index to the state
        batch_state = torch.cat(
            [batch_state, h_iter*torch.ones_like(batch_state)], 1)
    else:
        batch_state = append_state(mat_state, H-1)
    action_prob = agent(batch_state)
    policy_dist = Categorical(action_prob)
    actions = policy_dist.sample()
    env.step(h_iter, actions)
    mat_state.append(env.state)  # s+1
    mat_action.append(actions)
    mat_return.append(env.weighted_traj_return(mat_state, type = params["alg"]["type"]))
    # print("Action ", actions, " state ", env.state," mat return ", mat_return[-1])
obj = env.weighted_traj_return(mat_state).float()
print( " mean ", obj.mean(), " max ",
          obj.max(), " median ", obj.median(), " min ", obj.min(), " ent ", policy_dist.entropy().mean().detach())
max_index = torch.argmax(obj)
print("Max index ", max_index)
for i in range(len(mat_state)-1):
    print("State ", i, " ", mat_state[i][max_index], " Action ", mat_action[i][max_index], " Return ", mat_return[i][max_index])

 mean  tensor(34.9254)  max  tensor(50.)  median  tensor(36.)  min  tensor(18.)  ent  tensor(1.5026)
Max index  tensor(3885)
State  0   tensor(34)  Action  tensor(1)  Return  tensor(6)
State  1   tensor(35)  Action  tensor(1)  Return  tensor(8)
State  2   tensor(36)  Action  tensor(1)  Return  tensor(10)
State  3   tensor(37)  Action  tensor(1)  Return  tensor(12)
State  4   tensor(38)  Action  tensor(1)  Return  tensor(14)
State  5   tensor(39)  Action  tensor(4)  Return  tensor(16)
State  6   tensor(25)  Action  tensor(1)  Return  tensor(18)
State  7   tensor(26)  Action  tensor(2)  Return  tensor(19)
State  8   tensor(40)  Action  tensor(2)  Return  tensor(21)
State  9   tensor(54)  Action  tensor(2)  Return  tensor(23)
State  10   tensor(68)  Action  tensor(2)  Return  tensor(25)
State  11   tensor(82)  Action  tensor(4)  Return  tensor(25)
State  12   tensor(68)  Action  tensor(3)  Return  tensor(27)
State  13   tensor(67)  Action  tensor(3)  Return  tensor(29)
State  14   tensor(