In [1]:
#!/usr/bin/env python3
import argparse
import hashlib
from pathlib import Path

import torch
import matplotlib.pyplot as plt
import networkx as nx

from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, subgraph, k_hop_subgraph


# ─────────────────────────────────────────────
# 해시 함수 (프로젝트와 동일)
# ─────────────────────────────────────────────
def _hash(key: str) -> str:
    return hashlib.md5(key.encode()).hexdigest() + ".pt"


# ─────────────────────────────────────────────
# PyG Data 로더
# ─────────────────────────────────────────────
def _load_pyg_data(pt_path: Path) -> Data:
    # PyTorch 2.6+ 대응 (weights_only=False)
    obj = torch.load(pt_path, map_location="cpu", weights_only=False)
    if isinstance(obj, Data):
        return obj
    if isinstance(obj, dict):
        for k in ["data", "graph", "g", "G"]:
            if k in obj and isinstance(obj[k], Data):
                return obj[k]
    raise ValueError(f"PyG Data를 {pt_path}에서 찾지 못했습니다.")


# ─────────────────────────────────────────────
# 큰 그래프 → 상위 차수 노드만 남김
# ─────────────────────────────────────────────
def _degree_topk_subgraph(data: Data, max_nodes: int) -> Data:
    if data.num_nodes <= max_nodes:
        return data
    G = to_networkx(data, to_undirected=True)
    deg = dict(G.degree())
    keep = sorted(deg.keys(), key=lambda n: deg[n], reverse=True)[:max_nodes]
    keep = torch.tensor(keep, dtype=torch.long)
    ei, _ = subgraph(keep, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)

    sub = Data()
    sub.edge_index = ei
    if hasattr(data, "x") and data.x is not None and data.x.numel() > 0:
        sub.x = data.x[keep]
    if hasattr(data, "edge_type"):
        pass
    if hasattr(data, "node_meta"):
        sub.node_meta = data.node_meta
    if hasattr(data, "utt_meta"):
        sub.utt_meta = data.utt_meta
    return sub


# ─────────────────────────────────────────────
# k-hop 서브그래프
# ─────────────────────────────────────────────
def _khop_subgraph(data: Data, center: int, khop: int) -> Data:
    subset, edge_index2, _, edge_mask = k_hop_subgraph(
        center, khop, data.edge_index,
        relabel_nodes=True, num_nodes=data.num_nodes, return_edge_mask=True
    )
    sub = Data()
    sub.edge_index = edge_index2
    if hasattr(data, "x") and data.x is not None and data.x.numel() > 0:
        sub.x = data.x[subset]
    if hasattr(data, "edge_type") and data.edge_type is not None and edge_mask is not None:
        try:
            sub.edge_type = data.edge_type[edge_mask]
        except Exception:
            pass
    if hasattr(data, "node_meta"):
        sub.node_meta = data.node_meta
    if hasattr(data, "utt_meta"):
        sub.utt_meta = data.utt_meta
    return sub


# ─────────────────────────────────────────────
# 시각화 함수
# ─────────────────────────────────────────────
def visualize_graph(data: Data, out_path: Path, layout="spring", seed=42, show_labels=False):
    G = to_networkx(data, to_undirected=True)

    # 레이아웃
    if layout == "spring":
        pos = nx.spring_layout(G, seed=seed)
    elif layout == "kamada":
        pos = nx.kamada_kawai_layout(G)
    else:
        pos = nx.spectral_layout(G)

    plt.figure(figsize=(12, 9), dpi=250)

    # ───────────────
    # ① 노드 타입별 색상
    # ───────────────
    node_colors = []
    node_labels = {}
    idx = 0

    text_n = data.node_meta.get("text_nodes", 0) if hasattr(data, "node_meta") else 0
    know_n = data.node_meta.get("knowledge_nodes", 0) if hasattr(data, "node_meta") else 0
    video_n = data.node_meta.get("video_nodes", 0) if hasattr(data, "node_meta") else 0
    audio_n = data.node_meta.get("audio_nodes", 0) if hasattr(data, "node_meta") else 0

    for _ in range(text_n):
        node_colors.append("#1f77b4")  # 파랑
        node_labels[idx] = "T"
        idx += 1
    for _ in range(know_n):
        node_colors.append("#ff7f0e")  # 주황
        node_labels[idx] = "K"
        idx += 1
    for _ in range(video_n):
        node_colors.append("#2ca02c")  # 초록
        node_labels[idx] = "V"
        idx += 1
    for _ in range(audio_n):
        node_colors.append("#d62728")  # 빨강
        node_labels[idx] = "A"
        idx += 1

    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=120,
                           edgecolors="black", linewidths=0.4, alpha=0.9)

    # ───────────────
    # ② 엣지 타입별 색상
    # ───────────────
    legend_handles = []
    if hasattr(data, "edge_type") and data.edge_type is not None:
        et = data.edge_type.cpu().tolist()
        edges = list(zip(data.edge_index[0].tolist(), data.edge_index[1].tolist()))
        edge_type_map = {
            0: ("#999999", "text↔text"),
            1: ("#2ca02c", "video↔video"),
            2: ("#d62728", "audio↔audio"),
            3: ("#17becf", "text↔video"),
            4: ("#9467bd", "text↔audio"),
            5: ("#ff7f0e", "utterance link"),
        }
        for t, (color, label) in edge_type_map.items():
            mask = [i for i, tt in enumerate(et) if tt == t]
            if not mask:
                continue
            edges_t = [(edges[i][0], edges[i][1]) for i in mask]
            nx.draw_networkx_edges(G, pos, edgelist=edges_t,
                                   edge_color=color, width=1.0, alpha=0.6)
            from matplotlib.lines import Line2D
            legend_handles.append(Line2D([0], [0], color=color, lw=2, label=label))
    else:
        nx.draw_networkx_edges(G, pos, edge_color="#cccccc", width=0.6)

    # ───────────────
    # ③ 라벨
    # ───────────────
    if show_labels:
        nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=6)

    if legend_handles:
        plt.legend(handles=legend_handles, loc="best", fontsize=8)

    plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print(f"[OK] 저장 완료: {out_path}")


# ─────────────────────────────────────────────
# 메인
# ─────────────────────────────────────────────
def main():
    ap = argparse.ArgumentParser(description="캐시된 비디오 그래프 시각화")
    ap.add_argument("--video-id", type=str, default="-Ac3MwMrIVU_0424", help="비디오 아이디")
    ap.add_argument("--cache-dir", type=str, default="cache", help="그래프 캐시 폴더")
    ap.add_argument("--out", type=str, default=None, help="출력 이미지 경로 (기본: graph_<video_id>.png)")
    ap.add_argument("--layout", type=str, default="spring", choices=["spring", "kamada", "spectral"])
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--show-labels", action="store_true", help="노드 타입 라벨 표시")
    ap.add_argument("--max-nodes", type=int, default=None, help="큰 그래프라면 상위 차수 노드만 남겨 그리기")
    ap.add_argument("--khop", type=int, default=None, help="k-hop 서브그래프")
    ap.add_argument("--center", type=int, default=None, help="k-hop 중심 노드 id")
    args = ap.parse_args()

    key = f"video_graph::{args.video_id}"
    filename = _hash(key)
    pt_path = Path(args.cache_dir) / filename
    out_path = Path(args.out) if args.out else Path(f"graph_{args.video_id}.png")

    print(f"[INFO] video_id={args.video_id}")
    print(f"[INFO] hash key='{key}'  →  file='{filename}'")
    print(f"[INFO] cache path: {pt_path}")

    if not pt_path.exists():
        raise FileNotFoundError(f"그래프 파일이 없습니다: {pt_path}\n"
                                f"※ 먼저 build_and_cache_graphs()를 실행해 캐시를 생성했는지 확인하세요.")

    data = _load_pyg_data(pt_path)

    # 서브그래프 옵션 적용
    if args.khop is not None and args.center is not None:
        data = _khop_subgraph(data, center=args.center, khop=args.khop)
    elif args.max_nodes is not None:
        data = _degree_topk_subgraph(data, max_nodes=args.max_nodes)

    visualize_graph(
        data,
        out_path=out_path,
        layout=args.layout,
        seed=args.seed,
        show_labels=args.show_labels
    )


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm
usage: ipykernel_launcher.py [-h] [--video-id VIDEO_ID]
                             [--cache-dir CACHE_DIR] [--out OUT]
                             [--layout {spring,kamada,spectral}] [--seed SEED]
                             [--show-labels] [--max-nodes MAX_NODES]
                             [--khop KHOP] [--center CENTER]
ipykernel_launcher.py: error: unrecognized arguments: --f=/run/user/1000/jupyter/runtime/kernel-v387a7d3b1072dd52851d04b5dcb936c1196d3198f.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [1]:
import pandas as pd
import sqlite3
import torch
import hashlib
import json
import re
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple
from torch_geometric.data import Data
from transformers import BertTokenizerFast, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch.nn as nn
import os
import sys
from pyvis.network import Network
import networkx as nx
import matplotlib.pyplot as plt  # 사용 안 해도 pyvis가 내부에서 필요할 수 있어요
from pathlib import Path
import math
import sqlite3
import torch
from torch_geometric.utils import to_networkx

# GPU 사용 여부 확인 및 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 해시 함수 정의
def _hash(key: str) -> str:
    return hashlib.md5(key.encode()).hexdigest() + ".pt"

# Tensor를 지정된 디바이스로 옮기는 헬퍼 함수
def _to_device(data: Dict, device: torch.device) -> Dict:
    return {k: v.to(device) for k, v in data.items()}

# VRAM 로깅 함수
def log_vram(stage: str, device: torch.device):
    if device.type == 'cuda':
        torch.cuda.synchronize()
        dev_id = device.index if isinstance(device, torch.device) else torch.cuda.current_device()
        alloc = torch.cuda.memory_allocated(dev_id) / 1024**2
        reserved = torch.cuda.memory_reserved(dev_id) / 1024**2
        peak_a = torch.cuda.max_memory_allocated(dev_id) / 1024**2
        peak_r = torch.cuda.max_memory_reserved(dev_id) / 1024**2
        wasted = reserved - alloc
        print(f"[{stage:15s}] alloc: {alloc:6.1f} MB | reserved: {reserved:6.1f} MB | "
              f"peak_alloc: {peak_a:6.1f} MB | peak_reserved: {peak_r:6.1f} MB | wasted: {wasted:6.1f} MB")
        torch.cuda.reset_peak_memory_stats(dev_id)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# -*- coding: utf-8 -*-
import os
import re
import sys
import json
import sqlite3
import hashlib
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import pandas as pd
from torch_geometric.data import Data
from transformers import (
    BertTokenizerFast,
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

# ─────────────────────────────────────────────
# 환경
# ─────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Graph builder] Using device: {device}")

def _hash(key: str) -> str:
    return hashlib.md5(key.encode()).hexdigest() + ".pt"

def _to_device(data: Dict, device: torch.device) -> Dict:
    return {k: v.to(device) for k, v in data.items()}

def log_vram(stage: str, device: torch.device):
    if device.type == "cuda":
        torch.cuda.synchronize()
        dev_id = device.index if isinstance(device, torch.device) else torch.cuda.current_device()
        alloc = torch.cuda.memory_allocated(dev_id) / 1024**2
        reserved = torch.cuda.memory_reserved(dev_id) / 1024**2
        peak_a = torch.cuda.max_memory_allocated(dev_id) / 1024**2
        peak_r = torch.cuda.max_memory_reserved(dev_id) / 1024**2
        wasted = reserved - alloc
        print(f"[{stage:15s}] alloc: {alloc:6.1f} MB | reserved: {reserved:6.1f} MB | "
              f"peak_alloc: {peak_a:6.1f} MB | peak_reserved: {peak_r:6.1f} MB | wasted: {wasted:6.1f} MB")
        torch.cuda.reset_peak_memory_stats(dev_id)

# ─────────────────────────────────────────────
# 엣지 타입 (★ t_k / k_k 포함)
# ─────────────────────────────────────────────
EDGE_TYPE: Dict[str, int] = {
    "t_t": 0,  # text  ↔ text
    "v_v": 1,  # video ↔ video
    "a_a": 2,  # audio ↔ audio
    "t_v": 3,  # text  ↔ video
    "t_a": 4,  # text  ↔ audio
    "utt": 5,  # utterance ↔ utterance (for merging)
    "t_k": 6,  # ★ text(언급 토큰) ↔ knowledge(관계)
    "k_k": 7,  # ★ knowledge(관계) ↔ knowledge(테일 엔티티)
}

def _add_bidir(
    src: int,
    dst: int,
    etype: int,
    edge_src: List[int],
    edge_dst: List[int],
    edge_type: List[int],
):
    edge_src.extend([src, dst])
    edge_dst.extend([dst, src])
    edge_type.extend([etype, etype])

# ─────────────────────────────────────────────
# 외부 지식 모델 (라벨/특성 조회)
# ─────────────────────────────────────────────
class ExternalFinancialKnowledgeModel:
    def __init__(self,
                 wiki_db="data/wikidata_revisions.db",
                 speech_db="data/speech_segments.db"):
        self.wiki = sqlite3.connect(wiki_db); self.wiki.row_factory = sqlite3.Row
        self.speech = sqlite3.connect(speech_db); self.speech.row_factory = sqlite3.Row

        self.target_entities = self._load_target_entities()

        label_df = pd.read_sql(
            "SELECT DISTINCT qid, value AS label FROM labels WHERE lang='en';",
            self.wiki,
        )
        label_df = label_df[label_df["label"].str.lower().isin([n.lower() for n in self.target_entities])]
        label_df["id_int"] = pd.factorize(label_df["qid"])[0]

        self.qid2int = dict(zip(label_df["qid"], label_df["id_int"]))
        self.int2qid = {v: k for k, v in self.qid2int.items()}
        self.label2qid = {l.lower(): q for q, l in zip(label_df["qid"], label_df["label"])}

        def _name_to_pattern(name: str) -> str:
            name = name.strip().lower()
            if not name:
                return ""
            parts = re.split(r"\s+", name)
            parts = [re.escape(p) for p in parts if p]
            if not parts:
                return ""
            return r"\b" + r"\s+".join(parts) + r"\b"

        safe = [_name_to_pattern(n) for n in self.target_entities]
        safe = [p for p in safe if p]
        safe.sort(key=len, reverse=True)
        self._pat = re.compile("|".join(safe), flags=re.I)

    def _load_target_entities(self) -> List[str]:
        df = pd.read_sql("SELECT persons_found FROM video_metadata;", self.speech)
        names = set()
        for js in df["persons_found"]:
            if js:
                names.update(json.loads(js).keys())
        return list(names)

    def identify_entities(self, text: str) -> List[str]:
        t = re.sub(r"[^\w\s]", "", text.lower())
        return list({m.strip() for m in self._pat.findall(t)})

    def entities_to_id(self, ents: List[str]) -> List[int]:
        return [self.qid2int[self.label2qid[e.lower()]]
                for e in ents if e.lower() in self.label2qid]

    def qid_to_label(self, qid: str) -> str:
        row = self.wiki.execute(
            "SELECT value FROM labels WHERE qid=? AND lang='en' LIMIT 1",
            (qid,),
        ).fetchone()
        return row[0] if row else qid

    def pid_to_label(self, pid: str) -> str:
        row = self.wiki.execute(
            "SELECT label FROM property_labels WHERE pid=? LIMIT 1",
            (pid,),
        ).fetchone()
        return row[0] if row else pid



    @lru_cache(maxsize=32)
    def _graph_until(self, time_iso: str) -> Data:
        sql = (
            "SELECT c.qid subj, c.property pid, c.value_qid obj "
            "FROM claims c JOIN revisions r USING(qid,revision_id) "
            "WHERE r.timestamp<=?"
        )
        df = pd.read_sql(sql, self.wiki, params=(time_iso,))
        df = df[df["subj"].isin(self.qid2int) & df["obj"].isin(self.qid2int)]
        if df.empty:
            return Data()

        src = torch.tensor(df["subj"].map(self.qid2int).to_numpy(), dtype=torch.long)
        dst = torch.tensor(df["obj"].map(self.qid2int).to_numpy(), dtype=torch.long)

        # ★ factorize 하되, 고유 PID 목록(uniques)을 vocab으로 함께 보관
        codes, uniques = pd.factorize(df["pid"])
        rel = torch.tensor(codes, dtype=torch.long).view(-1, 1)

        G = Data(edge_index=torch.stack([src, dst]), edge_attr=rel)
        # ★ 진짜 PID 문자열 리스트 보관
        G.pid_vocab = [str(p) for p in uniques.tolist()]
        return G

    def acquire_related_external_knowledge(
        self, text: str, time_iso: str, add_reverse=True, add_self_loop=True
    ) -> Tuple[List[int], Data]:
        ids = self.entities_to_id(self.identify_entities(text))
        G = self._graph_until(time_iso)
        if not ids or G.edge_index.numel() == 0:
            return ids, Data()

        mask = (torch.isin(G.edge_index[0], torch.tensor(ids)) |
                torch.isin(G.edge_index[1], torch.tensor(ids)))
        ei, ea = G.edge_index[:, mask], G.edge_attr[mask]

        if add_reverse:
            ei = torch.cat([ei, ei.flip(0)], 1)
            ea = torch.cat([ea, ea], 0)
        if add_self_loop:
            loops = torch.tensor(ids, dtype=torch.long, device=ei.device)
            ei = torch.cat([ei, loops.unsqueeze(0).repeat(2, 1)], 1)
            ea = torch.cat([ea, torch.full((len(loops), 1), -1, dtype=torch.long, device=ei.device)], 0)

        sub = Data(edge_index=ei, edge_attr=ea)
        # ★ 하위 서브그래프에도 vocab 전달
        if hasattr(G, "pid_vocab"):
            sub.pid_vocab = G.pid_vocab
        return ids, sub

    def pid_to_label(self, pid: str) -> str:
        row = self.wiki.execute(
            "SELECT label FROM property_labels WHERE pid=? LIMIT 1",
            (pid,),
        ).fetchone()
        return row[0] if row else pid


# ─────────────────────────────────────────────
# 텍스트 인코더 (BERT 오프셋 ↔ GLM 히든 매핑)
# ─────────────────────────────────────────────
def build_cross_map(wp_offsets: List[Tuple[int, int]],
                    glm_offsets: List[Tuple[int, int]]) -> List[List[int]]:
    mapping = [[] for _ in wp_offsets]
    p = 0
    for i, (ws, we) in enumerate(wp_offsets):
        while p < len(glm_offsets) and glm_offsets[p][1] <= ws:
            p += 1
        q = p
        while q < len(glm_offsets) and glm_offsets[q][0] < we:
            mapping[i].append(q)
            q += 1
    return mapping

class TextFeatureExtractor:
    def __init__(self):
        self.wp_tok = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.glm_tok: AutoTokenizer = AutoTokenizer.from_pretrained(
            "THUDM/chatglm2-6b", trust_remote_code=True
        )
        self.glm = AutoModelForCausalLM.from_pretrained(
            "THUDM/chatglm2-6b",
            trust_remote_code=True,
            torch_dtype=torch.float16,   # ★ FP16
            output_hidden_states=True,
            device_map="auto"
        )
        self.proj_down = nn.Linear(4096, 768, bias=False).to(device).half()
        self.proj_up = nn.Linear(768, 4096, bias=False).to(device).half()
        self.norm = nn.LayerNorm(768).to(device).half()
        for m in (self.proj_down, self.proj_up):
            nn.init.xavier_uniform_(m.weight)

    def _manual_offsets(self, text: str, toks: List[str]):
        norm = text.lower()
        p, off = 0, []
        for t in toks:
            tc = t.lstrip(" ")
            tc = tc if tc else " "
            j = norm.find(tc, p)
            j = j if j != -1 else p
            off.append((j, j + len(tc)))
            p = j + len(tc)
        return off

    @torch.no_grad()
    def encode(
        self,
        utterance_text: str,
        knowledge_triples: Optional[List[Tuple[str, str, str]]] = None,  # (h_label, r_label, t_label)
        anchor_entities: Optional[List[str]] = None,                      # (옵션) QIDs or labels
        knowledge_text_override: Optional[str] = None,                    # ★ 우리가 만든 문자열 그대로 사용
    ):
        # ── knowledge_text 구성 ─────────────────────────────────────
        if knowledge_text_override is not None:
            knowledge_text = knowledge_text_override
        else:
            know_parts = []
            if anchor_entities:
                know_parts.append(" [ENT] ".join(map(str, anchor_entities)))
            if knowledge_triples:
                tmp = []
                for h, r, t in knowledge_triples:
                    tmp.append(f"{h} [R] {r} [T] {t} [TRI] ")
                know_parts.append("".join(tmp).strip())
            knowledge_text = " ".join(know_parts) if know_parts else None

        # BERT: WP 토큰/오프셋/타입
        wp = self.wp_tok(
            utterance_text,
            text_pair=knowledge_text,
            return_offsets_mapping=True,
            return_token_type_ids=True,
            return_tensors="pt",
        )
        wp = _to_device(wp, self.glm.device)
        wp_tokens = self.wp_tok.convert_ids_to_tokens(wp["input_ids"][0])
        wp_offsets = wp["offset_mapping"][0].tolist()  # ★ 각 WP 토큰의 (start,end)
        token_types = wp["token_type_ids"][0].tolist()
        sep_idx = [i for i, tok in enumerate(wp_tokens) if tok == "[SEP]"]

        # knowledge 블록(WP 인덱스 범위들)
        knowledge_blocks: List[Tuple[int, int]] = []
        in_block, start = False, None
        for i, (tt, tok) in enumerate(zip(token_types, wp_tokens)):
            if tt == 1 and tok != "[SEP]":
                if not in_block:
                    start, in_block = i, True
            else:
                if in_block:
                    knowledge_blocks.append((start, i))
                    in_block = False
        if in_block:
            knowledge_blocks.append((start, len(wp_tokens)))

        # GLM: 합쳐진 문자열 기반 임베딩
        merged = utterance_text + (" [SEP] " + knowledge_text if knowledge_text else "")
        glm_ids = self.glm_tok.encode(merged, add_special_tokens=False)
        glm_enc = {
            "input_ids": torch.tensor([glm_ids], device=self.glm.device),
            "attention_mask": torch.ones(1, len(glm_ids), dtype=torch.long, device=self.glm.device),
        }
        glm_tokens = self.glm_tok.convert_ids_to_tokens(glm_ids)
        try:
            glm_offsets = self.glm_tok(
                merged, add_special_tokens=False, return_offsets_mapping=True
            )["offset_mapping"]
        except Exception:
            glm_offsets = self._manual_offsets(merged, glm_tokens)

        hidden4096 = self.glm(**glm_enc).hidden_states[-1][0]  # [L,4096]
        hid768 = self.norm(self.proj_down(hidden4096))         # [L,768]

        # WP ↔ GLM 토큰 정렬 매핑
        map_wp2glm = build_cross_map(wp_offsets, glm_offsets)
        max_idx = hid768.size(0)
        wp_emb = torch.stack([
            (
                hid768[torch.tensor(valid, dtype=torch.long, device=hid768.device)].mean(0)
                if (valid := [i for i in ids if i < max_idx])
                else torch.zeros(768, device=hid768.device)
            )
            for ids in map_wp2glm
        ])  # [N_wp,768]

        meta = {
            "wp_tokens": wp_tokens,
            "glm_tokens": glm_tokens,
            "map_wp2glm": map_wp2glm,
            "sep": sep_idx,
            "knowledge_blocks": knowledge_blocks,
            "wp_offsets": wp_offsets,  # ★ 추가
        }
        return wp_emb, meta

    def cleanup(self):
        print("TextFeatureExtractor 메모리 해제 중...")
        if hasattr(self, "glm"): del self.glm
        if hasattr(self, "proj_down"): del self.proj_down
        if hasattr(self, "proj_up"): del self.proj_up
        torch.cuda.empty_cache()
        print("TextFeatureExtractor 메모리 해제 완료")

# ─────────────────────────────────────────────
# 유틸: 언급 문자열이 덮는 WP 토큰 찾기(발화 구간)
# ─────────────────────────────────────────────
def _find_token_indices_for_phrase(
    text: str,
    phrase: str,
    wp_offsets: List[Tuple[int, int]],
    text_range: Tuple[int, int],  # (wp_start, wp_end) in WP index
) -> List[int]:
    t_norm = text.lower()
    p_norm = phrase.lower().strip()
    if not p_norm:
        return []
    spans = []
    start = 0
    while True:
        j = t_norm.find(p_norm, start)
        if j == -1:
            break
        spans.append((j, j + len(p_norm)))
        start = j + len(p_norm)

    wp_start, wp_end = text_range
    hits = []
    for (s, e) in spans:
        for i in range(wp_start, wp_end):
            ws, we = wp_offsets[i]
            if we > s and ws < e:  # overlap
                hits.append(i)
    return sorted(set(hits))

# ─────────────────────────────────────────────
# 그래프 빌더 (브리징 방식)
# ─────────────────────────────────────────────
class GraphBuilder:
    """발화 그래프 + 외부 지식 브리징(관계/테일을 라벨 임베딩으로)"""
    MAX_KG = 10  # 외부지식 최대 부착 개수

    def __init__(self, *, merge_anchor: bool = False):
        self.merge_anchor = merge_anchor
        self.text_enc = TextFeatureExtractor()
        self.ekm = ExternalFinancialKnowledgeModel()

    def __del__(self):
        if hasattr(self, "text_enc"):
            self.text_enc.cleanup()
            del self.text_enc
        print("GraphBuilder 메모리 해제 완료")
        log_vram("del", device)

    # knowledge_text와 각 트리플의 (rel_span, tail_span) 생성
    @staticmethod
    def _build_knowledge_text_and_spans(
        triples_lbl: List[Tuple[str, str, str]]
    ) -> Tuple[str, List[Tuple[Tuple[int, int], Tuple[int, int]]]]:
        """
        triples_lbl: [(head_label, rel_label, tail_label), ...]
        return:
          knowledge_text: "h [R] r [T] t [TRI] h [R] r [T] t [TRI] ..."
          spans: [ ((rel_start,rel_end),(tail_start,tail_end)), ... ]  (char offsets in knowledge_text)
        """
        parts: List[str] = []
        spans: List[Tuple[Tuple[int, int], Tuple[int, int]]] = []
        pos = 0
        for (h, r, t) in triples_lbl:
            parts.append(h); pos += len(h)
            sepR = " [R] "; parts.append(sepR); pos += len(sepR)

            rel_start = pos
            parts.append(r); pos += len(r)
            rel_end = pos

            sepT = " [T] "; parts.append(sepT); pos += len(sepT)

            tail_start = pos
            parts.append(t); pos += len(t)
            tail_end = pos

            sepTri = " [TRI] "; parts.append(sepTri); pos += len(sepTri)

            spans.append(((rel_start, rel_end), (tail_start, tail_end)))

        return "".join(parts).strip(), spans

    @staticmethod
    def _charspan_to_wp_indices(
        span: Tuple[int, int],
        wp_offsets: List[Tuple[int, int]],
        wp_block: Tuple[int, int],  # knowledge block in WP index (start,end)
    ) -> List[int]:
        """knowledge_text의 char span → 해당하는 '전체 시퀀스'의 WP 인덱스 목록"""
        s, e = span
        b0, b1 = wp_block
        hit = []
        # wp_offsets[b0:b1]는 knowledge_text 기준 오프셋임(토크나이저가 pair로 처리했기 때문)
        for i in range(b0, b1):
            ws, we = wp_offsets[i]
            if we > s and ws < e:  # overlap
                hit.append(i)
        return hit

    def build(
        self,
        utterance_text: str,
        time_iso: str,
        video_emb: Optional[torch.Tensor] = None,
        audio_emb: Optional[torch.Tensor] = None,
    ) -> Data:
        # ── (1) 외부 지식 트리플(QID/PID) 수집 ───────────────────
        ent_ids, ek_sub = self.ekm.acquire_related_external_knowledge(utterance_text, time_iso)
        triples_qpt: List[Tuple[str, str, str]] = []  # (h_qid, pid, t_qid)

        if getattr(ek_sub, "edge_index", None) is not None and ek_sub.edge_index.numel() > 0:
            src_list = ek_sub.edge_index[0].tolist()
            dst_list = ek_sub.edge_index[1].tolist()
            rel_list = ek_sub.edge_attr.tolist()  # [E] 정수 코드
            pid_vocab = getattr(ek_sub, "pid_vocab", None)
            if not pid_vocab:
                raise RuntimeError("pid_vocab이 누락되었습니다. ExternalFinancialKnowledgeModel._graph_until() 수정 필요.")
            for s_idx, d_idx, code in zip(src_list, dst_list, rel_list):
                pid_str = pid_vocab[int(code)]  # 예: "P106"
                triples_qpt.append(
                    (self.ekm.int2qid[int(s_idx)], pid_str, self.ekm.int2qid[int(d_idx)])
                )

        # QID/PID → 라벨로 치환
        triples_lbl: List[Tuple[str, str, str]] = []
        triples_keep: List[Tuple[str, str, str]] = []  # (h_qid, pid, t_qid) 동일 순서
        for (hq, pid, tq) in triples_qpt:
            h_label = self.ekm.qid_to_label(hq)
            r_label = self.ekm.pid_to_label(pid)
            t_label = self.ekm.qid_to_label(tq)
            triples_lbl.append((h_label, r_label, t_label))
            triples_keep.append((hq, pid, tq))

        # 최대 MAX_KG만 사용
        if len(triples_lbl) > self.MAX_KG:
            triples_lbl = triples_lbl[: self.MAX_KG]
            triples_keep = triples_keep[: self.MAX_KG]

        # ── (2) knowledge_text + 각 트리플의 (rel_span, tail_span) 생성 ─
        knowledge_text, spans_per_triple = self._build_knowledge_text_and_spans(triples_lbl)

        # ── (3) 하나의 패스로 인코딩(발화 + 지식) ─────────────────
        wp_emb, meta = self.text_enc.encode(
            utterance_text=utterance_text,
            knowledge_triples=None,                 # override 사용
            anchor_entities=[self.ekm.int2qid[i] for i in ent_ids],
            knowledge_text_override=knowledge_text, # ★ 우리가 만든 문자열 그대로 사용
        )
        if wp_emb.is_cuda:
            wp_emb = wp_emb.cpu()

        hs = wp_emb                         # [N_wp, 768]
        D = hs.size(1)
        sep0 = meta["sep"][0] if meta["sep"] else -1
        wp_offsets = meta.get("wp_offsets", [])
        kb = meta.get("knowledge_blocks", [])
        kb_range = kb[0] if kb else None    # 대개 knowledge_text 전체가 하나의 블록

        node_feats: List[torch.Tensor] = []
        edge_src: List[int] = []
        edge_dst: List[int] = []
        edge_type: List[int] = []
        node_types: List[int] = []  # 0=text, 1=knowledge, 2=video, 3=audio

        # ── (4) 텍스트 토큰 노드(발화 구간) ────────────────────────
        text_nodes: List[int] = []
        if sep0 > 1:
            text_start = len(node_feats)
            node_feats.extend([hs[i] for i in range(1, sep0)])  # [CLS]=0 제외
            text_nodes = list(range(text_start, text_start + (sep0 - 1)))
            node_types.extend([0] * (sep0 - 1))

        # 인접 텍스트 토큰 연결
        for i in range(len(text_nodes) - 1):
            _add_bidir(text_nodes[i], text_nodes[i + 1], EDGE_TYPE["t_t"], edge_src, edge_dst, edge_type)

        # ── (5) 비디오/오디오 노드 ─────────────────────────────────
        v_idx = -1; a_idx = -1
        if video_emb is not None and video_emb.numel():
            v_idx = len(node_feats)
            node_feats.append(video_emb.to(hs.dtype).squeeze(0))
            node_types.append(2)
            for t in text_nodes:
                _add_bidir(t, v_idx, EDGE_TYPE["t_v"], edge_src, edge_dst, edge_type)
        if audio_emb is not None and audio_emb.numel():
            a_idx = len(node_feats)
            node_feats.append(audio_emb.to(hs.dtype).squeeze(0))
            node_types.append(3)
            for t in text_nodes:
                _add_bidir(t, a_idx, EDGE_TYPE["t_a"], edge_src, edge_dst, edge_type)

        # ── (6) 브리징: 헤드(문장 언급 토큰) ↔ 관계 ↔ 테일 ──────────
        # 6-1) 문장 내 언급(identify_entities) → QID
        mentions = self.ekm.identify_entities(utterance_text)  # ["taylor swift", ...]
        mention2qid: Dict[str, str] = {}
        for m in mentions:
            q = self.ekm.label2qid.get(m.lower())
            if q:
                mention2qid[m] = q

        # 6-2) mention 토큰 앵커 (WP 인덱스)
        text_range = (1, sep0) if sep0 > 1 else (1, 1)
        mention_anchor_wp: Dict[str, List[int]] = {}
        for m in mention2qid.keys():
            wp_idxs = _find_token_indices_for_phrase(utterance_text, m, wp_offsets, text_range)
            if wp_idxs:
                mention_anchor_wp[m] = wp_idxs

        # 6-3) 트리플 순회: 관계/테일 임베딩 만들고 연결
        tail_qid_to_node: Dict[str, int] = {}
        knowledge_nodes_added = 0

        for k, ((h_qid, pid, t_qid), (h_label, r_label, t_label)) in enumerate(zip(triples_keep, triples_lbl)):
            if kb_range is None:
                break

            # 관계/테일의 WP 인덱스(전체 시퀀스 기준) 회수
            (rel_span, tail_span) = spans_per_triple[k]
            rel_wp = self._charspan_to_wp_indices(rel_span, wp_offsets, kb_range)
            tail_wp = self._charspan_to_wp_indices(tail_span, wp_offsets, kb_range)
            if not rel_wp or not tail_wp:
                continue

            # 관계/테일 임베딩(평균)
            rel_vec = hs[torch.tensor(rel_wp)].mean(dim=0)
            if t_qid in tail_qid_to_node:
                tail_node = tail_qid_to_node[t_qid]
            else:
                tail_vec = hs[torch.tensor(tail_wp)].mean(dim=0)
                tail_node = len(node_feats)
                node_feats.append(tail_vec)
                node_types.append(1)  # knowledge
                tail_qid_to_node[t_qid] = tail_node
                knowledge_nodes_added += 1

            rel_node = len(node_feats)
            node_feats.append(rel_vec)
            node_types.append(1)  # knowledge
            knowledge_nodes_added += 1

            # 헤드 앵커(문장 토큰들) 찾기: QID 매칭 우선, 없으면 라벨 문자열로 fallback
            head_anchor_nodes: List[int] = []
            for m, mq in mention2qid.items():
                if mq == h_qid and m in mention_anchor_wp:
                    for ti in mention_anchor_wp[m]:
                        li = ti - 1
                        if 0 <= li < len(text_nodes):
                            head_anchor_nodes.append(text_nodes[li])
            if not head_anchor_nodes:
                wp_idxs = _find_token_indices_for_phrase(utterance_text, h_label, wp_offsets, text_range)
                for ti in wp_idxs:
                    li = ti - 1
                    if 0 <= li < len(text_nodes):
                        head_anchor_nodes.append(text_nodes[li])

            # 엣지 연결: (헤드 토큰들) ↔ (관계), (관계) ↔ (테일)
            for hnode in sorted(set(head_anchor_nodes)):
                _add_bidir(hnode, rel_node, EDGE_TYPE["t_k"], edge_src, edge_dst, edge_type)
            _add_bidir(rel_node, tail_node, EDGE_TYPE["k_k"], edge_src, edge_dst, edge_type)

        # ── (7) PyG Data 구성 ──────────────────────────────────────
        x = torch.stack(node_feats) if node_feats else torch.empty(0, D)
        edge_index = (
            torch.tensor([edge_src, edge_dst], dtype=torch.long)
            if edge_src else torch.empty(2, 0, dtype=torch.long)
        )
        edge_type_t = (
            torch.tensor(edge_type, dtype=torch.long)
            if edge_type else torch.empty(0, dtype=torch.long)
        )

        data = Data(x=x, edge_index=edge_index, edge_type=edge_type_t)
        data.node_meta = {
            "text_nodes": len(text_nodes),
            "knowledge_nodes": knowledge_nodes_added,  # 관계+테일
            "video_nodes": 1 if v_idx != -1 else 0,
            "audio_nodes": 1 if a_idx != -1 else 0,
            "triples": len(triples_lbl),
        }
        data.utt_meta = {
            "first_text_node": text_nodes[0] if text_nodes else -1,
            "last_text_node": text_nodes[-1] if text_nodes else -1,
        }

        data.node_type     = torch.tensor(node_types, dtype=torch.int8)
        data.idx_text      = torch.tensor(text_nodes, dtype=torch.long) if text_nodes else torch.empty(0, dtype=torch.long)
        data.idx_knowledge = torch.tensor([i for i, t in enumerate(node_types) if t == 1], dtype=torch.long)
        data.idx_video     = torch.tensor([v_idx], dtype=torch.long) if v_idx != -1 else torch.empty(0, dtype=torch.long)
        data.idx_audio     = torch.tensor([a_idx], dtype=torch.long) if a_idx != -1 else torch.empty(0, dtype=torch.long)
        return data

# ─────────────────────────────────────────────
# 그래프 병합
# ─────────────────────────────────────────────
def merge_graph(prev_graph: Optional[Data], current_graph: Data) -> Data:
    if prev_graph is None or prev_graph.x.numel() == 0:
        return current_graph

    # 노드 피처 병합
    x_merged = torch.cat([prev_graph.x, current_graph.x], dim=0)

    # 엣지 병합(오프셋)
    num_prev_nodes = prev_graph.x.size(0)
    edge_index_current_offset = current_graph.edge_index + num_prev_nodes
    edge_index_merged = torch.cat([prev_graph.edge_index, edge_index_current_offset], dim=1)
    edge_type_merged  = torch.cat([prev_graph.edge_type, current_graph.edge_type])

    # utterance 링크(양방향)
    if prev_graph.utt_meta["last_text_node"] != -1 and current_graph.utt_meta["first_text_node"] != -1:
        prev_last_node  = prev_graph.utt_meta["last_text_node"]
        curr_first_node = current_graph.utt_meta["first_text_node"] + num_prev_nodes
        utt_edge_index = torch.tensor([[prev_last_node, curr_first_node],
                                       [curr_first_node, prev_last_node]], dtype=torch.long)
        utt_edge_type  = torch.tensor([EDGE_TYPE["utt"], EDGE_TYPE["utt"]], dtype=torch.long)
        edge_index_merged = torch.cat([edge_index_merged, utt_edge_index], dim=1)
        edge_type_merged  = torch.cat([edge_type_merged,  utt_edge_type])

    # 비디오/오디오 크로스-utterance 연결
    def _idx_list(g: Data, attr: str, type_id: Optional[int], allow_meta_fallback: bool, offset: int) -> List[int]:
        t = getattr(g, attr, None)
        if t is not None and hasattr(t, "numel") and t.numel() > 0:
            return (t + offset).tolist()
        if type_id is not None and hasattr(g, "node_type") and g.node_type is not None and g.node_type.numel() > 0:
            idxs = torch.nonzero(g.node_type.to(torch.long) == int(type_id), as_tuple=True)[0]
            if idxs.numel() > 0:
                return (idxs + offset).tolist()
        if allow_meta_fallback and hasattr(g, "node_meta"):
            text_n = int(g.node_meta.get("text_nodes", 0))
            know_n = int(g.node_meta.get("knowledge_nodes", 0))
            has_v  = int(g.node_meta.get("video_nodes", 0)) == 1
            has_a  = int(g.node_meta.get("audio_nodes", 0)) == 1
            base = text_n + know_n
            if attr == "idx_video" and has_v:
                return [offset + base]
            if attr == "idx_audio" and has_a:
                return [offset + base + (1 if has_v else 0)]
        return []

    prev_v_list = _idx_list(prev_graph, "idx_video", 2, allow_meta_fallback=False, offset=0)
    prev_a_list = _idx_list(prev_graph, "idx_audio", 3, allow_meta_fallback=False, offset=0)
    curr_v_list = _idx_list(current_graph, "idx_video", 2, allow_meta_fallback=True,  offset=num_prev_nodes)
    curr_a_list = _idx_list(current_graph, "idx_audio", 3, allow_meta_fallback=True,  offset=num_prev_nodes)

    if prev_v_list and curr_v_list:
        pv = prev_v_list[-1]
        vv_ei = torch.tensor([[pv] * len(curr_v_list), curr_v_list], dtype=torch.long)
        vv_ei = torch.cat([vv_ei, vv_ei.flip(0)], dim=1)
        vv_et = torch.full((vv_ei.size(1),), EDGE_TYPE["v_v"], dtype=torch.long)
        edge_index_merged = torch.cat([edge_index_merged, vv_ei], dim=1)
        edge_type_merged  = torch.cat([edge_type_merged,  vv_et])

    if prev_a_list and curr_a_list:
        pa = prev_a_list[-1]
        aa_ei = torch.tensor([[pa] * len(curr_a_list), curr_a_list], dtype=torch.long)
        aa_ei = torch.cat([aa_ei, aa_ei.flip(0)], dim=1)
        aa_et = torch.full((aa_ei.size(1),), EDGE_TYPE["a_a"], dtype=torch.long)
        edge_index_merged = torch.cat([edge_index_merged, aa_ei], dim=1)
        edge_type_merged  = torch.cat([edge_type_merged,  aa_et])

    # 메타 병합
    node_meta_merged = {k: prev_graph.node_meta.get(k, 0) + current_graph.node_meta.get(k, 0)
                        for k in set(prev_graph.node_meta.keys()) | set(current_graph.node_meta.keys())}
    utt_meta_merged = {
        "first_text_node": prev_graph.utt_meta["first_text_node"],
        "last_text_node":  current_graph.utt_meta["last_text_node"] + num_prev_nodes,
    }

    merged_graph = Data(x=x_merged, edge_index=edge_index_merged, edge_type=edge_type_merged)
    merged_graph.node_meta = node_meta_merged
    merged_graph.utt_meta  = utt_meta_merged

    # node_type / idx_* 병합
    if hasattr(prev_graph, "node_type") or hasattr(current_graph, "node_type"):
        nt_prev = getattr(prev_graph, "node_type", None)
        nt_curr = getattr(current_graph, "node_type", None)
        if nt_prev is not None and nt_curr is not None:
            merged_graph.node_type = torch.cat([nt_prev, nt_curr], dim=0)
        elif nt_prev is not None:
            merged_graph.node_type = nt_prev
        elif nt_curr is not None:
            merged_graph.node_type = nt_curr

    def _cat_idx(prev_t: Optional[torch.Tensor], curr_t: Optional[torch.Tensor], offset: int):
        if prev_t is None and curr_t is None:
            return None
        if prev_t is None:
            return (curr_t + offset) if (curr_t is not None and curr_t.numel()) else curr_t
        if curr_t is None or not curr_t.numel():
            return prev_t
        return torch.cat([prev_t, curr_t + offset])

    merged_graph.idx_text      = _cat_idx(getattr(prev_graph, "idx_text", None),      getattr(current_graph, "idx_text", None),      num_prev_nodes)
    merged_graph.idx_knowledge = _cat_idx(getattr(prev_graph, "idx_knowledge", None), getattr(current_graph, "idx_knowledge", None), num_prev_nodes)
    merged_graph.idx_video     = _cat_idx(getattr(prev_graph, "idx_video", None),     getattr(current_graph, "idx_video", None),     num_prev_nodes)
    merged_graph.idx_audio     = _cat_idx(getattr(prev_graph, "idx_audio", None),     getattr(current_graph, "idx_audio", None),     num_prev_nodes)

    return merged_graph


[Graph builder] Using device: cuda


In [3]:
# -*- coding: utf-8 -*-
"""
Graph builder with external-knowledge bridging (text<->relation<->tail)
- Keeps REAL PID strings via pid_vocab
- Concatenate utterance + knowledge text for a single-pass encoding
- 16-bit model load (no 8-bit quantization)
- Edge types include t_k (text↔knowledge) and k_k (knowledge↔knowledge)
- Utilities for visualization and summaries
"""
import os
import re
import sys
import json
import sqlite3
import hashlib
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import pandas as pd

from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

import networkx as nx
import matplotlib.pyplot as plt

from transformers import (
    BertTokenizerFast,
    AutoTokenizer,
    AutoModelForCausalLM,
    # BitsAndBytesConfig,  # not used (we load at 16-bit now)
)

# ─────────────────────────────────────────────
# 환경
# ─────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Graph builder] Using device: {device}")

def _hash(key: str) -> str:
    return hashlib.md5(key.encode()).hexdigest() + ".pt"

def _to_device(data: Dict, device: torch.device) -> Dict:
    return {k: v.to(device) for k, v in data.items()}

def log_vram(stage: str, device: torch.device):
    if device.type == "cuda":
        torch.cuda.synchronize()
        dev_id = device.index if isinstance(device, torch.device) else torch.cuda.current_device()
        alloc = torch.cuda.memory_allocated(dev_id) / 1024**2
        reserved = torch.cuda.memory_reserved(dev_id) / 1024**2
        peak_a = torch.cuda.max_memory_allocated(dev_id) / 1024**2
        peak_r = torch.cuda.max_memory_reserved(dev_id) / 1024**2
        wasted = reserved - alloc
        print(f"[{stage:15s}] alloc: {alloc:6.1f} MB | reserved: {reserved:6.1f} MB | "
              f"peak_alloc: {peak_a:6.1f} MB | peak_reserved: {peak_r:6.1f} MB | wasted: {wasted:6.1f} MB")
        torch.cuda.reset_peak_memory_stats(dev_id)

# ─────────────────────────────────────────────
# 엣지 타입 (★ t_k, k_k 포함)
# ─────────────────────────────────────────────
EDGE_TYPE: Dict[str, int] = {
    "t_t": 0,  # text  ↔ text
    "v_v": 1,  # video ↔ video
    "a_a": 2,  # audio ↔ audio
    "t_v": 3,  # text  ↔ video
    "t_a": 4,  # text  ↔ audio
    "utt": 5,  # utterance ↔ utterance (for merging)
    "t_k": 6,  # text(언급 토큰) ↔ knowledge(관계)
    "k_k": 7,  # knowledge(관계) ↔ knowledge(테일 엔티티)
}

def _add_bidir(
    src: int,
    dst: int,
    etype: int,
    edge_src: List[int],
    edge_dst: List[int],
    edge_type: List[int],
):
    edge_src.extend([src, dst])
    edge_dst.extend([dst, src])
    edge_type.extend([etype, etype])

# ─────────────────────────────────────────────
# 외부 지식 모델 (라벨/특성 조회)
# ─────────────────────────────────────────────
class ExternalFinancialKnowledgeModel:
    def __init__(self,
                 wiki_db="data/wikidata_revisions.db",
                 speech_db="data/speech_segments.db"):
        self.wiki = sqlite3.connect(wiki_db); self.wiki.row_factory = sqlite3.Row
        self.speech = sqlite3.connect(speech_db); self.speech.row_factory = sqlite3.Row

        self.target_entities = self._load_target_entities()

        label_df = pd.read_sql(
            "SELECT DISTINCT qid, value AS label FROM labels WHERE lang='en';",
            self.wiki,
        )
        label_df = label_df[label_df["label"].str.lower().isin([n.lower() for n in self.target_entities])]
        label_df["id_int"] = pd.factorize(label_df["qid"])[0]

        self.qid2int = dict(zip(label_df["qid"], label_df["id_int"]))
        self.int2qid = {v: k for k, v in self.qid2int.items()}
        self.label2qid = {l.lower(): q for q, l in zip(label_df["qid"], label_df["label"])}

        def _name_to_pattern(name: str) -> str:
            name = name.strip().lower()
            if not name:
                return ""
            parts = re.split(r"\s+", name)
            parts = [re.escape(p) for p in parts if p]
            if not parts:
                return ""
            return r"\b" + r"\s+".join(parts) + r"\b"

        safe = [_name_to_pattern(n) for n in self.target_entities]
        safe = [p for p in safe if p]
        safe.sort(key=len, reverse=True)
        self._pat = re.compile("|".join(safe), flags=re.I)

    def _load_target_entities(self) -> List[str]:
        df = pd.read_sql("SELECT persons_found FROM video_metadata;", self.speech)
        names = set()
        for js in df["persons_found"]:
            if js:
                names.update(json.loads(js).keys())
        return list(names)

    def identify_entities(self, text: str) -> List[str]:
        t = re.sub(r"[^\w\s]", "", text.lower())
        return list({m.strip() for m in self._pat.findall(t)})

    def entities_to_id(self, ents: List[str]) -> List[int]:
        return [self.qid2int[self.label2qid[e.lower()]]
                for e in ents if e.lower() in self.label2qid]

    def qid_to_label(self, qid: str) -> str:
        row = self.wiki.execute(
            "SELECT value FROM labels WHERE qid=? AND lang='en' LIMIT 1",
            (qid,),
        ).fetchone()
        return row[0] if row else qid

    def pid_to_label(self, pid: str) -> str:
        row = self.wiki.execute(
            "SELECT label FROM property_labels WHERE pid=? LIMIT 1",
            (pid,),
        ).fetchone()
        return row[0] if row else pid

    @lru_cache(maxsize=32)
    def _graph_until(self, time_iso: str) -> Data:
        sql = (
            "SELECT c.qid subj, c.property pid, c.value_qid obj "
            "FROM claims c JOIN revisions r USING(qid,revision_id) "
            "WHERE r.timestamp<=?"
        )
        df = pd.read_sql(sql, self.wiki, params=(time_iso,))
        df = df[df["subj"].isin(self.qid2int) & df["obj"].isin(self.qid2int)]
        if df.empty:
            return Data()

        src = torch.tensor(df["subj"].map(self.qid2int).to_numpy(), dtype=torch.long)
        dst = torch.tensor(df["obj"].map(self.qid2int).to_numpy(), dtype=torch.long)

        # ★ factorize 하되, 고유 PID 목록을 vocab으로 보관
        codes, uniques = pd.factorize(df["pid"])
        rel = torch.tensor(codes, dtype=torch.long).view(-1, 1)

        G = Data(edge_index=torch.stack([src, dst]), edge_attr=rel)
        G.pid_vocab = [str(p) for p in uniques.tolist()]  # REAL PID strings, e.g., "P31"
        return G

    def acquire_related_external_knowledge(
        self, text: str, time_iso: str, add_reverse=True, add_self_loop=True
    ) -> Tuple[List[int], Data]:
        ids = self.entities_to_id(self.identify_entities(text))
        G = self._graph_until(time_iso)
        if not ids or G.edge_index.numel() == 0:
            return ids, Data()

        mask = (torch.isin(G.edge_index[0], torch.tensor(ids)) |
                torch.isin(G.edge_index[1], torch.tensor(ids)))
        ei, ea = G.edge_index[:, mask], G.edge_attr[mask]

        if add_reverse:
            ei = torch.cat([ei, ei.flip(0)], 1)
            ea = torch.cat([ea, ea], 0)
        if add_self_loop:
            loops = torch.tensor(ids, dtype=torch.long, device=ei.device)
            ei = torch.cat([ei, loops.unsqueeze(0).repeat(2, 1)], 1)
            ea = torch.cat([ea, torch.full((len(loops), 1), -1, dtype=torch.long, device=ei.device)], 0)

        sub = Data(edge_index=ei, edge_attr=ea)
        if hasattr(G, "pid_vocab"):
            sub.pid_vocab = G.pid_vocab  # ★ REAL pid vocab 전달
        return ids, sub

# ─────────────────────────────────────────────
# 텍스트 인코더 (BERT 오프셋 ↔ GLM 히든 매핑)
# ─────────────────────────────────────────────
def build_cross_map(wp_offsets: List[Tuple[int, int]],
                    glm_offsets: List[Tuple[int, int]]) -> List[List[int]]:
    mapping = [[] for _ in wp_offsets]
    p = 0
    for i, (ws, we) in enumerate(wp_offsets):
        while p < len(glm_offsets) and glm_offsets[p][1] <= ws:
            p += 1
        q = p
        while q < len(glm_offsets) and glm_offsets[q][0] < we:
            mapping[i].append(q)
            q += 1
    return mapping

class TextFeatureExtractor:
    def __init__(self):
        self.wp_tok = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.glm_tok: AutoTokenizer = AutoTokenizer.from_pretrained(
            "THUDM/chatglm2-6b", trust_remote_code=True
        )
        # ★ 16-bit 로드 (no 8-bit quantization)
        self.glm = AutoModelForCausalLM.from_pretrained(
            "THUDM/chatglm2-6b",
            trust_remote_code=True,
            torch_dtype=torch.float16,
            output_hidden_states=True,
            device_map="auto",
        )
        self.proj_down = nn.Linear(4096, 768, bias=False).to(device).half()
        self.proj_up = nn.Linear(768, 4096, bias=False).to(device).half()
        self.norm = nn.LayerNorm(768).to(device).half()
        for m in (self.proj_down, self.proj_up):
            nn.init.xavier_uniform_(m.weight)

    def _manual_offsets(self, text: str, toks: List[str]):
        norm = text.lower()
        p, off = 0, []
        for t in toks:
            tc = t.lstrip(" ")
            tc = tc if tc else " "
            j = norm.find(tc, p)
            j = j if j != -1 else p
            off.append((j, j + len(tc)))
            p = j + len(tc)
        return off

    @torch.no_grad()
    def encode(
        self,
        utterance_text: str,
        knowledge_triples: Optional[List[Tuple[str, str, str]]] = None,  # not used (we override)
        anchor_entities: Optional[List[str]] = None,
        knowledge_text_override: Optional[str] = None,
    ):
        # knowledge_text 구성 (override 우선)
        if knowledge_text_override is not None:
            knowledge_text = knowledge_text_override
        else:
            know_parts = []
            if anchor_entities:
                know_parts.append(" [ENT] ".join(map(str, anchor_entities)))
            if knowledge_triples:
                tmp = []
                for h, r, t in knowledge_triples:
                    tmp.append(f"{h} [R] {r} [T] {t} [TRI] ")
                know_parts.append("".join(tmp).strip())
            knowledge_text = " ".join(know_parts) if know_parts else None

        # BERT: WP 토큰/오프셋/타입
        wp = self.wp_tok(
            utterance_text,
            text_pair=knowledge_text,
            return_offsets_mapping=True,
            return_token_type_ids=True,
            return_tensors="pt",
        )
        wp = _to_device(wp, self.glm.device)
        wp_tokens = self.wp_tok.convert_ids_to_tokens(wp["input_ids"][0])
        wp_offsets = wp["offset_mapping"][0].tolist()
        token_types = wp["token_type_ids"][0].tolist()
        sep_idx = [i for i, tok in enumerate(wp_tokens) if tok == "[SEP]"]

        # knowledge 블록(WP 인덱스 범위들)
        knowledge_blocks: List[Tuple[int, int]] = []
        in_block, start = False, None
        for i, (tt, tok) in enumerate(zip(token_types, wp_tokens)):
            if tt == 1 and tok != "[SEP]":
                if not in_block:
                    start, in_block = i, True
            else:
                if in_block:
                    knowledge_blocks.append((start, i))
                    in_block = False
        if in_block:
            knowledge_blocks.append((start, len(wp_tokens)))

        # GLM: 합쳐진 문자열 기반 임베딩
        merged = utterance_text + (" [SEP] " + knowledge_text if knowledge_text else "")
        glm_ids = self.glm_tok.encode(merged, add_special_tokens=False)
        glm_enc = {
            "input_ids": torch.tensor([glm_ids], device=self.glm.device),
            "attention_mask": torch.ones(1, len(glm_ids), dtype=torch.long, device=self.glm.device),
        }
        glm_tokens = self.glm_tok.convert_ids_to_tokens(glm_ids)
        try:
            glm_offsets = self.glm_tok(
                merged, add_special_tokens=False, return_offsets_mapping=True
            )["offset_mapping"]
        except Exception:
            glm_offsets = self._manual_offsets(merged, glm_tokens)

        hidden4096 = self.glm(**glm_enc).hidden_states[-1][0]  # [L,4096]
        hid768 = self.norm(self.proj_down(hidden4096))         # [L,768]

        # WP ↔ GLM 토큰 정렬 매핑
        map_wp2glm = build_cross_map(wp_offsets, glm_offsets)
        max_idx = hid768.size(0)
        wp_emb = torch.stack([
            (
                hid768[torch.tensor(valid, dtype=torch.long, device=hid768.device)].mean(0)
                if (valid := [i for i in ids if i < max_idx])
                else torch.zeros(768, device=hid768.device)
            )
            for ids in map_wp2glm
        ])  # [N_wp,768]

        meta = {
            "wp_tokens": wp_tokens,
            "glm_tokens": glm_tokens,
            "map_wp2glm": map_wp2glm,
            "sep": sep_idx,
            "knowledge_blocks": knowledge_blocks,
            "wp_offsets": wp_offsets,
        }
        return wp_emb, meta

    def cleanup(self):
        print("TextFeatureExtractor 메모리 해제 중...")
        if hasattr(self, "glm"): del self.glm
        if hasattr(self, "proj_down"): del self.proj_down
        if hasattr(self, "proj_up"): del self.proj_up
        torch.cuda.empty_cache()
        print("TextFeatureExtractor 메모리 해제 완료")

# ─────────────────────────────────────────────
# 유틸: 언급 문자열이 덮는 WP 토큰 찾기(발화 구간)
# ─────────────────────────────────────────────
def _find_token_indices_for_phrase(
    text: str,
    phrase: str,
    wp_offsets: List[Tuple[int, int]],
    text_range: Tuple[int, int],  # (wp_start, wp_end) in WP index
) -> List[int]:
    t_norm = text.lower()
    p_norm = phrase.lower().strip()
    if not p_norm:
        return []
    spans = []
    start = 0
    while True:
        j = t_norm.find(p_norm, start)
        if j == -1:
            break
        spans.append((j, j + len(p_norm)))
        start = j + len(p_norm)

    wp_start, wp_end = text_range
    hits = []
    for (s, e) in spans:
        for i in range(wp_start, wp_end):
            ws, we = wp_offsets[i]
            if we > s and ws < e:  # overlap
                hits.append(i)
    return sorted(set(hits))

# ─────────────────────────────────────────────
# GraphBuilder (브리징 방식: text↔relation↔tail)
# ─────────────────────────────────────────────
class GraphBuilder:
    """발화 그래프 + 외부 지식 브리징(관계/테일을 라벨 임베딩으로)"""
    MAX_KG = 5  # 외부지식 최대 부착 개수

    def __init__(self, *, merge_anchor: bool = False):
        self.merge_anchor = merge_anchor
        self.text_enc = TextFeatureExtractor()
        self.ekm = ExternalFinancialKnowledgeModel()

    def __del__(self):
        if hasattr(self, "text_enc"):
            self.text_enc.cleanup()
            del self.text_enc
        print("GraphBuilder 메모리 해제 완료")
        log_vram("del", device)

    @staticmethod
    def _build_knowledge_text_and_spans(
        triples_lbl: List[Tuple[str, str, str]]
    ) -> Tuple[str, List[Tuple[Tuple[int, int], Tuple[int, int]]]]:
        """
        triples_lbl: [(head_label, rel_label, tail_label), ...]
        return:
          knowledge_text: "h [R] r [T] t [TRI] h [R] r [T] t [TRI] ..."
          spans: [ ((rel_start,rel_end),(tail_start,tail_end)), ... ]  (char offsets in knowledge_text)
        """
        parts: List[str] = []
        spans: List[Tuple[Tuple[int, int], Tuple[int, int]]] = []
        pos = 0
        for (h, r, t) in triples_lbl:
            parts.append(h); pos += len(h)
            sepR = " [R] "; parts.append(sepR); pos += len(sepR)

            rel_start = pos
            parts.append(r); pos += len(r)
            rel_end = pos

            sepT = " [T] "; parts.append(sepT); pos += len(sepT)

            tail_start = pos
            parts.append(t); pos += len(t)
            tail_end = pos

            sepTri = " [TRI] "; parts.append(sepTri); pos += len(sepTri)

            spans.append(((rel_start, rel_end), (tail_start, tail_end)))

        return "".join(parts).strip(), spans

    @staticmethod
    def _charspan_to_wp_indices(
        span: Tuple[int, int],
        wp_offsets: List[Tuple[int, int]],
        wp_block: Tuple[int, int],  # knowledge block in WP index (start,end)
    ) -> List[int]:
        """knowledge_text의 char span → 해당하는 '전체 시퀀스'의 WP 인덱스 목록"""
        s, e = span
        b0, b1 = wp_block
        hit = []
        for i in range(b0, b1):
            ws, we = wp_offsets[i]
            if we > s and ws < e:  # overlap
                hit.append(i)
        return hit

    def build(
        self,
        utterance_text: str,
        time_iso: str,
        video_emb: Optional[torch.Tensor] = None,
        audio_emb: Optional[torch.Tensor] = None,
    ) -> Data:
        # (1) 외부지식 수집
        ent_ids, ek_sub = self.ekm.acquire_related_external_knowledge(utterance_text, time_iso)

        triples_qpt: List[Tuple[str, str, str]] = []  # (h_qid, pid, t_qid) with REAL pid
        if ek_sub.num_edges:
            pid_vocab = getattr(ek_sub, "pid_vocab", None)
            if pid_vocab is None:
                print("[WARN] ek_sub.pid_vocab 이 없습니다. PID 매핑이 불가합니다.")
            for s, d, (code,) in zip(*ek_sub.edge_index, ek_sub.edge_attr):
                pid = pid_vocab[code] if pid_vocab is not None else f"P{int(code)}"
                triples_qpt.append(
                    (self.ekm.int2qid[s.item()], pid, self.ekm.int2qid[d.item()])
                )

        # (2) QID/PID → 라벨
        triples_lbl: List[Tuple[str, str, str]] = []
        triples_keep: List[Tuple[str, str, str]] = []
        for (hq, pid, tq) in triples_qpt:
            h_label = self.ekm.qid_to_label(hq)
            r_label = self.ekm.pid_to_label(pid)
            t_label = self.ekm.qid_to_label(tq)
            triples_lbl.append((h_label, r_label, t_label))
            triples_keep.append((hq, pid, tq))

        if len(triples_lbl) > self.MAX_KG:
            triples_lbl = triples_lbl[: self.MAX_KG]
            triples_keep = triples_keep[: self.MAX_KG]

        # (3) knowledge_text + 스팬
        knowledge_text, spans_per_triple = self._build_knowledge_text_and_spans(triples_lbl)

        # (4) 인코딩(하나의 패스)
        wp_emb, meta = self.text_enc.encode(
            utterance_text=utterance_text,
            knowledge_triples=None,  # override 사용
            anchor_entities=[self.ekm.int2qid[i] for i in ent_ids],
            knowledge_text_override=knowledge_text if triples_lbl else None,
        )
        if wp_emb.is_cuda:
            wp_emb = wp_emb.cpu()

        hs = wp_emb; D = hs.size(1)
        sep0 = meta["sep"][0] if meta["sep"] else -1
        wp_offsets = meta.get("wp_offsets", [])
        kb = meta.get("knowledge_blocks", [])
        kb_range = kb[0] if kb else None

        node_feats: List[torch.Tensor] = []
        edge_src: List[int] = []
        edge_dst: List[int] = []
        edge_type: List[int] = []
        node_types: List[int] = []  # 0=text, 1=knowledge, 2=video, 3=audio

        # (5) 텍스트 토큰
        text_nodes: List[int] = []
        if sep0 > 1:
            text_start = len(node_feats)
            node_feats.extend([hs[i] for i in range(1, sep0)])  # [CLS]=0 제외
            text_nodes = list(range(text_start, text_start + (sep0 - 1)))
            node_types.extend([0] * (sep0 - 1))
            for i in range(len(text_nodes) - 1):
                _add_bidir(text_nodes[i], text_nodes[i+1], EDGE_TYPE["t_t"], edge_src, edge_dst, edge_type)

        # (6) 비/오
        v_idx = -1; a_idx = -1
        if video_emb is not None and video_emb.numel():
            v_idx = len(node_feats)
            node_feats.append(video_emb.to(hs.dtype).squeeze(0))
            node_types.append(2)
            for t in text_nodes:
                _add_bidir(t, v_idx, EDGE_TYPE["t_v"], edge_src, edge_dst, edge_type)
        if audio_emb is not None and audio_emb.numel():
            a_idx = len(node_feats)
            node_feats.append(audio_emb.to(hs.dtype).squeeze(0))
            node_types.append(3)
            for t in text_nodes:
                _add_bidir(t, a_idx, EDGE_TYPE["t_a"], edge_src, edge_dst, edge_type)

        # (7) 브리징
        mentions = self.ekm.identify_entities(utterance_text)
        mention2qid = {m: self.ekm.label2qid[m.lower()] for m in mentions if m.lower() in self.ekm.label2qid}

        text_range = (1, sep0) if sep0 > 1 else (1, 1)
        mention_anchor_wp: Dict[str, List[int]] = {}
        for m in mention2qid.keys():
            wp_idxs = _find_token_indices_for_phrase(utterance_text, m, wp_offsets, text_range)
            if wp_idxs:
                mention_anchor_wp[m] = wp_idxs

        tail_qid_to_node: Dict[str, int] = {}
        knowledge_nodes_added = 0

        # 디버그 카운터
        dbg_rel_edges = 0
        dbg_tail_edges = 0
        dbg_trip_ok = 0
        dbg_trip_skip = 0

        for k, ((h_qid, pid, t_qid), (h_label, r_label, t_label)) in enumerate(zip(triples_keep, triples_lbl)):
            if kb_range is None:
                dbg_trip_skip += 1
                continue

            rel_span, tail_span = spans_per_triple[k]
            rel_wp = self._charspan_to_wp_indices(rel_span, wp_offsets, kb_range)
            tail_wp = self._charspan_to_wp_indices(tail_span, wp_offsets, kb_range)

            if not rel_wp or not tail_wp:
                dbg_trip_skip += 1
                continue

            rel_vec = hs[torch.tensor(rel_wp)].mean(dim=0)
            if t_qid in tail_qid_to_node:
                tail_node = tail_qid_to_node[t_qid]
            else:
                tail_vec = hs[torch.tensor(tail_wp)].mean(dim=0)
                tail_node = len(node_feats)
                node_feats.append(tail_vec)
                node_types.append(1)
                tail_qid_to_node[t_qid] = tail_node
                knowledge_nodes_added += 1

            rel_node = len(node_feats)
            node_feats.append(rel_vec)
            node_types.append(1)
            knowledge_nodes_added += 1

            # 헤드 앵커 탐색
            head_anchor_nodes: List[int] = []
            for m, mq in mention2qid.items():
                if mq == h_qid and m in mention_anchor_wp:
                    for ti in mention_anchor_wp[m]:
                        li = ti - 1
                        if 0 <= li < len(text_nodes):
                            head_anchor_nodes.append(text_nodes[li])
            if not head_anchor_nodes:
                # 라벨 fallback
                wp_idxs = _find_token_indices_for_phrase(utterance_text, h_label, wp_offsets, text_range)
                for ti in wp_idxs:
                    li = ti - 1
                    if 0 <= li < len(text_nodes):
                        head_anchor_nodes.append(text_nodes[li])

            # 연결
            for hnode in sorted(set(head_anchor_nodes)):
                _add_bidir(hnode, rel_node, EDGE_TYPE["t_k"], edge_src, edge_dst, edge_type)
                dbg_rel_edges += 2
            _add_bidir(rel_node, tail_node, EDGE_TYPE["k_k"], edge_src, edge_dst, edge_type)
            dbg_tail_edges += 2
            dbg_trip_ok += 1

        if dbg_trip_ok == 0 and len(triples_lbl) > 0:
            print(f"[KG] triples={len(triples_lbl)} 있었지만, 스팬 정렬 실패로 브리징 0개 (skip={dbg_trip_skip})")
        else:
            print(f"[KG] bridged_triples={dbg_trip_ok}  t_k_edges={dbg_rel_edges}  k_k_edges={dbg_tail_edges}")

        # (8) Data
        x = torch.stack(node_feats) if node_feats else torch.empty(0, D)
        edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long) if edge_src else torch.empty(2, 0, dtype=torch.long)
        edge_type_t = torch.tensor(edge_type, dtype=torch.long) if edge_type else torch.empty(0, dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, edge_type=edge_type_t)
        data.node_meta = {
            "text_nodes": len(text_nodes),
            "knowledge_nodes": knowledge_nodes_added,
            "video_nodes": 1 if v_idx != -1 else 0,
            "audio_nodes": 1 if a_idx != -1 else 0,
            "triples": len(triples_lbl),
        }
        data.utt_meta = {
            "first_text_node": text_nodes[0] if text_nodes else -1,
            "last_text_node": text_nodes[-1] if text_nodes else -1,
        }
        data.node_type     = torch.tensor(node_types, dtype=torch.int8)
        data.idx_text      = torch.tensor(text_nodes, dtype=torch.long) if text_nodes else torch.empty(0, dtype=torch.long)
        data.idx_knowledge = torch.tensor([i for i, t in enumerate(node_types) if t == 1], dtype=torch.long)
        data.idx_video     = torch.tensor([v_idx], dtype=torch.long) if v_idx != -1 else torch.empty(0, dtype=torch.long)
        data.idx_audio     = torch.tensor([a_idx], dtype=torch.long) if a_idx != -1 else torch.empty(0, dtype=torch.long)

        # 디버그 정보
        data.debug_kg = {
            "triples_lbl": triples_lbl,
            "kb_range": kb_range,
            "mentions": mentions,
            "mention_anchor_wp": mention_anchor_wp,
            "bridged_ok": dbg_trip_ok,
            "bridged_skip": dbg_trip_skip,
        }
        return data

# ─────────────────────────────────────────────
# 그래프 병합 (비디오/오디오 크로스 연결 포함)
# ─────────────────────────────────────────────
def merge_graph(prev_graph: Optional[Data], current_graph: Data) -> Data:
    if prev_graph is None or prev_graph.x.numel() == 0:
        return current_graph

    # 노드 피처 병합
    x_merged = torch.cat([prev_graph.x, current_graph.x], dim=0)

    # 엣지 병합(오프셋)
    num_prev_nodes = prev_graph.x.size(0)
    edge_index_current_offset = current_graph.edge_index + num_prev_nodes
    edge_index_merged = torch.cat([prev_graph.edge_index, edge_index_current_offset], dim=1)
    edge_type_merged  = torch.cat([prev_graph.edge_type, current_graph.edge_type])

    # utterance 링크(양방향)
    if prev_graph.utt_meta["last_text_node"] != -1 and current_graph.utt_meta["first_text_node"] != -1:
        prev_last_node  = prev_graph.utt_meta["last_text_node"]
        curr_first_node = current_graph.utt_meta["first_text_node"] + num_prev_nodes
        utt_edge_index = torch.tensor([[prev_last_node, curr_first_node],
                                       [curr_first_node, prev_last_node]], dtype=torch.long)
        utt_edge_type  = torch.tensor([EDGE_TYPE["utt"], EDGE_TYPE["utt"]], dtype=torch.long)
        edge_index_merged = torch.cat([edge_index_merged, utt_edge_index], dim=1)
        edge_type_merged  = torch.cat([edge_type_merged,  utt_edge_type])

    # 비디오/오디오 크로스-utterance 연결
    def _idx_list(g: Data, attr: str, type_id: Optional[int], allow_meta_fallback: bool, offset: int) -> List[int]:
        t = getattr(g, attr, None)
        if t is not None and hasattr(t, "numel") and t.numel() > 0:
            return (t + offset).tolist()

        if type_id is not None and hasattr(g, "node_type") and g.node_type is not None and g.node_type.numel() > 0:
            idxs = torch.nonzero(g.node_type.to(torch.long) == int(type_id), as_tuple=True)[0]
            if idxs.numel() > 0:
                return (idxs + offset).tolist()

        if allow_meta_fallback and hasattr(g, "node_meta"):
            text_n = int(g.node_meta.get("text_nodes", 0))
            know_n = int(g.node_meta.get("knowledge_nodes", 0))
            has_v  = int(g.node_meta.get("video_nodes", 0)) == 1
            has_a  = int(g.node_meta.get("audio_nodes", 0)) == 1
            base = text_n + know_n
            if attr == "idx_video" and has_v:
                return [offset + base]
            if attr == "idx_audio" and has_a:
                return [offset + base + (1 if has_v else 0)]
        return []

    prev_v_list = _idx_list(prev_graph, "idx_video", 2, allow_meta_fallback=False, offset=0)
    prev_a_list = _idx_list(prev_graph, "idx_audio", 3, allow_meta_fallback=False, offset=0)

    curr_v_list = _idx_list(current_graph, "idx_video", 2, allow_meta_fallback=True,  offset=num_prev_nodes)
    curr_a_list = _idx_list(current_graph, "idx_audio", 3, allow_meta_fallback=True,  offset=num_prev_nodes)

    if prev_v_list and curr_v_list:
        pv = prev_v_list[-1]
        vv_ei = torch.tensor([[pv] * len(curr_v_list), curr_v_list], dtype=torch.long)
        vv_ei = torch.cat([vv_ei, vv_ei.flip(0)], dim=1)
        vv_et = torch.full((vv_ei.size(1),), EDGE_TYPE["v_v"], dtype=torch.long)
        edge_index_merged = torch.cat([edge_index_merged, vv_ei], dim=1)
        edge_type_merged  = torch.cat([edge_type_merged,  vv_et])

    if prev_a_list and curr_a_list:
        pa = prev_a_list[-1]
        aa_ei = torch.tensor([[pa] * len(curr_a_list), curr_a_list], dtype=torch.long)
        aa_ei = torch.cat([aa_ei, aa_ei.flip(0)], dim=1)
        aa_et = torch.full((aa_ei.size(1),), EDGE_TYPE["a_a"], dtype=torch.long)
        edge_index_merged = torch.cat([edge_index_merged, aa_ei], dim=1)
        edge_type_merged  = torch.cat([edge_type_merged,  aa_et])

    # 메타 병합
    node_meta_merged = {k: prev_graph.node_meta.get(k, 0) + current_graph.node_meta.get(k, 0)
                        for k in set(prev_graph.node_meta.keys()) | set(current_graph.node_meta.keys())}
    utt_meta_merged = {
        "first_text_node": prev_graph.utt_meta["first_text_node"],
        "last_text_node":  current_graph.utt_meta["last_text_node"] + num_prev_nodes,
    }

    merged_graph = Data(x=x_merged, edge_index=edge_index_merged, edge_type=edge_type_merged)
    merged_graph.node_meta = node_meta_merged
    merged_graph.utt_meta  = utt_meta_merged

    # node_type / idx_* 병합
    if hasattr(prev_graph, "node_type") or hasattr(current_graph, "node_type"):
        nt_prev = getattr(prev_graph, "node_type", None)
        nt_curr = getattr(current_graph, "node_type", None)
        if nt_prev is not None and nt_curr is not None:
            merged_graph.node_type = torch.cat([nt_prev, nt_curr], dim=0)
        elif nt_prev is not None:
            merged_graph.node_type = nt_prev
        elif nt_curr is not None:
            merged_graph.node_type = nt_curr

    def _cat_idx(prev_t: Optional[torch.Tensor], curr_t: Optional[torch.Tensor], offset: int):
        if prev_t is None and curr_t is None:
            return None
        if prev_t is None:
            return (curr_t + offset) if (curr_t is not None and curr_t.numel()) else curr_t
        if curr_t is None or not curr_t.numel():
            return prev_t
        return torch.cat([prev_t, curr_t + offset])

    merged_graph.idx_text      = _cat_idx(getattr(prev_graph, "idx_text", None),      getattr(current_graph, "idx_text", None),      num_prev_nodes)
    merged_graph.idx_knowledge = _cat_idx(getattr(prev_graph, "idx_knowledge", None), getattr(current_graph, "idx_knowledge", None), num_prev_nodes)
    merged_graph.idx_video     = _cat_idx(getattr(prev_graph, "idx_video", None),     getattr(current_graph, "idx_video", None),     num_prev_nodes)
    merged_graph.idx_audio     = _cat_idx(getattr(prev_graph, "idx_audio", None),     getattr(current_graph, "idx_audio", None),     num_prev_nodes)

    return merged_graph

# ─────────────────────────────────────────────
# 요약 & 시각화
# ─────────────────────────────────────────────
def print_graph_summary(prefix: str, data: Data):
    n = int(data.x.size(0)) if hasattr(data, "x") and data.x is not None else 0
    e = int(data.edge_index.size(1)) if hasattr(data, "edge_index") else 0
    text_n = data.node_meta.get("text_nodes", 0) if hasattr(data, "node_meta") else 0
    know_n = data.node_meta.get("knowledge_nodes", 0) if hasattr(data, "node_meta") else 0
    video_n = data.node_meta.get("video_nodes", 0) if hasattr(data, "node_meta") else 0
    audio_n = data.node_meta.get("audio_nodes", 0) if hasattr(data, "node_meta") else 0

    et_cnt = {}
    if hasattr(data, "edge_type") and data.edge_type is not None and data.edge_type.numel():
        for t in data.edge_type.cpu().tolist():
            et_cnt[t] = et_cnt.get(t, 0) + 1

    def _c(tid): return et_cnt.get(tid, 0)
    print(
        f"[{prefix}] nodes={n}, edges={e} | "
        f"text={text_n}, knowledge={know_n}, video={video_n}, audio={audio_n} | "
        f"t_t={_c(0)}, v_v={_c(1)}, a_a={_c(2)}, t_v={_c(3)}, t_a={_c(4)}, utt={_c(5)}, "
        f"t_k={_c(6)}, k_k={_c(7)}"
    )

def visualize_graph(data: Data, out_path: Path, layout="spring", seed=42, label_mode=None):
    G = to_networkx(data, to_undirected=True)
    pos = nx.spring_layout(G, seed=seed) if layout=="spring" else (
        nx.kamada_kawai_layout(G) if layout=="kamada" else nx.spectral_layout(G)
    )
    plt.figure(figsize=(12,9), dpi=250)

    # node colors (node_type 우선)
    if hasattr(data, "node_type") and data.node_type is not None and data.node_type.numel():
        nts = data.node_type.cpu().tolist()
        color_map = {0:"#1f77b4", 1:"#ff7f0e", 2:"#2ca02c", 3:"#d62728"}  # T/K/V/A
        node_colors = [color_map.get(t, "#9e9e9e") for t in nts]
    else:
        node_colors = []
        for cnt,color in [
            (data.node_meta.get("text_nodes",0),"#1f77b4"),
            (data.node_meta.get("knowledge_nodes",0),"#ff7f0e"),
            (data.node_meta.get("video_nodes",0),"#2ca02c"),
            (data.node_meta.get("audio_nodes",0),"#d62728"),
        ]:
            node_colors.extend([color]*cnt)

    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=120, edgecolors="black", linewidths=0.4, alpha=0.9)

    # edge colors (★ t_k, k_k 포함)
    legend_handles = []
    if hasattr(data, "edge_type") and data.edge_type is not None and data.edge_index.numel() > 0:
        et = data.edge_type.cpu().tolist()
        edges = list(zip(data.edge_index[0].tolist(), data.edge_index[1].tolist()))
        edge_type_map = {
            0: ("#999999", "text↔text"),
            1: ("#2ca02c", "video↔video"),
            2: ("#d62728", "audio↔audio"),
            3: ("#17becf", "text↔video"),
            4: ("#9467bd", "text↔audio"),
            5: ("#ff7f0e", "utterance link"),
            6: ("#e377c2", "text↔knowledge"),   # ★
            7: ("#7f7f7f", "knowledge↔knowledge"),  # ★
        }
        from matplotlib.lines import Line2D
        for t,(color,label) in edge_type_map.items():
            mask = [i for i,tt in enumerate(et) if tt==t]
            if not mask: continue
            edges_t = [(edges[i][0], edges[i][1]) for i in mask]
            nx.draw_networkx_edges(G, pos, edgelist=edges_t, edge_color=color, width=1.0, alpha=0.7)
            legend_handles.append(Line2D([0],[0], color=color, lw=2, label=label))
    elif G.number_of_edges() > 0:
        nx.draw_networkx_edges(G, pos, edge_color="#cccccc", width=0.6)

    # labels
    if label_mode is not None:
        idx_labels = {i:str(i) for i in range(G.number_of_nodes())}
        if hasattr(data, "node_type") and data.node_type is not None and data.node_type.numel():
            typemap = {0:"T",1:"K",2:"V",3:"A"}
            type_labels = {i: typemap.get(int(data.node_type[i]), "") for i in range(len(data.node_type))}
        else:
            type_labels = {}
        if label_mode=="index": labels = idx_labels
        elif label_mode=="type": labels = type_labels
        elif label_mode=="both": labels = {i: f"{idx_labels[i]}({type_labels.get(i,'')})" for i in idx_labels}
        else: labels=None
        if labels:
            nx.draw_networkx_labels(G, pos, labels=labels, font_size=7, font_color="white",
                                    bbox=dict(boxstyle="round,pad=0.15", fc="black", ec="none", alpha=0.6))

    if legend_handles:
        plt.legend(handles=legend_handles, loc="upper right", fontsize=8, frameon=False)

    plt.axis("off"); plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", pad_inches=0.05); plt.close()
    print(f"saved: {out_path}")


[Graph builder] Using device: cuda


In [4]:
# ============================================
# 단계별 실행(한 영상: 68UYBhIte3U)
# ============================================
video_id = "68UYBhIte3U"

# ▶ 모달 임베딩 캐시 경로: 한 곳으로 통일해서 사용하세요.
#   (기존에 임베딩을 'cache/'에 저장해뒀다면 cache 로, 'debug/'라면 debug 로 바꾸세요)
cache_dir = Path("cache")        # ← 필요시 "debug"로 변경
out_dir   = Path("viz_mpl")      # 그림 저장 폴더
out_dir.mkdir(exist_ok=True)

# DB 연결 및 업로드 시각
speech_db_path = "data/speech_segments.db"
conn = sqlite3.connect(speech_db_path); conn.row_factory = sqlite3.Row
meta = conn.execute("SELECT published_date FROM video_metadata WHERE video_id=?", (video_id,)).fetchone()
assert meta is not None, f"{video_id} 의 published_date가 없습니다."
upload_time = meta["published_date"]

# 발화 로드(총 6개)
rows = conn.execute(
    "SELECT segment_id, script FROM speech_segments WHERE video_id=? ORDER BY start_time",
    (video_id,)
).fetchall()
print(f"video_id={video_id} | segments={len(rows)}")

# GraphBuilder 1회 생성 (node_type/idx_*를 채우도록 사전에 수정되어 있어야 함)
graph_builder = GraphBuilder()

video_graph = None
for i, r in enumerate(rows, start=1):
    seg_id = r["segment_id"]
    text   = (r["script"] or "").strip()
    if not text:
        print(f" - skip empty text: seg={seg_id}")
        continue

    print(f"\n[Step {i}] seg_id={seg_id}")

    # 모달 임베딩 로드(있으면)
    v_path = cache_dir / _hash(f"vid::{seg_id}")
    a_path = cache_dir / _hash(f"aud::{seg_id}")
    v_emb = torch.load(v_path) if v_path.exists() else None
    a_emb = torch.load(a_path) if a_path.exists() else None

    # 1) 서브그래프 구축
    subg = graph_builder.build(
        utterance_text=text,
        time_iso=upload_time,
        video_emb=v_emb,
        audio_emb=a_emb
    )

    # 서브그래프 요약+시각화 (node_type 기반)
    print_graph_summary(f"Subgraph {i}", subg)
    visualize_graph(
        subg,
        out_dir / f"{video_id}_sub_{i:02d}_{seg_id}.png",
        layout="spring",
        seed=42,
        label_mode="index"   # 노드 인덱스 표기
    )

    # 2) 병합
    video_graph = merge_graph(video_graph, subg)

    # 병합 결과 요약+시각화 (node_type 기반)
    print_graph_summary(f"Merged {i}", video_graph)
    visualize_graph(
        video_graph,
        out_dir / f"{video_id}_merged_{i:02d}.png",
        layout="spring",
        seed=42,
        label_mode="both"    # 인덱스+타입 예: 12(T)
    )

conn.close()
print("\n완료: PNG가 저장되었습니다 →", out_dir.resolve())


video_id=68UYBhIte3U | segments=5


Loading checkpoint shards: 100%|██████████| 7/7 [00:07<00:00,  1.12s/it]



[Step 1] seg_id=68UYBhIte3U_001
[KG] bridged_triples=2  t_k_edges=8  k_k_edges=4
[Subgraph 1] nodes=26, edges=130 | text=20, knowledge=4, video=1, audio=1 | t_t=38, v_v=0, a_a=0, t_v=40, t_a=40, utt=0, t_k=8, k_k=4
saved: viz_mpl/68UYBhIte3U_sub_01_68UYBhIte3U_001.png
[Merged 1] nodes=26, edges=130 | text=20, knowledge=4, video=1, audio=1 | t_t=38, v_v=0, a_a=0, t_v=40, t_a=40, utt=0, t_k=8, k_k=4
saved: viz_mpl/68UYBhIte3U_merged_01.png

[Step 2] seg_id=68UYBhIte3U_002
[KG] bridged_triples=0  t_k_edges=0  k_k_edges=0
[Subgraph 2] nodes=14, edges=70 | text=12, knowledge=0, video=1, audio=1 | t_t=22, v_v=0, a_a=0, t_v=24, t_a=24, utt=0, t_k=0, k_k=0
saved: viz_mpl/68UYBhIte3U_sub_02_68UYBhIte3U_002.png
[Merged 2] nodes=40, edges=206 | text=32, knowledge=4, video=2, audio=2 | t_t=60, v_v=2, a_a=2, t_v=64, t_a=64, utt=2, t_k=8, k_k=4
saved: viz_mpl/68UYBhIte3U_merged_02.png

[Step 3] seg_id=68UYBhIte3U_003
[KG] bridged_triples=0  t_k_edges=0  k_k_edges=0
[Subgraph 3] nodes=25, edges=136 

In [5]:
# move_final_graphs.py
import hashlib
import shutil
from pathlib import Path
import pandas as pd

def _hash(key: str) -> str:
    return hashlib.md5(key.encode()).hexdigest() + ".pt"

def _unique_target(path: Path) -> Path:
    """remove/에 같은 파일명이 있으면 _1, _2 ...를 붙여 충돌 방지"""
    if not path.exists():
        return path
    stem, suffix = path.stem, path.suffix
    i = 1
    while True:
        candidate = path.with_name(f"{stem}_{i}{suffix}")
        if not candidate.exists():
            return candidate
        i += 1

def move_final_graphs(ready_csv="data/ready_videos.csv", cache_dir="cache", remove_dir="remove"):
    cache_dir = Path(cache_dir)
    remove_dir = Path(remove_dir)
    remove_dir.mkdir(exist_ok=True)

    # ready_videos.csv에서 video_id 목록 로드
    df = pd.read_csv(ready_csv)
    video_ids = df["video_id"].dropna().astype(str).tolist()

    moved, missing = 0, 0
    for vid in video_ids:
        fname = _hash(f"video_graph::{vid}")  # ← 최종 그래프 파일 이름 규칙
        src = cache_dir / fname
        if src.exists():
            dst = _unique_target(remove_dir / fname)
            print(f"Moving {src} -> {dst}")
            shutil.move(str(src), str(dst))
            moved += 1
        else:
            # 없을 수도 있으니 조용히 카운트만
            missing += 1

    print(f"\n완료: 이동 {moved}개, 누락 {missing}개 (목표 폴더: {remove_dir.resolve()})")

if __name__ == "__main__":
    move_final_graphs()


Moving cache/168d0ecc2ebee6bcfeb61942b4b99cfc.pt -> remove/168d0ecc2ebee6bcfeb61942b4b99cfc.pt
Moving cache/dab882c492672c56e40ed91ae9524150.pt -> remove/dab882c492672c56e40ed91ae9524150.pt
Moving cache/54bedc1f9fc1ab9135adec864bcf8699.pt -> remove/54bedc1f9fc1ab9135adec864bcf8699.pt
Moving cache/f0a7a99d90c1e15f421c5be1fb125ff6.pt -> remove/f0a7a99d90c1e15f421c5be1fb125ff6.pt
Moving cache/aeba414a8d8befb28b91b547219efb6e.pt -> remove/aeba414a8d8befb28b91b547219efb6e.pt
Moving cache/554bebb841ea2bddab79d7f60651d664.pt -> remove/554bebb841ea2bddab79d7f60651d664.pt
Moving cache/a54c536d7b4eb0f1699d116a82890b90.pt -> remove/a54c536d7b4eb0f1699d116a82890b90.pt
Moving cache/971dcb5404066c95d1a5cf357d364db6.pt -> remove/971dcb5404066c95d1a5cf357d364db6.pt
Moving cache/0d1273454e78546f988452d56e30fc0a.pt -> remove/0d1273454e78546f988452d56e30fc0a.pt
Moving cache/fb38b44decee333ca9e65ae98381ef63.pt -> remove/fb38b44decee333ca9e65ae98381ef63.pt
Moving cache/85b9e19908c1d1b9370b0fce05257d9e.pt -