<a href="https://colab.research.google.com/github/psiudo/ooo/blob/main/v11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ==================================
# 셀 1: Google Drive 마운트
# ==================================
from google.colab import drive
drive.mount('/content/drive')

print("✅ Google Drive가 성공적으로 마운트되었습니다.")

Mounted at /content/drive
✅ Google Drive가 성공적으로 마운트되었습니다.


In [2]:
# ==================================
# 셀 2: 필수 라이브러리 설치
# ==================================
!pip install --upgrade torch-geometric
print("✅ 필수 라이브러리가 준비되었습니다.")

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m67.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
✅ 필수 라이브러리가 준비되었습니다.


In [3]:
!pip install numba



In [None]:
########## v11 : HRL 트랜스포머 GNN 보상함수의 unambiguity ##########


# ==============================================================================
# 섹션 1: 설정 및 기본 유틸리티 (리팩토링 버전)
# ==============================================================================
import json
import torch.optim as optim
import collections
import random
import os
import torch
import logging
from torch_geometric.data import Data
import pickle
from torch_geometric.utils import to_dense_batch
import logging
import sys
from tqdm import tqdm
import torch.nn as nn
import numpy as np
from collections import defaultdict, deque
import multiprocessing as mp
import random

import numba
from numba.core import types
from numba.typed import List

def setup_logger():
    """
    다른 라이브러리의 로깅 설정을 모두 초기화하고,
    우리가 원하는 설정으로 강제 적용하는 함수.
    """
    # 루트 로거를 가져옵니다.
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO) # 로그 레벨 설정

    # 루트 로거에 연결된 모든 기존 핸들러(Handler)를 제거합니다. (가장 중요)
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)

    # 우리가 원하는 새로운 핸들러를 생성하여 추가합니다.
    handler = logging.StreamHandler(sys.stdout) # 로그를 콘솔에 출력
    formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    handler.setFormatter(formatter)
    root_logger.addHandler(handler)


class Config:
    """학습과 관련된 모든 하이퍼파라미터와 설정을 관리하는 클래스"""

    # --- 경로 설정 ---
    DRIVE_PROJECT_ROOT = '/content/drive/MyDrive/OptiChallenge'
    PROBLEM_DIR  = os.path.join(DRIVE_PROJECT_ROOT, 'Exercise_Problems')
    LOG_DIR      = os.path.join(DRIVE_PROJECT_ROOT, 'v12_logs_refactored')
    MODEL_DIR    = os.path.join(DRIVE_PROJECT_ROOT, 'v12_models_refactored')

    # 💡 새 expert pkl  ‘expert_probX_TYY.pkl’ 형식으로 있다고 가정
    EXPERT_DIR   = os.path.join(DRIVE_PROJECT_ROOT)
    EXPERT_GLOB  = "expert_prob[1248]_T*.pkl"
    MAX_EXPERT_SAMPLES = 40_000

    @property
    def EXPERT_DATA_PATHS(self):
        import glob
        paths = sorted(glob.glob(os.path.join(self.EXPERT_DIR, self.EXPERT_GLOB)))
        if not paths:
            logging.warning(f"[Config] No expert pkl found under {self.EXPERT_DIR}")
        return paths

    # --- 학습 제어 ---
    TOTAL_MANAGER_STEPS = 500_000       # Manager 에이전트의 총 학습 스텝
    CURRICULUM_STEPS = 2000            # 전문가 정책을 모방하는 커리큘럼 학습 스텝 수
    PRINT_INTERVAL_MANAGER_STEPS = 20  # 학습 중 로그 출력 간격 (Manager 스텝 기준)
    EVAL_INTERVAL_MANAGER_STEPS = 2000 # 모델 평가 및 저장 간격 (Manager 스텝 기준)
    EVAL_EPISODES = 10                  # 평가 시 실행할 에피소드 수
    CURRICULUM_TRANSITION_STEP = 20000 # 💡 [추가] 문제 난이도 커리큘럼 전환 시점


    # --- 모방 학습 (Worker) ---
    IMITATION_LEARNING_EPOCHS = 200      # Worker 모방 학습 에폭 수
    IMITATION_LR = 1e-4                 # Worker 모방 학습 Learning Rate
    IMITATION_BATCH_SIZE = 512          # Worker 모방 학습 배치 크기

    # --- Manager 에이전트 설정 ---
    MANAGER_STATE_DIM = 6               # Manager 상태 벡터의 차원
    MANAGER_ACTION_DIM = 5              # Manager 행동의 가짓수
    MANAGER_LR = 3e-4                   # Manager Learning Rate
    MANAGER_GAMMA = 0.99                # Manager 할인율 (Gamma)
    MANAGER_ENTROPY_COEF = 0.05         # Manager 엔트로피 보너스 계수 (탐험 장려)
    MANAGER_NUM_STEPS_PER_UPDATE = 512  # Manager 업데이트를 위한 데이터 수집 스텝

    # [신규] 보상 체계 하이퍼파라미터
    REPEAT_PENALTY = -1             # 같은 행동 반복 페널티
    STEP_PENALTY_WEIGHT = 0.001       # Worker 스텝당 시간 페널티 가중치
    NO_PROGRESS_PENALTY = -2.0        # Worker가 진척 없이 종료 시 페널티
    SHAPING_REWARD_WEIGHT = 1.5       # PBRS 보상 강도 조절 계수
    NO_PROGRESS_LIMIT = 50            # Worker 진척 판정 한도 (기존 50 하드코딩 값 대체)\
    TIMEOUT_PENALTY = -10.0

    # --- Worker 에이전트 설정 ---
    WORKER_LR = 3e-4                    # Worker Learning Rate
    WORKER_GAMMA = 0.95                 # Worker 할인율 (Gamma)
    WORKER_ENTROPY_COEF = 0.005         # Worker 엔트로피 보너스 계수
    WORKER_MAX_STEPS_PER_GOAL = 300     # Manager의 목표 하나당 Worker가 수행할 최대 스텝
    WORKER_NUM_STEPS_PER_UPDATE = 1024

    # --- PPO 알고리즘 공통 설정 ---
    PPO_UPDATE_EPOCHS = 4               # 한 번의 업데이트 시 에폭 수
    PPO_NUM_MINIBATCHES = 8             # 미니배치 개수
    PPO_CLIP_COEF = 0.2                 # PPO 클리핑 계수
    PPO_GAE_LAMBDA = 0.95               # GAE(Generalized Advantage Estimation) 람다값
    PPO_VALUE_COEF = 1.0                # 가치 함수 손실(Value Loss) 계수
    PPO_MAX_GRAD_NORM = 0.5             # Gradient Clipping 최대 L2 Norm

    # --- 네트워크 구조 설정 ---
    NODE_FEATURE_DIM = 4      # [is_occupied, dest_diff, blocking_count, is_relocatable]
    GNN_EMBED_DIM = 128       # GNN의 기본 임베딩 차원
    GOAL_EMBED_DIM = 16       # 목표 임베딩 벡터 차원
    # Worker의 GNN 출력을 Manager 상태로 사용 (mean_pool + att_pool)
    # MANAGER_STATE_DIM = GNN_EMBED_DIM * 2

# --- 경로 생성 ---
# 학습 로그와 모델 가중치를 저장할 디렉토리를 생성합니다.
os.makedirs(Config.LOG_DIR, exist_ok=True)
os.makedirs(Config.MODEL_DIR, exist_ok=True)



# 💡 --- 이 함수로 기존의 모든 get_shortest_path 관련 함수를 교체 --- 💡
@numba.jit(nopython=True)
def get_shortest_path(adj_list, start, end, num_nodes):
    """
    Numba에 완벽히 호환되는 가장 표준적이고 안정적인 BFS 함수.
    - 부모 노드 추적 방식을 사용
    - NumPy 배열과 기본 리스트만 사용
    """
    if start == end:
        # Numba를 위해 타입을 명시적으로 리스트로 생성
        path = numba.typed.List()
        path.append(start)
        return path

    # 부모 노드를 기록할 NumPy 배열 (-1로 초기화)
    parents = np.full(num_nodes, -1, dtype=np.int64)

    # 방문 기록을 위한 boolean NumPy 배열
    visited = np.zeros(num_nodes, dtype=np.bool_)

    # 큐로 사용할 단순 리스트
    queue = numba.typed.List()

    queue.append(start)
    visited[start] = True
    head = 0 # 큐의 맨 앞을 가리키는 포인터

    path_found = False
    while head < len(queue):
        current = queue[head]
        head += 1

        if current == end:
            path_found = True
            break

        for neighbor in adj_list[current]:
            if not visited[neighbor]:
                visited[neighbor] = True
                parents[neighbor] = current
                queue.append(neighbor)

    # 경로 역추적
    if path_found:
        path = numba.typed.List()
        curr = end
        while curr != -1:
            path.append(curr)
            curr = parents[curr]
        return path[::-1] # 역순이므로 뒤집어서 반환

    return numba.typed.List.empty_list(numba.int64)

# 💡 --- 교체 완료 --- 💡

class ShipEnv:
    """화물선의 상태와 행동을 시뮬레이션하는 환경 클래스 (Numba 최종 최적화 적용)"""
    def __init__(self, problem_data: dict, max_num_ports: int):
        self.num_nodes = problem_data.get('N', 1)
        self.num_ports = problem_data.get('P', 1)
        self.fixed_cost = float(problem_data.get('F', 100))
        self.max_num_ports = max_num_ports
        self.load_fail_streak: int = 0

        # 💡 --- [핵심] Numba 호환을 위한 그래프 데이터 구조화 --- 💡
        adj_list = [numba.typed.List.empty_list(numba.int64) for _ in range(self.num_nodes)]
        edge_list_for_tensor = []
        for u, v in problem_data.get('E', []):
            adj_list[u].append(v)
            adj_list[v].append(u)
            edge_list_for_tensor.extend([[u, v], [v, u]])
        self.adj_list = adj_list # Numba 함수에 넘겨주기 위해 저장
        # 💡 --- 수정 완료 --- 💡

        self.edge_index_tensor = torch.tensor(edge_list_for_tensor, dtype=torch.long).t().contiguous()

        # 모든 노드 쌍 간의 최단 경로를 미리 계산하여 캐싱
        self.shortest_paths = {}
        for i in range(self.num_nodes):
            self.shortest_paths[i] = {}
            for j in range(self.num_nodes):
                # 💡 [핵심] 수정된 get_shortest_path 함수 호출
                path_result = get_shortest_path(self.adj_list, i, j, self.num_nodes)
                self.shortest_paths[i][j] = list(path_result) # 결과를 일반 리스트로 저장

        self.cars = []
        car_id_counter = 0
        for demand_idx, (demand, quantity) in enumerate(problem_data.get('K', [])):
            origin, dest = demand
            for _ in range(quantity):
                self.cars.append({'id': car_id_counter, 'demand_id': demand_idx, 'origin': origin, 'dest': dest})
                car_id_counter += 1
        self.total_cars = len(self.cars)
        self.reset()

    def _get_or_compute_path(self, start: int, end: int) -> list | None:
        """미리 계산된 경로를 캐시에서 조회합니다."""
        return self.shortest_paths.get(start, {}).get(end, None)

    def _calculate_path_cost(self, path: list | None) -> float:
        if not path or len(path) <= 1: return 0.0
        return float(self.fixed_cost + (len(path) - 1))

    def reset(self) -> Data:
        self.current_port: int = 0
        self.node_status: list[int] = [-1] * self.num_nodes
        self.car_locations: dict[int, int] = {}
        self.cars_on_board: set[int] = set()
        self.temporarily_unloaded_cars: set[int] = set()
        self.delivered_cars: set[int] = set()
        self.relocations_this_episode: int = 0
        self.last_car_action = {}
        # [NEW] 에피소드가 바뀌면 실패 누적도 초기화
        self.load_fail_streak: int = 0
        return self._get_state()

    def _get_state(self) -> Data:
      node_features = []
      for i in range(self.num_nodes):
          if i == 0:
              node_features.append([0.0, 0.0, 0.0, 0.0]); continue
          car_id = self.node_status[i]
          if car_id == -1:
              node_features.append([0.0, 0.0, 0.0, 0.0])
          else:
              car = self.cars[car_id]
              dest_diff = float(car['dest'] - self.current_port)
              path_to_gate = self._get_or_compute_path(i, 0)

              # 💡 --- [핵심 최적화] 경로 리스트를 set으로 변환하여 확인 속도 향상 --- 💡
              if path_to_gate:
                  path_to_gate_set = set(path_to_gate[1:])
                  blocking_count = sum(1 for node_idx, status in enumerate(self.node_status)
                                      if status != -1 and node_idx in path_to_gate_set)
              else:
                  blocking_count = 0
              # 💡 --- 최적화 완료 --- 💡

              is_relocatable = 1.0 if car['dest'] != self.current_port else 0.0
              node_features.append([1.0, dest_diff, float(blocking_count), is_relocatable])


      waiting_cars = [c for c in self.cars if c['origin'] == self.current_port and c['id'] not in self.cars_on_board and c['id'] not in self.temporarily_unloaded_cars]
      waiting_dest_counts = [0.0] * self.max_num_ports
      for car in waiting_cars:
          if car['dest'] < self.max_num_ports: waiting_dest_counts[car['dest']] += 1.0
      global_features = [float(self.current_port), float(len(waiting_cars)), float(len(self.temporarily_unloaded_cars))] + waiting_dest_counts

      return Data(x=torch.tensor(node_features, dtype=torch.float),
                  edge_index=self.edge_index_tensor,
                  global_features=torch.tensor([global_features], dtype=torch.float))

    def get_legal_actions(self, for_worker: bool = False) -> list[tuple[str, int]]:
        actions = []
        any_empty_spot = any(status == -1 for status in self.node_status[1:])
        if any_empty_spot:
            load_candidates = [c for c in self.cars if (c['origin'] == self.current_port and c['id'] not in self.cars_on_board and c['id'] not in self.temporarily_unloaded_cars) or c['id'] in self.temporarily_unloaded_cars]
            for car in load_candidates: actions.append(('LOAD', car['id']))
            for car_id in self.cars_on_board: actions.append(('RELOCATE_INTERNAL', car_id))
        for car_id in self.cars_on_board: actions.append(('UNLOAD', car_id))
        if not for_worker: actions.append(('PROCEED_TO_NEXT_PORT', -1))
        return actions


    # ShipEnv 내 (기존 메소드들 바로 위/아래 아무 곳)
    def _is_hard_blocker(self, node_idx: int) -> bool:
        """
        게이트에서 가깝고 곧 빠질 차는 ‘소프트 블로커’로 보고
        정말 치워야만 통과 가능한 차만 True.
        """
        cid = self.node_status[node_idx]
        if cid == -1:
            return False

        car = self.cars[cid]

        # (1) 지금 하역 항구면 금방 내린다 → False
        if car['dest'] == self.current_port:
            return False
        # (2) 게이트에서 두-세 칸 이내(manhattan depth <3)면 통과 대기열 → False
        gate_depth = len(self._get_or_compute_path(node_idx, 0)) - 1
        return gate_depth >= 3


    def _find_best_spot(self, car_id_to_load: int) -> tuple[int, list]:
        car_dest = self.cars[car_id_to_load]['dest']
        cars_to_leave_later = {
            cid for cid in self.cars_on_board
            if self.cars[cid]['dest'] > car_dest
        }

        best = None
        for spot in range(1, self.num_nodes):
            if self.node_status[spot] != -1:
                continue
            path = self._get_or_compute_path(0, spot)
            if not path:
                continue

            # ── 차단 분석 ────────────────────────────────────────────
            path_set = set(path[1:])
            hard_blocks = sum(self._is_hard_blocker(n) for n in path_set)

            # (1) 하드 블록 2개 초과면 후보 자체를 버린다
            HARD_BLOCK_LIMIT = 2
            if hard_blocks > HARD_BLOCK_LIMIT:
                continue

            soft_blocks = sum(
                1 for n in path_set
                if self.node_status[n] != -1
            ) - hard_blocks
            later_blocks = sum(
                1 for cid in cars_to_leave_later
                if self.car_locations[cid] in path_set
            )

            dist = len(path) - 1
            dest_penalty = abs(car_dest - self.current_port) ** 0.5

            score = (
                1.0 * hard_blocks   +   # 반드시 치워야 할 벽
                0.3 * soft_blocks   +   # 곧 빠질 벽
                0.7 * later_blocks  +   # 나중 출항 차를 막는 정도
                0.02 * dist         +   # 멀리 밀어넣기
                0.01 * dest_penalty     # 목적지 편차 완화
            )


            score += random.uniform(0.0, 0.001)   # tiny tie-breaker

            cand = (score, spot, path)
            best = min(best, cand) if best else cand

        return (-1, []) if best is None else (best[1], best[2])

    def _find_best_internal_spot(self, car_id_to_move: int) -> tuple[int, list]:
        start = self.car_locations.get(car_id_to_move, None)
        if start is None:
            return -1, []

        tried: set[int] = set()
        MAX_RETRY = 5                       # 다른 자리 최대 5개까지 시도
        for _ in range(MAX_RETRY):
            # ── 최솟값 후보 선정 (이전과 동일 로직) ─────────
            best = None
            for spot in range(1, self.num_nodes):
                if spot in tried or self.node_status[spot] != -1:
                    continue
                path = self._get_or_compute_path(start, spot)
                if not path:
                    continue
                inner = set(path[1:-1])
                if any(self.node_status[n] != -1 for n in inner):
                    continue
                unblock_gain = sum(
                    1 for cid in self.cars_on_board
                    if self.cars[cid]['dest'] == self.current_port and
                      self.car_locations[cid] in inner
                )
                dist       = len(path) - 1
                gate_depth = dist
                score      = 0.5 * dist - 0.8 * unblock_gain + 0.05 * gate_depth
                best       = min(best, (score, spot, path)) if best else (score, spot, path)

            # 후보가 없었다
            if best is None:
                return -1, []

            spot, path = best[1], best[2]
            # 실제로 길이 깨끗하면 여기서 끝
            if not any(self.node_status[n] != -1 for n in path[1:-1]):
                return spot, path

            # 아니면 다시 시도
            tried.add(spot)

        return -1, []

    def step(
        self,
        action: tuple[str, int]
    ) -> tuple[Data, float, float, bool]:
        """
        ShipEnv 1-step 시뮬레이션.

        Parameters
        ----------
        action : (action_type:str, car_id:int)
            * action_type ∈ {'LOAD','UNLOAD','RELOCATE_INTERNAL'}
            * car_id == -1 는 PORT 단에서만 쓰는 dummy 값

        Returns
        -------
        state  : Data   ─ PyG 그래프(노드특성·글로벌특성 포함)
        reward : float  ─ worker-level scalar reward
        cost   : float  ─ pure path-moving cost (고정+거리)
        done   : bool   ─ 항로가 끝나면 True
        """
        act_type, cid = action
        cost = 0.0
        event_rew = 0.0                       # (extrinsic + intrinsic)

        # ────────────────────────────────────────────────────
        # 0) 반복/핑퐁 패널티  (cid == -1 이면 생략)
        # ────────────────────────────────────────────────────
        if cid != -1:
            prev = self.last_car_action.get(cid)

            # (a) 같은 행동 타입 연속            → 살짝 -0.1
            if prev == act_type:
                event_rew -= 0.5

            # (b) 직전 UNLOAD → 곧바로 LOAD     → 강하게 -5.0
            elif prev == 'UNLOAD' and act_type == 'LOAD':
                event_rew -= 5.0

            # 기록 갱신
            self.last_car_action[cid] = act_type

        # ────────────────────────────────────────────────────
        # 1) LOAD
        # ────────────────────────────────────────────────────
        if act_type == 'LOAD':
            tgt, path = self._find_best_spot(cid)
            if tgt == -1:                          # 자리가 없음
                self.load_fail_streak += 1         # 누적 +1
                dyn_pen = -1.0 * (1 + 0.3 * self.load_fail_streak)
                return self._get_state(), dyn_pen, 0.0, False
            else:
                # 자리를 찾았으면 streak 초기화
                self.load_fail_streak = 0

            cost += self._calculate_path_cost(path)
            self.node_status[tgt] = cid
            self.car_locations[cid] = tgt
            self.cars_on_board.add(cid)

            if cid in self.temporarily_unloaded_cars:
                self.temporarily_unloaded_cars.remove(cid)
            else:
                event_rew += 0.1   # “신규 선적” 소정 보상

        # ────────────────────────────────────────────────────
        # 2) UNLOAD
        # ────────────────────────────────────────────────────
        elif act_type == 'UNLOAD':
            start = self.car_locations.get(cid)
            if start is None:      # 차가 실제로 안 있다?
                return self._get_state(), -1.0, 0.0, False

            path = self._get_or_compute_path(start, 0)
            cost += self._calculate_path_cost(path)

            # 선박 상태 업데이트
            self.node_status[start] = -1
            self.cars_on_board.remove(cid)
            self.car_locations.pop(cid, None)

            car = self.cars[cid]
            if car['dest'] == self.current_port:          # 목적지 도착
                self.delivered_cars.add(cid)
                event_rew += 1.0
            else:                                         # 임시 하역
                self.temporarily_unloaded_cars.add(cid)
                self.relocations_this_episode += 1
                event_rew -= 0.1

        # ────────────────────────────────────────────────────
        # 3) RELOCATE_INTERNAL
        # ────────────────────────────────────────────────────
        elif act_type == 'RELOCATE_INTERNAL':
            tgt, path = self._find_best_internal_spot(cid)
            if tgt == -1:          # 자리 못 찾음
                event_rew -= 0.2
            else:
                cost += self._calculate_path_cost(path)
                src = self.car_locations.get(cid)
                if src is not None:
                    self.node_status[src] = -1
                self.node_status[tgt] = cid
                self.car_locations[cid] = tgt
                self.relocations_this_episode += 1
                event_rew -= 0.02

        # ────────────────────────────────────────────────────
        # 4) Reward, done flag, state 반환
        # ────────────────────────────────────────────────────
        # 거리비용은 0.1 배만 벌점으로 환산
        reward = event_rew - (cost / self.fixed_cost) * 0.1 if self.fixed_cost else event_rew
        done   = self.current_port >= self.num_ports

        return self._get_state(), reward, cost, done


# ==============================================================================
# 섹션 3: 계층적 환경 Wrapper (리팩토링 버전)
# ==============================================================================
import torch
import torch.nn as nn
from torch_geometric.data import Batch

class HierarchicalEnvWrapper:
    def __init__(self, problem_data: dict, max_num_ports: int,
                 worker_agent, config):
        self.problem_data     = problem_data
        self.max_num_ports    = max_num_ports
        self.worker_agent     = worker_agent
        self.config           = config
        self.last_manager_action = None

        # ── 핵심 객체 ─────────────────────────────
        self.ship_env = ShipEnv(problem_data, max_num_ports)

        # ★ 체류 스텝 카운터 초기화
        self.steps_on_port = 0          # ← 없던 필드

        # Manager action 인코딩
        self.manager_action_map = {
            0: 'CLEAR_BLOCKERS',
            1: 'FINISH_UNLOAD',
            2: 'FINISH_LOAD',
            3: 'CLEAR_TEMP',
            4: 'PROCEED_TO_NEXT_PORT'
        }
        self.goal_embedding = nn.Embedding(
            self.config.MANAGER_ACTION_DIM,
            self.config.GOAL_EMBED_DIM
        ).to(self.worker_agent.device)


    def _calculate_potential(self) -> float:
        """상태의 잠재적 가치를 계산합니다 (Potential-Based Reward Shaping용)."""
        s = self.ship_env

        # 처리해야 할 작업이 적을수록 잠재 가치가 높습니다 (따라서 음수로 계산).

        # 1. 현재 항구에서 실어야 할 대기 차량 수
        waiting_to_load = len([
            c for c in s.cars if c['origin'] == s.current_port and
            c['id'] not in s.cars_on_board and
            c['id'] not in s.temporarily_unloaded_cars
        ])

        # 2. 현재 항구에서 내려야 할 차량 수 (가중치 부여)
        due_to_unload = len([
            c for c in s.cars if c['id'] in s.cars_on_board and
            c['dest'] == s.current_port
        ])

        # 3. 임시로 내린 차량 수 (다시 실어야 함)
        temp_unloaded = len(s.temporarily_unloaded_cars)

        # 잠재력은 이들의 음수 가중합. (작업이 많을수록 가치가 낮음)
        # 특히 '내려야 할 차량'을 처리하는 것을 더 중요하게 보기 위해 가중치 1.5 부여
        potential = - (1.0 * waiting_to_load + 1.5 * due_to_unload + 1.0 * temp_unloaded)
        return potential


    def reset(self, prob_data=None):
        if prob_data is not None:
            self.problem_data = prob_data
            self.ship_env     = ShipEnv(self.problem_data, self.max_num_ports)
        else:
            self.ship_env.reset()

        # ★ 포트 체류 카운터도 항상 리셋
        self.steps_on_port       = 0
        self.last_manager_action = None
        return self._get_manager_state()


    def _get_manager_state(self):
        s = self.ship_env
        total_slots = s.num_nodes - 1 if s.num_nodes > 1 else 1
        port_norm = s.current_port / s.num_ports
        free_slots_norm = sum(1 for n in s.node_status[1:] if n == -1) / total_slots
        waiting_to_load = len([c for c in s.cars if c['origin']==s.current_port and c['id'] not in s.cars_on_board and c['id'] not in s.temporarily_unloaded_cars])
        due_to_unload = len([c for c in s.cars if c['id'] in s.cars_on_board and c['dest']==s.current_port])
        on_board_dests = [s.cars[cid]['dest'] for cid in s.cars_on_board]
        avg_dest_dist = (sum(d - s.current_port for d in on_board_dests)/len(on_board_dests)) if on_board_dests else 0.0

        return torch.tensor([
            port_norm,
            free_slots_norm,
            waiting_to_load / s.total_cars,
            due_to_unload / s.total_cars,
            len(s.temporarily_unloaded_cars) / s.total_cars,
            avg_dest_dist / s.num_ports
        ], dtype=torch.float)

    def _is_goal_achieved(self, goal_str):
        """
        각 목표(goal_str)가 충족되었는지 판단한다.
        CLEAR_BLOCKERS는 ‘하역 대상 차가 없음’ + ‘게이트 경로에 차가 없음’ 두 조건을 모두 만족해야 True.
        """
        s = self.ship_env

        # 현재 항구에서 하역 대상(목적지 == current_port)이 남아 있나?
        has_due_to_unload = any(
            c['dest'] == s.current_port and c['id'] in s.cars_on_board
            for c in s.cars
        )

        # 게이트(노드 0) → 각 노드 경로 중, 막힌 노드가 존재하나?
        def path_blocked():
            for cid in s.cars_on_board:
                # 이미 게이트까지 나왔다가 temp 로 내려간 차량이면 blocker 아님
                if cid in s.temporarily_unloaded_cars:
                    continue
                node_idx = s.car_locations.get(cid, None)
                if node_idx is None:
                    continue
                path = s._get_or_compute_path(0, node_idx)
                if not path:
                    continue
                if any(s.node_status[n] != -1 for n in path[1:]):
                    return True
            return False

        if goal_str == 'FINISH_UNLOAD':
            return not has_due_to_unload                     # 하역 대상이 더 이상 없음
        elif goal_str == 'CLEAR_BLOCKERS':
            return (not has_due_to_unload) and (not path_blocked())
        elif goal_str == 'FINISH_LOAD':
            return not any(
                c['origin'] == s.current_port and
                c['id'] not in s.cars_on_board and
                c['id'] not in s.temporarily_unloaded_cars
                for c in s.cars
            )
        elif goal_str == 'CLEAR_TEMP':
            return len(s.temporarily_unloaded_cars) == 0
        else:
            return False


    # ======================================================================
    #  HierarchicalEnvWrapper.step  ―  최종 교체본
    # ======================================================================
    def step(self, manager_action_idx: int, *, greedy_worker: bool = False):
        """
        Manager 액션을 받아 Worker 롤아웃까지 수행한 뒤
        (다음 Manager 상태, Manager 보상, done, info) 반환
        """
        # ────────────────────────── 기본 세팅 ──────────────────────────
        goal_str         = self.manager_action_map[manager_action_idx]
        goal_embed       = self.goal_embedding(
            torch.tensor([manager_action_idx], device=self.worker_agent.device)
        )
        potential_before = self._calculate_potential()

        # 항구 체류 스텝 카운터  ── NEW ──────────────────────────────────
        self.steps_on_port += 1

        # Worker 통계
        worker_steps         = 0
        total_cost           = 0.0
        total_worker_reward  = 0.0
        no_progress          = 0
        no_progress_trigger  = False

        # Storage (학습 모드일 때만)
        worker_storage = None
        if (goal_str != 'PROCEED_TO_NEXT_PORT') and (not greedy_worker):
            worker_storage = PPOStorage(
                self.config.WORKER_NUM_STEPS_PER_UPDATE,
                (2,),                         # [action_type, car_id]
                self.worker_agent.device
            )

        # ──────────────────────── 디버그·안정성 파라미터 ───────────────────────
        BONUS_EVERY_N      = 100
        PERIODIC_EVERY_N   = 150
        STREAK_LIMIT       = 15             # ★ “동일 행동·차량” 연속 허용치
        EPS                = 1e-6
        last_dbg_key       = None
        bonus_skip_counter = 0
        same_action_streak = 0              # ★ streak 카운터
        last_worker_key    = None           # ★ (act_type, car_id)
        # ────────────────────────────────────────────────────────────────────

        # ─────────────────────────── Worker 루프 ───────────────────────────
        if goal_str != 'PROCEED_TO_NEXT_PORT':
            batch_graph = Batch.from_data_list(
                [self.ship_env._get_state()]
            ).to(self.worker_agent.device)

            last_worker_key    = None
            same_action_streak = 0        # streak 카운터

            for i in range(self.config.WORKER_MAX_STEPS_PER_GOAL):
                worker_steps = i + 1

                # 0-a) 동일 행동·차량 스트릭 초과   ── NEW ────────────────
                if same_action_streak >= STREAK_LIMIT:
                    logging.info(f"[WRK] break — same action repeated {STREAK_LIMIT} times")
                    no_progress_trigger = True
                    break

                # 0-b) ‘진도 없음’ 한도 초과 시 탈출
                if (goal_str == 'FINISH_LOAD') and (no_progress >= self.config.NO_PROGRESS_LIMIT):
                    no_progress_trigger = True
                    break

                # 1) 목표 달성 여부
                if self._is_goal_achieved(goal_str):
                    break

                # 2) 합법 액션
                legal_actions = self.ship_env.get_legal_actions(for_worker=True)
                if not legal_actions:
                    break

                # 3) 최신 그래프 덮어쓰기
                tmp_state        = self.ship_env._get_state()
                batch_graph.x    = tmp_state.x.to(batch_graph.x.device)
                batch_graph.global_features = tmp_state.global_features.to(batch_graph.global_features.device)

                # 4) 행동 선택
                action, at, logp, ent, val = self.worker_agent.get_action_and_value(
                    batch_graph, legal_actions, goal_embed, greedy=greedy_worker
                )
                action_type, car_id = action

                # 5) Streak 업데이트  ── NEW ─────────────────────────────
                worker_key = (action_type, car_id)
                if worker_key == last_worker_key:
                    same_action_streak += 1
                else:
                    same_action_streak = 1
                last_worker_key = worker_key
                # ───────────────────────────────────────────────────────

                # 6) intrinsic 보너스 대상 플래그
                is_temp_before = (car_id != -1) and (car_id in self.ship_env.temporarily_unloaded_cars)
                is_due_before  = (car_id != -1) and (self.ship_env.cars[car_id]['dest'] == self.ship_env.current_port)

                # 7) 환경 한 스텝
                _, worker_reward, move_cost, overall_done = self.ship_env.step(action)
                total_cost          += move_cost
                intrinsic_reward     = 0.0
                if goal_str == 'FINISH_UNLOAD' and action_type == 'UNLOAD' and is_due_before:
                    intrinsic_reward += 0.8
                elif goal_str == 'CLEAR_TEMP' and action_type == 'LOAD' and is_temp_before:
                    intrinsic_reward += 0.4
                worker_reward       += intrinsic_reward
                total_worker_reward += worker_reward

                # 8) no-progress 업데이트
                load_succeeded = (action_type == 'LOAD') and (worker_reward > 0)
                if goal_str == 'FINISH_LOAD' and load_succeeded:
                    no_progress = 0
                else:
                    no_progress += 1

                # 9) Storage 저장
                if worker_storage is not None:
                    worker_storage.add(tmp_state, at, logp, worker_reward, overall_done, val)

                # 10) 디버그 로그 (… 생략, 기존과 동일) ----------------------

                if overall_done:
                    break

            # --- Worker 루프 종료 후 streak 초기화 -----------------
            same_action_streak = 0        # ← 다음 goal 로 이어지지 않게

            # 11) PPO 업데이트
            if worker_storage and worker_storage.step > 0 and not greedy_worker:
                self.worker_agent.update(worker_storage, goal_embed.detach())

        # ───────────────────────── Manager 보상 계산 ─────────────────────────
        goal_done    = self._is_goal_achieved(goal_str)
        event_reward = 0.0

        # 실패 패널티
        if (worker_steps >= self.config.WORKER_MAX_STEPS_PER_GOAL) and (not goal_done):
            penalty = (2 * self.config.TIMEOUT_PENALTY) if goal_str == 'CLEAR_BLOCKERS' else self.config.TIMEOUT_PENALTY
            event_reward += penalty

        # PBRS
        potential_after = self._calculate_potential()
        event_reward += self.config.SHAPING_REWARD_WEIGHT * (
            self.config.MANAGER_GAMMA * potential_after - potential_before
        )

        # 항구 이동 명령 처리
        if goal_str == 'PROCEED_TO_NEXT_PORT':
            waiting = [c for c in self.ship_env.cars if c['origin']==self.ship_env.current_port
                       and c['id'] not in self.ship_env.cars_on_board
                       and c['id'] not in self.ship_env.temporarily_unloaded_cars]
            can_go = not (waiting or self.ship_env.temporarily_unloaded_cars)
            if can_go:
                self.ship_env.current_port += 1
                self.steps_on_port = 0            # ★ 체류 카운터 리셋
                event_reward += 5.0
                goal_done = True
            else:
                event_reward -= 10.0
                goal_done = False

        # ── 항구 체류 제한 패널티만 적용(강제 이동 X)  ── NEW ───────────────
        MAX_STEP_PER_PORT = 4000
        if self.steps_on_port > MAX_STEP_PER_PORT:
            event_reward += self.config.TIMEOUT_PENALTY * 5   # 큰 패널티
            logging.info("[MGR] Port-stay limit exceeded — penalty applied")
            self.steps_on_port = MAX_STEP_PER_PORT            # 더 증가 안 함
        # ────────────────────────────────────────────────────────────────

        # 최종 Manager reward
        manager_reward  = event_reward - (total_cost * 0.001)
        manager_reward -= self.config.STEP_PENALTY_WEIGHT * worker_steps
        if self.last_manager_action == manager_action_idx:
            manager_reward += self.config.REPEAT_PENALTY
        if no_progress_trigger:
            manager_reward += self.config.NO_PROGRESS_PENALTY

        self.last_manager_action = manager_action_idx

        # 에피소드 종료 보너스/패널티
        done = self.ship_env.current_port >= self.ship_env.num_ports
        if done:
            if len(self.ship_env.delivered_cars) == self.ship_env.total_cars:
                manager_reward += 1000.0
            else:
                manager_reward -= (self.ship_env.total_cars - len(self.ship_env.delivered_cars)) * 10.0

        # 다음 Manager state
        next_state = self._get_manager_state()

        info = {
            'steps'              : worker_steps,
            'goal'               : goal_str,
            'success'            : goal_done,
            'cost'               : total_cost,
            'worker_total_reward': total_worker_reward,
        }
        return next_state, manager_reward, done, info




# ==============================================================================
# 섹션 4: 에이전트 및 네트워크 (리팩토링 버전)
# ==============================================================================
import torch.nn.functional as F
from torch.distributions import Categorical
from torch_geometric.data import Batch
from torch_geometric.nn import GlobalAttention, global_mean_pool

# --- PPO 알고리즘을 위한 경험 저장소 ---

class PPOStorage:
    """PPO 학습을 위한 롤아웃(rollout) 데이터를 저장하고 관리하는 클래스."""
    def __init__(self, num_steps: int, action_shape: tuple, device: torch.device, state_shape: tuple = None, manager_action_dim: int = None):
        """
        Args:
            num_steps (int): 저장할 총 스텝 수.
            action_shape (tuple): 행동의 형태.
            device (torch.device): 데이터가 저장될 장치 (CPU 또는 CUDA).
            state_shape (tuple, optional): 상태의 형태. 그래프 데이터가 아니면 지정.
            manager_action_dim (int, optional): Manager의 경우, 액션 마스크를 저장하기 위해 필요.
        """
        self.num_steps = num_steps
        self.device = device
        self.step = 0
        self.is_graph_data = (state_shape is None)

        # 데이터 저장 버퍼 초기화
        if self.is_graph_data:
            self.obs = [None] * num_steps
        else:
            self.obs = torch.zeros((num_steps,) + state_shape, device=device)

        if isinstance(action_shape, int): action_shape = (action_shape,)
        self.actions = torch.zeros((num_steps,) + action_shape, device=device, dtype=torch.long)
        self.logprobs = torch.zeros(num_steps, device=device)
        self.rewards = torch.zeros(num_steps, device=device)
        self.dones = torch.zeros(num_steps, device=device)
        self.values = torch.zeros(num_steps, device=device)

        # [핵심 수정] Manager의 상태 의존적 액션 마스크를 저장할 공간
        if manager_action_dim:
            self.masks = torch.zeros((num_steps, manager_action_dim), device=device)
        else:
            self.masks = None

    def reset(self):
        self.step = 0

    def add(self, obs, action, logprob, reward, done, value, mask=None):
        """한 스텝의 경험 데이터를 저장합니다."""
        if self.step >= self.num_steps: return

        if self.is_graph_data:
            self.obs[self.step] = obs.cpu() # GPU 메모리 절약을 위해 CPU에 저장
        else:
            self.obs[self.step].copy_(torch.as_tensor(obs, device=self.device))

        self.actions[self.step] = action
        self.logprobs[self.step] = logprob
        self.rewards[self.step] = torch.tensor(reward, dtype=torch.float32)
        self.dones[self.step] = torch.tensor(done, dtype=torch.float32)
        self.values[self.step] = value.detach()

        # [핵심 수정] Manager의 액션 마스크 저장
        if self.masks is not None and mask is not None:
            self.masks[self.step] = mask.detach()

        self.step += 1

    def is_full(self) -> bool:
        return self.step >= self.num_steps

    def compute_returns_and_advantages(self, last_value: torch.Tensor, gamma: float, gae_lambda: float):
        """GAE를 사용하여 보상(Return)과 어드밴티지(Advantage)를 계산합니다."""
        advantages = torch.zeros_like(self.rewards).to(self.device)
        last_gae = 0.0
        for t in reversed(range(self.num_steps)):
            next_non_terminal = 1.0 - self.dones[t]
            next_value = last_value if t == self.num_steps - 1 else self.values[t + 1]

            delta = self.rewards[t] + gamma * next_value * next_non_terminal - self.values[t]
            last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae
            advantages[t] = last_gae

        # 어드밴티지를 이용해 최종 Return 계산
        self.returns = advantages + self.values
        # 어드밴티지 정규화 (학습 안정화)
        self.advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)



# --- 신경망 모델 정의 ---
# WorkerNetwork 클래스 전체
class WorkerNetwork(nn.Module):
    """Worker 에이전트의 정책 및 가치 신경망"""
    def __init__(self, node_feature_size: int, global_feature_size: int, max_cars: int, num_nodes: int, config: Config):
        super().__init__()
        self.config = config
        embed_dim = self.config.GNN_EMBED_DIM
        self.node_input_proj = nn.Linear(node_feature_size, embed_dim)
        self.positional_encoding = nn.Embedding(num_nodes, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=embed_dim*4, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
        self.att_pool = GlobalAttention(gate_nn=nn.Sequential(nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Linear(embed_dim // 2, 1)))

        mlp_input_dim = embed_dim * 2 + global_feature_size + self.config.GOAL_EMBED_DIM
        self.mlp = nn.Sequential(nn.Linear(mlp_input_dim, 512), nn.GELU(), nn.Linear(512, 256), nn.GELU())
        self.actor_type_head = nn.Linear(256, 3)
        self.actor_load_head = nn.Linear(256, max_cars)
        self.actor_unload_head = nn.Linear(256, max_cars)
        self.actor_relocate_head = nn.Linear(256, max_cars)
        self.critic_head = nn.Linear(256, 1)

    def forward(self, data, goal_embedding):
        current_device = goal_embedding.device
        x, global_feats, batch_index = data.x.to(current_device), data.global_features.to(current_device), data.batch.to(current_device)
        ptr = data.ptr.to(current_device) if hasattr(data, 'ptr') and data.ptr is not None else torch.tensor([0, x.size(0)], device=current_device)

        node_embeddings = self.node_input_proj(x)
        pos_enc_list = [self.positional_encoding(torch.arange(ptr[i+1] - ptr[i], device=current_device)) for i in range(len(ptr)-1)]
        pos_enc = torch.cat(pos_enc_list) if pos_enc_list else torch.empty(0, self.config.GNN_EMBED_DIM, device=current_device)
        if node_embeddings.size(0) == pos_enc.size(0): node_embeddings += pos_enc

        if len(ptr) > 1 and x.shape[0] > 0: max_len = (ptr[1:] - ptr[:-1]).max().item()
        elif x.shape[0] > 0: max_len = x.shape[0]
        else: max_len = 1

        padded_x, masks_list = [], []
        if len(ptr) > 1:
            for i in range(len(ptr) - 1):
                start, end = ptr[i], ptr[i+1]
                graph_len, current_nodes = end - start, node_embeddings[start:end]
                pad = torch.zeros(max_len - graph_len, self.config.GNN_EMBED_DIM, device=current_device)
                padded_x.append(torch.cat([current_nodes, pad]))
                mask = torch.ones(max_len, dtype=torch.bool, device=current_device)
                mask[:graph_len] = False
                masks_list.append(mask)
        else:
            padded_x.append(node_embeddings)
            masks_list.append(torch.zeros(node_embeddings.shape[0], dtype=torch.bool, device=current_device))

        padded_x, attention_mask = torch.stack(padded_x), torch.stack(masks_list)
        transformer_out = self.transformer_encoder(padded_x, src_key_padding_mask=attention_mask)
        transformer_out_flat = transformer_out[~attention_mask]

        graph_emb_mean = global_mean_pool(transformer_out_flat, batch_index)
        graph_emb_att = self.att_pool(transformer_out_flat, batch_index)
        graph_emb = torch.cat([graph_emb_mean, graph_emb_att], dim=1)

        if goal_embedding.shape[0] != graph_emb.shape[0]:
            goal_embedding = goal_embedding.expand(graph_emb.shape[0], -1)

        combined_features = torch.cat([graph_emb, global_feats, goal_embedding], dim=1)
        final_features = self.mlp(combined_features)

        # [수정] 반환 값에 graph_emb 추가
        return (self.actor_type_head(final_features), self.actor_load_head(final_features),
                self.actor_unload_head(final_features), self.actor_relocate_head(final_features),
                self.critic_head(final_features).squeeze(-1), graph_emb)



class ManagerNetwork(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, num_layers: int = 2, nhead: int = 4):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_embed = nn.Embedding(action_dim, 32)

        self.input_proj = nn.Linear(state_dim + 32, 128)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=128,
            nhead=nhead,
            dim_feedforward=256,
            batch_first=True,
            activation='gelu'
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.actor_head = nn.Linear(128, action_dim)
        self.critic_head = nn.Linear(128, 1)

    def forward(self, state: torch.Tensor, prev_action_idx: torch.LongTensor):
        """
        Args:
            state: [B, state_dim]
            prev_action_idx: [B]
        Returns:
            logits: [B, action_dim]
            value:  [B]
        """
        a_emb = self.action_embed(prev_action_idx)            # [B, 32]
        x = torch.cat([state, a_emb], dim=-1)                 # [B, state_dim + 32]
        x = self.input_proj(x).unsqueeze(1)                   # [B, 1, 128]

        encoded = self.encoder(x)                             # [B, 1, 128]
        h = encoded.squeeze(1)                                # [B, 128]

        logits = self.actor_head(h)                           # [B, action_dim]
        value  = self.critic_head(h).squeeze(-1)              # [B]
        return logits, value

# --- 에이전트 클래스 정의 ---

class ManagerAgent:
    """상위 레벨의 목표를 결정하는 Manager 에이전트."""
    def __init__(self, config: Config, device: torch.device, env_wrapper: HierarchicalEnvWrapper):
        self.config = config
        self.device = device
        self.env_wrapper = env_wrapper # 현재 환경 상태에 접근하여 액션 마스킹
        self.action_map = env_wrapper.manager_action_map
        self.type_to_idx = {v: k for k, v in self.env_wrapper.manager_action_map.items()}
        self.idx_to_type = {v: k for k, v in self.type_to_idx.items()}

        self.network = ManagerNetwork(config.MANAGER_STATE_DIM, config.MANAGER_ACTION_DIM).to(self.device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=config.MANAGER_LR, eps=1e-5)


    def _build_action_mask(self) -> torch.Tensor:
        """
        달성된 goal, 혹은 물리적으로 불가능한 goal의 로짓을 -1e9 로 내려서
        softmax 확률 0 이 되게 만든다.
        """
        mask = torch.zeros(1, self.config.MANAGER_ACTION_DIM, device=self.device)
        s_env = self.env_wrapper.ship_env # ship_env에 더 쉽게 접근

        # 1. 이미 달성된 목표 마스킹
        for idx, goal in self.env_wrapper.manager_action_map.items():
            if self.env_wrapper._is_goal_achieved(goal):
                mask[0, idx] = -1e9

        # 💡 --- [핵심 추가] --- 💡
        # 2. 'PROCEED_TO_NEXT_PORT'가 불가능한 조건일 때 마스킹
        #   - 항구에 내려야 할 차가 있거나, 임시로 내린 차가 있으면 PROCEED 불가
        waiting_cars = any(c['origin'] == s_env.current_port and c['id'] not in s_env.cars_on_board and c['id'] not in s_env.temporarily_unloaded_cars for c in s_env.cars)
        has_temp_unloaded = bool(s_env.temporarily_unloaded_cars)

        if waiting_cars or has_temp_unloaded:
            proceed_idx = self.type_to_idx['PROCEED_TO_NEXT_PORT']
            mask[0, proceed_idx] = -1e9
        # 💡 --- [수정 완료] --- 💡

        return mask

    def get_action_and_value(
        self,
        state: torch.Tensor,
        legal_actions: list[tuple[str,int]],
        prev_action_idx: torch.LongTensor,
        greedy: bool = False
    ) -> tuple: # [수정] 반환 타입 튜플로 명시
        """
        Args:
            state: [state_dim] 크기의 텐서
            legal_actions: 가능한 행동 리스트 (예: [("FINISH_LOAD",-1), …])
            prev_action_idx: 이전에 선택한 매니저 액션 인덱스 ([1]-shape LongTensor)
            greedy: True면 탐욕적으로(max) 선택

        Returns:
            [greedy=False] (action_tensor, logp, ent, value, mask)
            - action_tensor (Tensor): 학습용 액션 텐서 ([1]-shape LongTensor)
            - logp (Tensor): 선택 확률의 로그값
            - ent (Tensor): 선택 분포의 엔트로피
            - value (Tensor): 상태 가치 추정 (스칼라)
            - mask (Tensor): PPO 업데이트에 사용할 상태 의존적 마스크

            [greedy=True] (action_idx, None, None, None, None)
            - action_idx (int): 선택된 매니저 액션 인덱스 (0~4)
        """
        # 1) 평가 모드
        self.network.eval()
        with torch.no_grad():
            # 2) 배치 차원 추가: [state_dim] → [1, state_dim]
            batch_state = state.unsqueeze(0).to(self.device)

            # 3) 달성 불가능하거나 이미 달성된 목표 로짓 마스킹
            mask = self._build_action_mask()[0]          # [action_dim]

            # 4) 네트워크 호출: logits [1,action_dim], value [1,]
            logits, value = self.network(batch_state, prev_action_idx)
            logits = logits[0] + mask                    # [action_dim]

            # 5) legal_actions 에 따라 가능한 목표만 남기기
            allowed = {act for act, _ in legal_actions}
            type_mask = torch.full_like(logits, -1e9)
            for idx, t_str in self.idx_to_type.items():
                if t_str in allowed:
                    type_mask[idx] = 0.0
            masked_logits = logits + type_mask

            dist = Categorical(logits=masked_logits)

            # [핵심 수정] greedy(평가) 모드와 학습 모드의 반환 값을 분리하여 명확히 함
            if greedy:
                type_idx = torch.argmax(masked_logits)
                self.network.train() # 모드 복귀
                return type_idx.item(), None, None, None, None

            # 학습 모드
            type_idx = dist.sample() # 0-dim 텐서
            logp     = dist.log_prob(type_idx)
            ent      = dist.entropy()

            # PPO 저장을 위해 (1,) 형태로 변환
            action_tensor = type_idx.unsqueeze(0)

        # 7) 학습 모드 복귀
        self.network.train()

        # [핵심 수정] 학습에 필요한 모든 값을 올바른 순서와 타입으로 반환
        return action_tensor, logp, ent, value[0], mask


    def update(self, storage: PPOStorage) -> dict:
        """저장된 경험 데이터를 사용하여 PPO 업데이트를 수행합니다."""
        # 마지막 상태의 가치 계산
        with torch.no_grad():
            # 마지막 상태는 storage.obs에 저장되어 있지만, 이전 액션이 필요함
            # 이 로직은 간단하게 마지막 obs와 마지막 action을 가져와서 처리해야 함
            last_obs = storage.obs[-1].unsqueeze(0).to(self.device)
            last_prev_action = storage.actions[-2] if storage.step > 1 else torch.zeros(1, dtype=torch.long, device=self.device)

            _, last_value = self.network(last_obs, last_prev_action)


        # GAE와 Return 계산
        storage.compute_returns_and_advantages(last_value, self.config.MANAGER_GAMMA, self.config.PPO_GAE_LAMBDA)

        # 학습을 위한 데이터 준비
        # obs는 PPOStorage에 tensor로 저장되도록 수정되었음을 가정
        b_obs = torch.stack(list(storage.obs)).to(self.device)
        b_actions = storage.actions.squeeze(-1).to(self.device)
        b_logprobs = storage.logprobs.to(self.device)
        b_returns = storage.returns.to(self.device)
        b_advantages = storage.advantages.to(self.device)
        b_masks = storage.masks.to(self.device)

        # 이전 액션(b_prev_actions) 배치 생성
        b_prev_actions = torch.cat(
            (torch.zeros(1, dtype=torch.long, device=self.device), b_actions[:-1]),
            dim=0
        )

        batch_size = storage.step
        minibatch_size = max(1, batch_size // self.config.PPO_NUM_MINIBATCHES)

        # PPO 업데이트 루프
        for _ in range(self.config.PPO_UPDATE_EPOCHS):
            perm_indices = np.random.permutation(batch_size)
            for start in range(0, batch_size, minibatch_size):
                end = start + minibatch_size
                mb_idx = perm_indices[start:end]

                # 미니배치 데이터 생성
                mb_states = b_obs[mb_idx]
                mb_actions = b_actions[mb_idx]
                mb_logprobs = b_logprobs[mb_idx]
                mb_returns = b_returns[mb_idx]
                mb_advantages = b_advantages[mb_idx]
                mb_masks = b_masks[mb_idx]
                mb_prev_actions = b_prev_actions[mb_idx] # 이전 액션 미니배치

                # 네트워크를 통해 새로운 로그확률, 가치, 엔트로피 계산
                logits, new_values = self.network(mb_states, mb_prev_actions)

                # 저장된 마스크를 적용하여 확률 분포 재계산
                final_logits = logits + mb_masks
                dist = Categorical(logits=final_logits)
                new_logprobs = dist.log_prob(mb_actions)
                entropy = dist.entropy()

                # PPO 손실 계산
                log_ratio = new_logprobs - mb_logprobs
                ratio = log_ratio.exp()

                surr1 = mb_advantages * ratio
                surr2 = mb_advantages * torch.clamp(ratio, 1 - self.config.PPO_CLIP_COEF, 1 + self.config.PPO_CLIP_COEF)

                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = F.mse_loss(new_values, mb_returns)

                loss = (policy_loss +
                        self.config.PPO_VALUE_COEF * value_loss -
                        self.config.MANAGER_ENTROPY_COEF * entropy.mean())

                # 역전파 및 최적화
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.network.parameters(), self.config.PPO_MAX_GRAD_NORM)
                self.optimizer.step()

        return {
            "policy_loss": policy_loss.item(),
            "value_loss": value_loss.item(),
            "entropy": entropy.mean().item()
        }

class WorkerAgent:
    """하위 레벨의 세부 행동을 결정하는 Worker 에이전트."""
    def __init__(
        self,
        node_feature_size: int,
        global_feature_size: int,
        max_cars: int,
        max_nodes: int,
        config: Config,
        device: torch.device
    ):
        self.config = config
        self.device = device

        # v11 시그니처에 맞춰 Config 객체만 넘기도록 수정
        self.network = WorkerNetwork(
            node_feature_size,
            global_feature_size,
            max_cars,
            max_nodes,
            config
        ).to(self.device)

        self.optimizer = optim.Adam(
            self.network.parameters(),
            lr=config.WORKER_LR,
            eps=1e-5
        )

        # 행동 타입 ↔ 인덱스 매핑
        self.type_to_idx = {'LOAD': 0, 'UNLOAD': 1, 'RELOCATE_INTERNAL': 2}
        self.idx_to_type = {v: k for k, v in self.type_to_idx.items()}


    def get_action_and_value(
                              self,
                              batch_graph: Batch,           # ← Data가 아니라 pre-batched 그래프를 받습니다
                              legal_actions: list,
                              goal_embedding: torch.Tensor,
                              greedy: bool = False
                            ) -> tuple:

        """현재 상태와 목표에 따라 세부 행동을 결정합니다."""
        self.network.eval() # 평가 모드
        with torch.no_grad():
            type_logits, load_logits, unload_logits, relocate_logits, value, _ = self.network(batch_graph, goal_embedding)
            type_logits, load_logits, unload_logits, relocate_logits = type_logits[0], load_logits[0], unload_logits[0], relocate_logits[0]

            # --- 합법적인 행동(Legal Action)에 대한 마스킹 ---
            # 각 행동 타입과 차량 ID에 대해 가능한 행동만 남깁니다.
            allowed_types_str = {act for act, _ in legal_actions}
            type_mask = torch.full_like(type_logits, -1e9)
            for idx, t_str in self.idx_to_type.items():
                if t_str in allowed_types_str:
                    type_mask[idx] = 0.0

            allowed_load_ids = {aid for (act, aid) in legal_actions if act == 'LOAD'}
            load_mask = torch.full_like(load_logits, -1e9)
            for idx in allowed_load_ids: load_mask[idx] = 0.0

            allowed_unload_ids = {aid for (act, aid) in legal_actions if act == 'UNLOAD'}
            unload_mask = torch.full_like(unload_logits, -1e9)
            for idx in allowed_unload_ids: unload_mask[idx] = 0.0

            allowed_relocate_ids = {aid for (act, aid) in legal_actions if act == 'RELOCATE_INTERNAL'}
            relocate_mask = torch.full_like(relocate_logits, -1e9)
            for idx in allowed_relocate_ids: relocate_mask[idx] = 0.0

            # 마스크 적용
            masked_type_logits = type_logits + type_mask

            # --- 계층적 샘플링 (Hierarchical Sampling) ---
            # 1. 행동 타입 결정
            type_dist = Categorical(logits=masked_type_logits)
            type_idx = torch.argmax(masked_type_logits) if greedy else type_dist.sample()
            type_str = self.idx_to_type[int(type_idx.item())]

            # 2. 결정된 타입에 따라 차량 ID 결정
            car_idx_tensor = torch.tensor(-1, device=self.device, dtype=torch.long)
            car_dist = None
            if type_str == 'LOAD':
                masked_logits = load_logits + load_mask
                car_dist = Categorical(logits=masked_logits)
                car_idx_tensor = torch.argmax(masked_logits) if greedy else car_dist.sample()
            elif type_str == 'UNLOAD':
                masked_logits = unload_logits + unload_mask
                car_dist = Categorical(logits=masked_logits)
                car_idx_tensor = torch.argmax(masked_logits) if greedy else car_dist.sample()
            elif type_str == 'RELOCATE_INTERNAL':
                masked_logits = relocate_logits + relocate_mask
                car_dist = Categorical(logits=masked_logits)
                car_idx_tensor = torch.argmax(masked_logits) if greedy else car_dist.sample()

            action_tuple = (type_str, int(car_idx_tensor.item()))

            # 학습 모드일 때만 로그 확률과 엔트로피 계산
            if greedy:
                return action_tuple, None, None, None, value.squeeze(0)

            action_tensor = torch.tensor([type_idx.item(), car_idx_tensor.item()], device=self.device)
            log_prob_type = type_dist.log_prob(type_idx)
            ent_type = type_dist.entropy()

            log_prob_car = torch.tensor(0.0, device=self.device)
            ent_car = torch.tensor(0.0, device=self.device)
            if car_dist is not None:
                log_prob_car = car_dist.log_prob(car_idx_tensor)
                ent_car = car_dist.entropy()

            total_log_prob = log_prob_type + log_prob_car
            total_entropy = ent_type + ent_car

        self.network.train() # 학습 모드로 전환
        return action_tuple, action_tensor, total_log_prob, total_entropy, value.squeeze(0)

    @staticmethod
    def _traj_success(traj):
        meta = getattr(traj[-1][0], "meta", {})
        return meta.get("delivered", -1) == meta.get("total", -2)

    @staticmethod
    def _load_one(path):
        with open(path, "rb") as f:
            t = pickle.load(f)
        if WorkerAgent._traj_success(t):
            return t
        logging.info(f"  · drop FAILED traj  → {os.path.basename(path)}")
        return []

    #───────────────────────────────────────────────────────────────
    #  (WorkerAgent 메소드)  ─  Expert-pkl 모방 학습 루틴
    #───────────────────────────────────────────────────────────────
    def pretrain_with_imitation(
        self,
        expert_data_paths: list[str],
        epochs: int,
        lr: float,
        batch_size: int,
    ):
        """
        전문가(pkl) 궤적을 이용해 Worker 네트워크를 사전 학습한다.
        ───────────────────────────────────────────────────────────
        ▸ expert_data_paths : '*.pkl' 파일 목록
        ▸ 각 pkl = [(state : Data, action : (str, int)), …]
        """
        logging.info("[Phase 1] Starting Imitation Learning for Worker Agent…")

        # 1) 궤적 로드 ───────────────────────────────────────────
        expert_pairs: list[tuple[Data, tuple[str, int]]] = []
        valid_types = set(self.type_to_idx.keys())      # {'LOAD', …}

        for path in expert_data_paths:
            if not os.path.exists(path):
                logging.warning(f"  · Not found → {path}")
                continue
            with open(path, "rb") as f:
                traj = pickle.load(f)
            expert_pairs.extend( (s, a) for s, a in traj if a[0] in valid_types )

        if not expert_pairs:
            logging.warning("  · No usable expert samples → skip.")
            return

        max_n = getattr(self.config, "MAX_EXPERT_SAMPLES", None)
        if max_n and len(expert_pairs) > max_n:
            random.shuffle(expert_pairs)
            expert_pairs = expert_pairs[:max_n]
        logging.info(f"  · Total Samples: {len(expert_pairs):,}")

        # 2) 옵티마이저 & 학습 루프 ───────────────────────────────
        optim_ = optim.Adam(self.network.parameters(), lr=lr)
        self.network.train()

        for epoch in range(epochs):
            random.shuffle(expert_pairs)
            total_loss, nb = 0.0, 0

            for idx in range(0, len(expert_pairs), batch_size):
                batch = expert_pairs[idx: idx + batch_size]
                if not batch: continue

                states, acts = zip(*batch)
                g = Batch.from_data_list(states).to(self.device)

                # action → 텐서
                a_types = torch.as_tensor(
                    [self.type_to_idx[a[0]] for a in acts],
                    device=self.device, dtype=torch.long)
                a_cars  = torch.as_tensor(
                    [a[1] for a in acts],
                    device=self.device, dtype=torch.long)

                dummy_goal = torch.zeros(
                    g.num_graphs, self.config.GOAL_EMBED_DIM,
                    device=self.device)

                # ── forward & loss ─────────────────────────────
                with torch.set_grad_enabled(True):
                    t_logit, l_logit, u_logit, r_logit, _v, _ = \
                        self.network(g, dummy_goal)

                    loss_type = F.cross_entropy(t_logit, a_types)

                    loss_load   = F.cross_entropy(l_logit, a_cars, reduction="none")
                    loss_unload = F.cross_entropy(u_logit, a_cars, reduction="none")
                    loss_reloc  = F.cross_entropy(r_logit, a_cars, reduction="none")

                    m_load     = (a_types == self.type_to_idx['LOAD'])
                    m_unload   = (a_types == self.type_to_idx['UNLOAD'])
                    m_reloc    = (a_types == self.type_to_idx['RELOCATE_INTERNAL'])

                    loss_car = loss_type.new_zeros(1)  # grad X 텐서
                    if m_load.any().item():
                        loss_car = loss_car + loss_load[m_load].mean()
                    if m_unload.any().item():
                        loss_car = loss_car + loss_unload[m_unload].mean()
                    if m_reloc.any().item():
                        loss_car = loss_car + loss_reloc[m_reloc].mean()

                    loss = loss_type + loss_car

                    # backward
                    optim_.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_norm_(self.network.parameters(),
                                             self.config.PPO_MAX_GRAD_NORM)
                    optim_.step()

                total_loss += loss.item(); nb += 1

            if nb and (epoch + 1) % 10 == 0:
                logging.info(f"    Epoch {epoch+1:4d}/{epochs}  "
                             f"avg-loss {total_loss/nb:.4f}")

        logging.info("[Phase 1] Imitation Learning Finished.")


    def evaluate_actions(self, states: list[Data], actions: torch.Tensor, goal_embedding: torch.Tensor) -> tuple:
      batch_data = Batch.from_data_list(states).to(self.device)

      """
      PPO 업데이트를 위해, 주어진 상태(states)에서 특정 행동(actions)을 했을 때의
      로그 확률(log_prob), 엔트로피(entropy), 가치(value)를 다시 계산합니다.

      Args:
          states (list[Data]): 상태(그래프) 데이터의 리스트.
          actions (torch.Tensor): [액션 타입, 차량 ID] 형태의 행동 텐서.
          goal_embedding (torch.Tensor): 현재 목표에 대한 임베딩 텐서.

      Returns:
          tuple: (로그 확률, 엔트로피, 가치) 텐서.
      """
      type_logits, load_logits, unload_logits, relocate_logits, values, graph_emb = \
          self.network(batch_data, goal_embedding)

      action_types = actions[:, 0]
      action_cars  = actions[:, 1]

      # 각 행동 헤드에 대한 확률 분포 생성
      type_dist      = Categorical(logits=type_logits)
      load_dist      = Categorical(logits=load_logits)
      unload_dist    = Categorical(logits=unload_logits)
      relocate_dist  = Categorical(logits=relocate_logits)

      # 행동 타입 로그 확률
      log_probs_type = type_dist.log_prob(action_types)

      # 하위 행동 로그 확률 미리 계산
      log_probs_load     = load_dist.log_prob(action_cars)
      log_probs_unload   = unload_dist.log_prob(action_cars)
      log_probs_relocate = relocate_dist.log_prob(action_cars)

      # 실제 취한 타입에 해당하는 하위 행동 확률만 선택
      mask_load     = (action_types == self.type_to_idx['LOAD']).float()
      mask_unload   = (action_types == self.type_to_idx['UNLOAD']).float()
      mask_relocate = (action_types == self.type_to_idx['RELOCATE_INTERNAL']).float()

      log_probs = log_probs_type \
                  + log_probs_load     * mask_load \
                  + log_probs_unload   * mask_unload \
                  + log_probs_relocate * mask_relocate

      # 엔트로피 계산
      type_probs = F.softmax(type_logits, dim=-1)
      entropy = (
          type_dist.entropy()
          + type_probs[:, self.type_to_idx['LOAD']]     * load_dist.entropy()
          + type_probs[:, self.type_to_idx['UNLOAD']]   * unload_dist.entropy()
          + type_probs[:, self.type_to_idx['RELOCATE_INTERNAL']] * relocate_dist.entropy()
      )

      return log_probs, entropy, values

    def update(self, storage: PPOStorage, goal_embedding: torch.Tensor) -> dict:
        """저장된 Worker의 경험 데이터를 사용하여 PPO 업데이트를 수행합니다."""
        self.network.train()

        # 1) 마지막 상태를 배치로 만들어줍니다.
        last_state = storage.obs[storage.step - 1]
        last_state_batch = Batch.from_data_list([last_state]).to(self.device)
        # 마지막 상태의 가치 계산
        with torch.no_grad():
            # network 반환: (type_logits, load_logits, unload_logits, relocate_logits, critic_value, graph_emb)
            _, _, _, _, critic_value, _ = self.network(last_state_batch, goal_embedding)
            last_value = critic_value.squeeze(0)



        # GAE와 Return 계산
        storage.compute_returns_and_advantages(last_value, self.config.WORKER_GAMMA, self.config.PPO_GAE_LAMBDA)

        # 학습 데이터 준비 (리스트는 그대로 두고, 텐서는 device로 이동)
        b_obs = storage.obs # 리스트이므로 device로 옮기지 않음
        b_actions = storage.actions.to(self.device)
        b_logprobs = storage.logprobs.to(self.device)
        b_returns = storage.returns.to(self.device)
        b_advantages = storage.advantages.to(self.device)

        batch_size = storage.step
        if batch_size == 0: return {}
        minibatch_size = max(1, batch_size // self.config.PPO_NUM_MINIBATCHES)

        # PPO 업데이트 루프
        for _ in range(self.config.PPO_UPDATE_EPOCHS):
            perm_indices = np.random.permutation(batch_size)
            for start in range(0, batch_size, minibatch_size):
                end = start + minibatch_size
                mb_idx = perm_indices[start:end]

                # 미니배치 데이터 생성
                mb_states = [b_obs[i] for i in mb_idx]
                mb_actions = b_actions[mb_idx]
                mb_logprobs = b_logprobs[mb_idx]
                mb_returns = b_returns[mb_idx]
                mb_advantages = b_advantages[mb_idx]

                new_logprobs, entropy, new_values = self.evaluate_actions(mb_states, mb_actions, goal_embedding)

                # PPO 손실 계산
                log_ratio = new_logprobs - mb_logprobs
                ratio = log_ratio.exp()

                surr1 = mb_advantages * ratio
                surr2 = mb_advantages * torch.clamp(ratio, 1 - self.config.PPO_CLIP_COEF, 1 + self.config.PPO_CLIP_COEF)

                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = F.mse_loss(new_values, mb_returns)

                loss = (policy_loss +
                        self.config.PPO_VALUE_COEF * value_loss -
                        self.config.WORKER_ENTROPY_COEF * entropy.mean())

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.network.parameters(), self.config.PPO_MAX_GRAD_NORM)
                self.optimizer.step()

        # 학습 로그를 위해 손실 값 반환
        return {
            "policy_loss": policy_loss.item(),
            "value_loss": value_loss.item(),
            "entropy": entropy.mean().item()
        }

# ==============================================================================
# 섹션 5: 평가 및 메인 루프 (리팩토링 버전)
# ==============================================================================
import time
from torch.utils.tensorboard import SummaryWriter
import numpy as np


def evaluate_agent(manager_agent: ManagerAgent, worker_agent: WorkerAgent, problems: list[dict], max_num_ports: int, config: Config):
    """
    현재 에이전트의 성능을 평가하고, 다양한 지표를 기록합니다.
    """
    logging.info("=" * 20 + " AGENT EVALUATION START " + "=" * 20)
    manager_agent.network.eval()
    worker_agent.network.eval()

    total_success_count = 0
    total_costs = []
    total_relocations = []
    manager_action_counts = collections.defaultdict(int)

    original_env_wrapper = manager_agent.env_wrapper
    MAX_EVAL_MANAGER_STEPS = 500

    # 💡 --- [핵심] 환경 객체를 루프 밖에서 한 번만 생성 --- 💡
    # 첫 번째 문제로 h_env를 초기화하고, 이후에는 reset으로 재사용
    prob = random.choice(problems)
    h_env = HierarchicalEnvWrapper(prob, max_num_ports, worker_agent, config)
    manager_agent.env_wrapper = h_env

    for episode_num in range(config.EVAL_EPISODES):
        # 💡 [수정] 매번 새로 생성하는 대신, reset 메서드로 문제만 교체
        if episode_num > 0:
            prob = random.choice(problems)
            # h_env.reset()은 manager_state를 반환하지 않으므로, 직접 상태를 가져와야 합니다.
            h_env.reset(prob_data=prob)
            manager_state = h_env._get_manager_state()
        else:
            # 첫 에피소드는 이미 h_env 생성 시 리셋됨
            manager_state = h_env._get_manager_state()

        if hasattr(manager_agent, 'prev_action_idx'):
             manager_agent.prev_action_idx.zero_()

        overall_done = False
        episode_cost = 0.0
        current_episode_steps = 0

        while not overall_done:
            if current_episode_steps > MAX_EVAL_MANAGER_STEPS:
                logging.warning(f"\nEval Episode [{episode_num+1}] reached max steps limit. Breaking loop.")
                break

            with torch.no_grad():
                legal_actions = h_env.ship_env.get_legal_actions(for_worker=False)
                prev_action_idx = manager_agent.prev_action_idx if hasattr(manager_agent, 'prev_action_idx') else torch.zeros(1, dtype=torch.long, device=manager_agent.device)

                manager_action_idx, _, _, _, _ = manager_agent.get_action_and_value(
                    manager_state.to(manager_agent.device), legal_actions, prev_action_idx, greedy=True
                )

            if hasattr(manager_agent, 'prev_action_idx'):
                manager_agent.prev_action_idx = torch.tensor(
                    [manager_action_idx], dtype=torch.long, device=manager_agent.device
                )

            manager_action_counts[h_env.manager_action_map[manager_action_idx]] += 1
            current_goal_str = h_env.manager_action_map[manager_action_idx]
            print(f"\r  Eval Ep[{episode_num+1}/{config.EVAL_EPISODES}] Step[{current_episode_steps+1}]: Trying Goal -> {current_goal_str.ljust(25)}", end="")

            manager_state, _, overall_done, info = h_env.step(
                manager_action_idx, greedy_worker=True
            )
            episode_cost += info.get('cost', 0.0)
            current_episode_steps += 1

        print()

        if len(h_env.ship_env.delivered_cars) == h_env.ship_env.total_cars:
            total_success_count += 1
        total_costs.append(episode_cost)
        total_relocations.append(h_env.ship_env.relocations_this_episode)
        logging.info(f"  Eval Episode [{episode_num+1}/{config.EVAL_EPISODES}] Finished. Success: {len(h_env.ship_env.delivered_cars) == h_env.ship_env.total_cars}")

    # 평가가 끝나면 원래 환경으로 복원
    manager_agent.env_wrapper = original_env_wrapper
    manager_agent.network.train()
    worker_agent.network.train()

    # 최종 결과 계산 및 로깅
    success_rate = total_success_count / config.EVAL_EPISODES
    avg_cost = np.mean(total_costs) if total_costs else 0.0
    avg_relocations = np.mean(total_relocations) if total_relocations else 0.0

    logging.info("-" * 54)
    logging.info(f"[EVAL RESULT] Success Rate: {success_rate*100:.1f}%")
    logging.info(f"[EVAL RESULT] Average Cost: {avg_cost:.2f}")
    logging.info(f"[EVAL RESULT] Average Relocations: {avg_relocations:.2f}")

    total_actions = sum(manager_action_counts.values())
    if total_actions > 0:
        logging.info("[EVAL RESULT] Manager Action Distribution:")
        for action_idx in sorted(h_env.manager_action_map.keys()):
            count = manager_action_counts.get(h_env.manager_action_map[action_idx], 0)
            action_str = h_env.manager_action_map[action_idx]
            percentage = (count / total_actions) * 100
            logging.info(f"  - {action_str:<25s}: {count} times ({percentage:.1f}%)")

    logging.info("=" * 22 + " EVALUATION END " + "=" * 22 + "\n")

    return {
        "success_rate": success_rate,
        "avg_cost": avg_cost,
        "avg_relocations": avg_relocations
    }

if __name__ == '__main__':
    mp.set_start_method("spawn", force=True) # 멀티 프로세
    # --- 1. 초기 설정 및 환경 구성 ---
    setup_logger()
    config = Config()
    writer = SummaryWriter(log_dir=config.LOG_DIR)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")

    assert torch.cuda.is_available(), "CUDA가 활성화되어 있지 않습니다!"
    print("✅ CUDA 활성화 확인")

    all_problem_files = [os.path.join(config.PROBLEM_DIR, f) for f in os.listdir(config.PROBLEM_DIR) if f.endswith('.json')]
    easy_problem_files = [f for f in all_problem_files if any(name in f for name in ['prob1', 'prob2', 'prob4'])]
    easy_problems = [json.load(open(f)) for f in easy_problem_files if os.path.exists(f)]
    all_problems = [json.load(open(f)) for f in all_problem_files if os.path.exists(f)]
    if not all_problems:
        logging.error(f"No problem files found in {config.PROBLEM_DIR}. Exiting.")
        exit()
    if not easy_problems:
        logging.warning(f"No easy problems found. Starting with all problems.")
        easy_problems = all_problems

    max_cars = max(sum(q for _, q in p.get('K', [])) for p in all_problems)
    max_num_ports = max(p.get('P', 1) for p in all_problems)
    max_nodes = max(p.get('N', 1) for p in all_problems)
    node_feature_size = 4
    global_feature_size = 3 + max_num_ports

    # --- 2. 에이전트 및 환경 초기화 ---
    worker_agent = WorkerAgent(node_feature_size, global_feature_size, max_cars, max_nodes, config, device)
    current_prob = random.choice(easy_problems)
    h_env = HierarchicalEnvWrapper(current_prob, max_num_ports, worker_agent, config)
    manager_agent = ManagerAgent(config, device, h_env)
    manager_agent.prev_action_idx = torch.zeros(1, dtype=torch.long, device=device)

    # --- 3. Worker 사전 훈련 (모방 학습) ---
    print("\nDEBUG: >>>>>>>>>> STEP 1: Calling pretrain_with_imitation NOW...")
    worker_agent.pretrain_with_imitation(
        config.EXPERT_DATA_PATHS, config.IMITATION_LEARNING_EPOCHS, config.IMITATION_LR, config.IMITATION_BATCH_SIZE
    )
    print("DEBUG: <<<<<<<<<< STEP 1: pretrain_with_imitation FINISHED.\n")

    # --- 4. 계층적 강화학습 메인 루프 ---
    logging.info("\n[Phase 2] Starting Hierarchical Reinforcement Learning...")

    manager_state = h_env.reset(prob_data=current_prob)
    manager_storage = PPOStorage(config.MANAGER_NUM_STEPS_PER_UPDATE, (1,), device,
                                 state_shape=(config.MANAGER_STATE_DIM,),
                                 manager_action_dim=config.MANAGER_ACTION_DIM)

    episode_rewards, episode_costs, episode_manager_steps = 0, 0, 0
    start_time = time.time()

    # Interval 통계를 위한 변수 초기화
    interval_rewards = 0.0
    interval_goals = defaultdict(int)
    interval_successes = 0
    interval_costs = 0.0
    interval_worker_rewards = 0.0
    interval_worker_steps = 0

    for manager_step in range(1, config.TOTAL_MANAGER_STEPS + 1):
        is_curriculum_phase = manager_step < config.CURRICULUM_STEPS

        if is_curriculum_phase:
            possible_goals = []
            if not h_env._is_goal_achieved('FINISH_UNLOAD'): possible_goals.append(1)
            if not h_env._is_goal_achieved('CLEAR_TEMP'): possible_goals.append(3)
            if not h_env._is_goal_achieved('FINISH_LOAD'): possible_goals.append(2)
            if possible_goals: manager_action_idx = random.choice(possible_goals)
            else: manager_action_idx = 4
            manager_action_tensor = torch.tensor([manager_action_idx], device=device)
            m_log_prob, m_value, m_mask = torch.tensor(0.0), torch.tensor(0.0), None
            manager_agent.prev_action_idx = manager_action_tensor.clone()
        else:
            legal_actions = h_env.ship_env.get_legal_actions(for_worker=False)
            prev_idx = manager_agent.prev_action_idx
            with torch.no_grad():
                manager_action_tensor, m_log_prob, _, m_value, m_mask = manager_agent.get_action_and_value(
                    manager_state, legal_actions, prev_idx, greedy=False
                )
            manager_action_idx = manager_action_tensor.item()
            manager_agent.prev_action_idx = manager_action_tensor.clone()

        next_manager_state, manager_reward, overall_done, worker_info = h_env.step(manager_action_idx)

        episode_rewards += manager_reward
        episode_costs += worker_info.get('cost', 0.0)
        episode_manager_steps += 1

        # 매 스텝마다 Interval 통계 누적
        interval_rewards += manager_reward
        interval_goals[worker_info.get('goal', 'N/A')] += 1
        if worker_info.get('success', False):
            interval_successes += 1
        interval_costs += worker_info.get('cost', 0.0)
        interval_worker_rewards += worker_info.get('worker_total_reward', 0.0) # h_env.step에서 이 값을 반환해야 함
        interval_worker_steps += worker_info.get('steps', 0)

        if not is_curriculum_phase:
            manager_storage.add(manager_state, manager_action_tensor, m_log_prob, manager_reward, overall_done, m_value, m_mask)

        manager_state = next_manager_state

        if overall_done:
            s = h_env.ship_env
            success_ratio = len(s.delivered_cars) / s.total_cars if s.total_cars > 0 else 0
            logging.info(f"EPISODE DONE (M-Step: {manager_step}) | Success: {success_ratio*100:.1f}% | Total Reward: {episode_rewards:.2f} | Total Cost: {episode_costs:.2f} | Length: {episode_manager_steps} steps")
            writer.add_scalar("Episode/TotalReward", episode_rewards, manager_step)
            writer.add_scalar("Episode/TotalCost", episode_costs, manager_step)
            writer.add_scalar("Episode/SuccessRatio", success_ratio, manager_step)
            writer.add_scalar("Episode/Length", episode_manager_steps, manager_step)

            if manager_step < config.CURRICULUM_TRANSITION_STEP:
                current_prob = random.choice(easy_problems)
            else:
                if manager_step - episode_manager_steps < config.CURRICULUM_TRANSITION_STEP:
                     logging.info("="*20 + " SWITCHING TO FULL PROBLEM SET " + "="*20)
                current_prob = random.choice(all_problems)

            manager_state = h_env.reset(prob_data=current_prob)
            manager_agent.prev_action_idx.zero_()
            episode_rewards, episode_costs, episode_manager_steps = 0, 0, 0

        if not is_curriculum_phase and manager_storage.is_full():
            loss_info = manager_agent.update(manager_storage)
            if loss_info:
                writer.add_scalar("Train/Manager_PolicyLoss", loss_info["policy_loss"], manager_step)
                writer.add_scalar("Train/Manager_ValueLoss", loss_info["value_loss"], manager_step)
                writer.add_scalar("Train/Manager_Entropy", loss_info["entropy"], manager_step)
            manager_storage.reset()

        # 💡 --- [수정] 주기적 로그 출력 (중복 제거 및 최종 버전) --- 💡
        if manager_step % config.PRINT_INTERVAL_MANAGER_STEPS == 0:
            elapsed_time = time.time() - start_time
            steps_per_sec = config.PRINT_INTERVAL_MANAGER_STEPS / elapsed_time if elapsed_time > 0 else 0

            avg_reward = interval_rewards / config.PRINT_INTERVAL_MANAGER_STEPS
            interval_success_rate = (interval_successes / config.PRINT_INTERVAL_MANAGER_STEPS) * 100
            avg_cost_per_step = interval_costs / config.PRINT_INTERVAL_MANAGER_STEPS
            # 💡 [추가] Worker 평균 보상 계산
            avg_worker_rew = interval_worker_rewards / interval_worker_steps if interval_worker_steps > 0 else 0

            goal_dist_str = ", ".join([f"{k.split('_')[-1]}:{v}" for k, v in sorted(interval_goals.items())])

            logging.info(
                f"M-Step {manager_step:6d} | Avg M-Rew: {avg_reward:7.2f} | "
                f"W-Rew: {avg_worker_rew:6.3f} | " # Worker 평균 보상 출력
                f"Success: {interval_success_rate:3.0f}% | Avg Cost: {avg_cost_per_step:8.1f} | "
                f"Goals: [{goal_dist_str}] | SPS: {steps_per_sec:.2f}"
            )
            goal_dist_str = ", ".join([f"{k.split('_')[-1]}:{v}" for k, v in sorted(interval_goals.items())])

            logging.info(
                f"M-Step {manager_step:6d} | Avg Rew: {avg_reward:7.2f} | "
                f"Success: {interval_success_rate:3.0f}% | Avg Cost: {avg_cost_per_step:8.1f} | "
                f"Goals: [{goal_dist_str}] | SPS: {steps_per_sec:.2f}"
            )

            # 다음 Interval을 위해 통계 변수 초기화
            interval_worker_rewards = 0.0
            interval_worker_steps = 0
            interval_rewards = 0.0
            interval_goals.clear()
            interval_successes = 0
            interval_costs = 0.0
            start_time = time.time()

        # 주기적 평가 및 모델 저장
        if manager_step > 0 and manager_step % config.EVAL_INTERVAL_MANAGER_STEPS == 0:
            if manager_step == config.CURRICULUM_STEPS:
                logging.info("="*20 + " CURRICULUM FINISHED " + "="*20)

            eval_results = evaluate_agent(manager_agent, worker_agent, all_problems, max_num_ports, config)
            writer.add_scalar("Eval/SuccessRate", eval_results["success_rate"] * 100.0, manager_step)
            writer.add_scalar("Eval/AvgCost", eval_results["avg_cost"], manager_step)
            writer.add_scalar("Eval/AvgRelocations", eval_results["avg_relocations"], manager_step)

            torch.save(worker_agent.network.state_dict(), os.path.join(config.MODEL_DIR, f"worker_model_step_{manager_step}.pth"))
            torch.save(manager_agent.network.state_dict(), os.path.join(config.MODEL_DIR, f"manager_model_step_{manager_step}.pth"))

    writer.close()
    logging.info("--- V12 Refactored Training Completed ---")

2025-07-03 08:24:21 - INFO - Using device: cuda
✅ CUDA 활성화 확인





DEBUG: >>>>>>>>>> STEP 1: Calling pretrain_with_imitation NOW...
2025-07-03 08:26:43 - INFO - [Phase 1] Starting Imitation Learning for Worker Agent…
2025-07-03 08:26:45 - INFO -   · Total Samples: 5,420
2025-07-03 08:27:57 - INFO -     Epoch   10/200  avg-loss 11.2702
2025-07-03 08:29:10 - INFO -     Epoch   20/200  avg-loss 9.8362
2025-07-03 08:30:22 - INFO -     Epoch   30/200  avg-loss 8.7999
2025-07-03 08:31:34 - INFO -     Epoch   40/200  avg-loss 8.1010
2025-07-03 08:32:47 - INFO -     Epoch   50/200  avg-loss 7.6587
2025-07-03 08:33:59 - INFO -     Epoch   60/200  avg-loss 7.2841
2025-07-03 08:35:12 - INFO -     Epoch   70/200  avg-loss 7.0093
2025-07-03 08:36:24 - INFO -     Epoch   80/200  avg-loss 6.7883
2025-07-03 08:37:37 - INFO -     Epoch   90/200  avg-loss 6.5952
2025-07-03 08:38:50 - INFO -     Epoch  100/200  avg-loss 6.4202
2025-07-03 08:40:02 - INFO -     Epoch  110/200  avg-loss 6.2706
