In [1]:
# %% [markdown]
# # LLaVA Custom Vision Encoder / Projector Inference Notebook
# 사용자 정의 Reg‑Gated ViT‑L/14 + 맞춤 Projector 를 Vicuna‑7B 기반 LLaVA 모델에 적용하여
# – 모델 로드, 이미지 추론, Cross‑Attention 시각화 – 를 수행합니다.

# %% ---------------------------------------------------------------------------
# 0. 환경 설정 & 공통 라이브러리
# ---------------------------------------------------------------------------
import os, sys, warnings, shutil
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["CUDA_HOME"] = "/usr/local/cuda-12.4"
os.environ["LD_LIBRARY_PATH"] = (
    "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.4/lib64:"
    + os.environ.get("LD_LIBRARY_PATH", "")
)

import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np, cv2, math
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
import requests

from transformers import (
    AutoConfig, AutoTokenizer, BitsAndBytesConfig, TextStreamer
)
from llava.constants import (
    IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images, tokenizer_image_token, get_model_name_from_path
)
from llava.model.builder import load_pretrained_model

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
torch.backends.cuda.enable_flash_sdp(False)

# LLaVA 프로젝트 루트가 sys.path 에 없으면 추가
nb_root = os.getcwd()
if nb_root not in sys.path:
    sys.path.insert(0, nb_root)
    print("Added to sys.path:", nb_root)

# 사용자 정의 CLIP
from INFERclipregXGATED.model import VisionTransformer as CustomVisionTransformer
print("✔ CustomVisionTransformer imported.")

  from .autonotebook import tqdm as notebook_tqdm


[2025-05-13 08:52:36,363] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Added to sys.path: /home/ubuntu/Projects/regllava
✔ CustomVisionTransformer imported.


In [2]:
# %% ---------------------------------------------------------------------------
# 1. 사용자 설정
# ---------------------------------------------------------------------------
MODEL_PATH_LLAVA_CONFIG_AND_PROJECTOR = "./llava-v1.5-7b-local"
MODEL_BASE_LLM = "lmsys/vicuna-7b-v1.5"

CUSTOM_VISION_ENCODER_WEIGHTS_PATH = "./models/ViT-L-14-REG-GATED-balanced-ckpt12.safetensors"
CUSTOM_PROJECTOR_FILENAME = "mm_projector_2epoch.bin"

IMAGE_FILE_TO_PROCESS = "data/car.jpg"
USER_PROMPT = "Describe the where the main object is located in the image."

DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
LOAD_8BIT = False
LOAD_4BIT = False

MAX_NEW_TOKENS = 256
TEMPERATURE = 0.2
CONV_MODE = None          # "vicuna_v1" 등으로 지정 가능, None 이면 자동 추론

VISION_ENCODER_CONFIG = dict(
    image_resolution=224, patch_size=14,
    width=1024, layers=24, heads=16,
    output_dim=1024, num_registers=4
)
mm_vision_select_layer_val = -2
mm_projector_type_val      = "mlp2x_gelu"

In [3]:
# %% ---------------------------------------------------------------------------
# 2. 유틸리티 함수
# ---------------------------------------------------------------------------
def load_image(path_or_url: str) -> Image.Image:
    if path_or_url.startswith(("http://", "https://")):
        resp = requests.get(path_or_url); resp.raise_for_status()
        return Image.open(BytesIO(resp.content)).convert("RGB")
    if not os.path.exists(path_or_url):
        raise FileNotFoundError(path_or_url)
    return Image.open(path_or_url).convert("RGB")

def overlay_heatmap(hm: np.ndarray, base: Image.Image) -> np.ndarray:
    bgr = cv2.cvtColor(np.array(base), cv2.COLOR_RGB2BGR)
    h, w = bgr.shape[:2]
    hm = cv2.resize(hm.astype(np.float32), (w, h), interpolation=cv2.INTER_LINEAR)
    hm = (hm - hm.min()) / (hm.ptp() + 1e-8)
    hm_cm = cv2.applyColorMap((hm * 255).astype(np.uint8), cv2.COLORMAP_JET)
    mix  = cv2.addWeighted(bgr, 0.6, hm_cm, 0.4, 0)
    return cv2.cvtColor(mix, cv2.COLOR_BGR2RGB)

def infer_conv_mode(model_name: str) -> str:
    name = model_name.lower()
    if "llama-2"   in name: return "llava_llama_2"
    if "mistral"   in name: return "mistral_instruct"
    if "v1.6-34b"  in name: return "chatml_direct"
    if "v1"        in name: return "llava_v1"
    if "mpt"       in name: return "mpt"
    return "llava_v0"

In [None]:
# %% ---------------------------------------------------------------------------
# 3. 모델 로드 (builder.py 수정 버전 사용)
# ---------------------------------------------------------------------------
print("▶ Loading model...")
disable_torch_init()

model_arch_name = get_model_name_from_path(MODEL_PATH_LLAVA_CONFIG_AND_PROJECTOR)
tokenizer, model, image_processor, _ctx_len = load_pretrained_model(
    MODEL_PATH_LLAVA_CONFIG_AND_PROJECTOR,
    model_base=MODEL_BASE_LLM,
    model_name=model_arch_name,
    load_8bit=LOAD_8BIT, load_4bit=LOAD_4BIT,
    device=DEVICE,
    device_map={ "": DEVICE },          # ← 모든 서브모듈을 ***단일 GPU*** 로 고정
    attn_implementation="eager",
    torch_dtype=torch.float32
)
print("✅ Model loaded.")

# → MAIN_DEV 정의 (모델 파라미터가 올라간 디바이스)
MAIN_DEV = next(model.parameters()).device


# ─── Method B: pad_token을 eos와 구분하도록 special_tokens 추가 ─────────────────
if tokenizer.pad_token_id == tokenizer.eos_token_id:
    tokenizer.add_special_tokens({'pad_token': '<pad>'})
    model.resize_token_embeddings(len(tokenizer))
# ---------------------------------------------------------------------------
# 3‑1. 사용할 **단일 GPU** 확정 & 모듈·입력 tensor 통일
# ---------------------------------------------------------------------------
#  • Accelerate `device_map="auto"` 로드 시 모듈이 GPU‑0, GPU‑1 에 분산될 수 있음
#  • Vision Tower 출력 → mm_projector → LLM 으로 흐르므로
#    세 모듈·입력이 **한 GPU** 에 있어야 dtype/device 오류가 안 난다

MAIN_DEV = next(model.parameters()).device          # 보통 cuda:0

proj_path = os.path.join(MODEL_PATH_LLAVA_CONFIG_AND_PROJECTOR,
                         CUSTOM_PROJECTOR_FILENAME)
assert os.path.exists(proj_path), f"Projector file not found: {proj_path}"
raw_state = torch.load(proj_path, map_location="cpu")   # CPU 로 우선 읽기

def extract_mm_projector(sd):
    cleaned = {}
    for k, v in sd.items():
        if "mm_projector" not in k:
            continue                      # 다른 서브모듈은 무시
        k = k.split("mm_projector.", 1)[-1]   # → 0.weight …
        if k.startswith("module."):
            k = k[7:]                     # DataParallel prefix 제거
        cleaned[k] = v
    return cleaned

proj_state = extract_mm_projector(raw_state)
if not proj_state:
    raise ValueError("mm_projector keys가 체크포인트에 없습니다!")

# ② GPU · dtype 맞춰서 로드
model.get_model().mm_projector.to(device=MAIN_DEV, dtype=model.dtype)
model.get_model().mm_projector.load_state_dict(proj_state, strict=True)
print("✅  mm_projector weights loaded:", len(proj_state), "tensors")

# ① Vision Tower
vt = model.get_vision_tower()
# —— 1) Reg-Gated branch 강제 활성화 —— 
setattr(vt, "is_custom_reg_gated_clip", True)
# —— 2) config 에 우리가 원하는 레이어·프로젝터 세팅 심기 —— 
model.config.mm_vision_select_layer = mm_vision_select_layer_val
model.config.mm_projector_type      = mm_projector_type_val

# ==== 디버깅: CustomVisionTransformer 생성자 시그니처 확인 ====
import inspect
sig = inspect.signature(CustomVisionTransformer.__init__)
print("Debug: CustomVisionTransformer.__init__ signature:", sig)
print(CustomVisionTransformer.__init__.__doc__)
# ==== 디버깅 종료 ====

# —— 커스텀 Reg-Gated ViT-L/14 인코더로 교체 & 가중치 로드 ——
from safetensors.torch import load_file as load_safetensors
# VISION_ENCODER_CONFIG 키 이름을 생성자 시그니처(input_resolution, patch_size, width, layers, heads, output_dim, num_registers) 에 맞춰 매핑
custom_vt = CustomVisionTransformer(
    input_resolution = VISION_ENCODER_CONFIG['image_resolution'],
    patch_size       = VISION_ENCODER_CONFIG['patch_size'],
    width            = VISION_ENCODER_CONFIG['width'],
    layers           = VISION_ENCODER_CONFIG['layers'],
    heads            = VISION_ENCODER_CONFIG['heads'],
    output_dim       = VISION_ENCODER_CONFIG['output_dim'],
    num_registers    = VISION_ENCODER_CONFIG['num_registers'],
)

# ——— checkpoint 읽어서 모델에 맞게 필터링 후 로드 ———
state_v = load_safetensors(CUSTOM_VISION_ENCODER_WEIGHTS_PATH)
from collections import OrderedDict
filtered_state = OrderedDict()
# checkpoint key 에 'visual.' prefix 있으면 제거
for k, v in state_v.items():
    name = k[len("visual."):] if k.startswith("visual.") else k
    # 모델에 존재하고 모양이 정확히 같은 파라미터만 담기
    if name in custom_vt.state_dict() and custom_vt.state_dict()[name].shape == v.shape:
        filtered_state[name] = v
# strict=False 로 불일치 항목(크기/이름) 무시하고 로드
missing_keys, unexpected_keys = custom_vt.load_state_dict(filtered_state, strict=False)
print(f"✅ Custom vision encoder loaded: {len(filtered_state)} tensors")
if missing_keys:
    print("   → missing keys:", missing_keys)
if unexpected_keys:
    print("   → unexpected keys (ignored):", unexpected_keys)

custom_vt.to(device=MAIN_DEV, dtype=model.dtype)
# LLaVA wrapper 안의 vision_tower 속성에 덮어쓰기
vt.vision_tower = custom_vt
print("✅ Custom vision encoder loaded:", CUSTOM_VISION_ENCODER_WEIGHTS_PATH)

# 이후 기존 방식대로 single-GPU, dtype 통일
vt.to(device=MAIN_DEV, dtype=model.dtype)
if hasattr(vt, "vision_tower"):
    vt.vision_tower.to(device=MAIN_DEV, dtype=model.dtype)

# ② Projector
model.get_model().mm_projector.to(device=MAIN_DEV, dtype=model.dtype)

# ---------------------------------------------------------------------------

# conv 템플릿 선택
conv_key = CONV_MODE or infer_conv_mode(model_arch_name)
if conv_key not in conv_templates:
    raise ValueError(f"Unknown conversation mode '{conv_key}'")
conv = conv_templates[conv_key].copy()
roles = ("user", "assistant") if "mpt" in model_arch_name.lower() else conv.roles
print("Conversation mode →", conv_key)

▶ Loading model...
Attempting to load model with 8-bit quantization (LOAD_8BIT=True)
Loading LLaVA from base model...


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 17225.07it/s]


/home/ubuntu/Projects/regllava/INFERclipregXGATED 에서 INFERclipregXGATED 모듈을 성공적으로 임포트했습니다.


AttributeError: 'Parameter' object has no attribute 'SCB'

In [None]:
# %% ---------------------------------------------------------------------------
# 4. 이미지 로드 & 전처리
# ---------------------------------------------------------------------------
# --- HF 모델 출력 타입 사용 위해 import 추가 ---
from transformers.modeling_outputs import BaseModelOutputWithPooling
import torch # torch import 확인
import torch.nn as nn # nn import 확인
# -------------------------------------------
pil_img = load_image(IMAGE_FILE_TO_PROCESS)
img_size = pil_img.size
print("Loaded image:", img_size)

img_tensor = process_images([pil_img], image_processor, model.config)
# LLaVA forward 는 list[Tensor] 를 요구 ⇒ 항상 list 로 맞춘다
if not isinstance(img_tensor, list):
    img_tensor = [img_tensor]
# --- 초기 dtype 결정은 모델 로드 시 설정된 model.dtype을 따르도록 함 ---
# img_tensor = [t.to(MAIN_DEV, dtype=torch.float32) for t in img_tensor] # 여기서 강제 float32 변환 제거
img_tensor = [t.to(MAIN_DEV, dtype=model.dtype) for t in img_tensor]
print(f"Debug: Initial img_tensor dtype set to model.dtype: {model.dtype}")


# ==== 디버깅 코드 추가 시작 ====
# Vision tower 타입 확인
print("Debug: Vision tower type:", type(model.get_vision_tower()))
# --- 모델 및 주요 모듈의 현재 dtype 확인 ---
print(f"Debug: Initial model.dtype: {model.dtype}")
# vision_tower 객체가 None이 아니고, 파라미터가 있을 경우 첫 파라미터의 dtype 확인
vt_initial_dtype = "N/A"
if model.get_vision_tower() is not None and hasattr(model.get_vision_tower(), 'dtype'):
    vt_initial_dtype = model.get_vision_tower().dtype
elif model.get_vision_tower() is not None:
     try:
          vt_initial_dtype = next(model.get_vision_tower().parameters()).dtype
     except StopIteration:
          vt_initial_dtype = "No Params"
print(f"Debug: Initial vision_tower.dtype: {vt_initial_dtype}")
# mm_projector의 첫 레이어 가중치 dtype 확인 (존재 및 타입 체크 추가)
proj_initial_dtype = "N/A"
if hasattr(model.get_model(), 'mm_projector') and model.get_model().mm_projector is not None:
    projector_module = model.get_model().mm_projector
    if isinstance(projector_module, nn.Sequential) and len(projector_module) > 0 and hasattr(projector_module[0], 'weight'):
        proj_initial_dtype = projector_module[0].weight.dtype
    elif hasattr(projector_module, 'weight'): # 단일 레이어 경우
        proj_initial_dtype = projector_module.weight.dtype
    else:
        proj_initial_dtype = "Unknown structure"
print(f"Debug: Initial mm_projector.dtype (first layer weight): {proj_initial_dtype}")
# ---------------------------------------


with torch.no_grad():
    # VisionTower.forward는 Tensor[B, C, H, W]를 받습니다.
    # img_tensor가 list[Tensor]이면 하나의 배치로 합칩니다.
    if isinstance(img_tensor, list):
        # 만약 각 요소가 [C, H, W] 형태라면 stack, [1, C, H, W] 형태라면 cat
        t0 = img_tensor[0]
        if t0.dim() == 3:
            images_tensor = torch.stack(img_tensor, dim=0)
        else:
            images_tensor = torch.cat(img_tensor, dim=0)
    else:
        images_tensor = img_tensor
    # --- images_tensor dtype 확인 ---
    print(f"Debug: images_tensor shape: {images_tensor.shape}, dtype: {images_tensor.dtype}")


    # ——— 추가 디버깅 #1: 내부 vision_tower 모듈 확인 ———
    vt = model.get_vision_tower()
    if hasattr(vt, "vision_tower"):
        print("Debug: inner vision_tower type:", type(vt.vision_tower))
        # --- 내부 vision_tower dtype 확인 ---
        inner_vt_dtype = "N/A"
        if hasattr(vt.vision_tower, 'dtype'):
            inner_vt_dtype = vt.vision_tower.dtype
        else:
             try: inner_vt_dtype = next(vt.vision_tower.parameters()).dtype
             except StopIteration: inner_vt_dtype = "No Params"
        print(f"Debug: inner vision_tower dtype: {inner_vt_dtype}")
        # ---------------------------------
    else:
        print("Debug: vision_tower has no .vision_tower attr")

    # ——— 추가 디버깅 #2: config 값 확인 ———
    print("Debug: config.mm_vision_select_layer =", getattr(model.config, "mm_vision_select_layer", None))
    print("Debug: config.mm_projector_type      =", getattr(model.config, "mm_projector_type", None))


    # --- 직접 custom ViT 호출 (이전과 동일하나, 입력 dtype 확인) ---
    vt_local = model.get_vision_tower()
    # --- vt_local.vision_tower의 dtype 확인 ---
    print(f"Debug: dtype of vt_local.vision_tower before call: {vt_local.vision_tower.dtype if hasattr(vt_local.vision_tower, 'dtype') else 'N/A'}")
    # ---------------------------------------
    print("Debug: Calling custom vision tower directly...")
    # --- 입력 images_tensor의 dtype을 vt_local.vision_tower의 dtype에 맞춰 전달 ---
    # 이는 실제 CLIPVisionTower.forward 내부의 동작과 유사하게 만듦
    try:
         target_vt_dtype = next(vt_local.vision_tower.parameters()).dtype
    except StopIteration: # 파라미터 없는 경우 대비
         target_vt_dtype = vt_initial_dtype if vt_initial_dtype != "N/A" and vt_initial_dtype != "No Params" else torch.float32 # fallback
    print(f"Debug: Matching input tensor dtype to target_vt_dtype: {target_vt_dtype}")
    outputs_vt_direct = vt_local.vision_tower(images_tensor.to(target_vt_dtype), output_hidden_states=True) # hidden_states 얻기 위해 True
    # --------------------------------------------------------------------

    if not isinstance(outputs_vt_direct, BaseModelOutputWithPooling) or not hasattr(outputs_vt_direct, 'last_hidden_state'):
         raise TypeError(f"Expected BaseModelOutputWithPooling, but got {type(outputs_vt_direct)}")

    # --- last_hidden_state 추출 및 dtype 확인 ---
    vt_hidden_states_direct = outputs_vt_direct.last_hidden_state
    print(f"Debug (direct vt): Extracted last_hidden_state shape: {tuple(vt_hidden_states_direct.shape)}, dtype: {vt_hidden_states_direct.dtype}")

    # --- 중간 레이어 특징 추출 (디버깅용, 실제 로직 모방) ---
    hidden_states_tuple = outputs_vt_direct.hidden_states
    selected_layer_features = None
    selected_patch_features = None # 초기화 추가

    # --- hidden_states_tuple 유효성 검사 ---
    if hidden_states_tuple is not None and isinstance(hidden_states_tuple, tuple) and len(hidden_states_tuple) > 23: # 인덱스 23 접근 가능 확인
        selected_layer_features = hidden_states_tuple[23] # -2 layer (index 23)
        print(f"Debug (direct vt): Extracted layer -2 features shape: {selected_layer_features.shape}, dtype: {selected_layer_features.dtype}")
        # 여기서 패치 슬라이싱 수행 (디버깅 목적)
        num_registers = 4 # 예시 값 (실제로는 self.num_registers 사용해야 하나 여기서는 하드코딩)
        num_patches = 256 # 예시 값 (실제로는 self.num_patches 사용해야 하나 여기서는 하드코딩)
        start_index = 1 + num_registers
        end_index = start_index + num_patches
        if selected_layer_features.shape[1] >= end_index:
             selected_patch_features = selected_layer_features[:, start_index:end_index]
             print(f"Debug (direct vt): Sliced layer -2 patch features shape: {selected_patch_features.shape}, dtype: {selected_patch_features.dtype}")
             # --- NaN/Inf 및 값 범위 확인 ---
             print(f"  Checking sliced patch feature values:")
             # --- .float() 캐스팅 추가하여 다양한 입력 dtype 처리 ---
             selected_patch_features_float = selected_patch_features.float()
             has_nan = torch.isnan(selected_patch_features_float).any()
             has_inf = torch.isinf(selected_patch_features_float).any()
             print(f"    Has NaN: {has_nan}")
             print(f"    Has Inf: {has_inf}")
             if not has_nan and not has_inf:
                 print(f"    Min: {selected_patch_features_float.min().item():.4f}")
                 print(f"    Max: {selected_patch_features_float.max().item():.4f}")
                 print(f"    Mean: {selected_patch_features_float.mean().item():.4f}")
                 print(f"    Std: {selected_patch_features_float.std().item():.4f}")
             # -----------------------------
        else:
             print("Warning: Not enough tokens in selected layer feature to slice patches.")
             # selected_patch_features 는 None 유지
    else:
        print(f"Warning: Could not extract layer -2 features for detailed check. hidden_states_tuple is None or length {len(hidden_states_tuple) if hidden_states_tuple is not None else 'N/A'}")
        # selected_layer_features 는 None 유지

    # --- !!! 프로젝터 테스트 시 float32 강제 !!! ---
    try:
        projector = model.get_model().mm_projector
        # --- 프로젝터 유효성 확인 ---
        if projector is None:
            print("Error: mm_projector is None.")
        else:
            # --- 프로젝터를 명시적으로 float32로 변환 ---
            projector_f32 = projector.to(torch.float32)
            # --- 변환 후 dtype 확인 ---
            proj_f32_dtype = "N/A"
            if isinstance(projector_f32, nn.Sequential) and len(projector_f32) > 0 and hasattr(projector_f32[0], 'weight'):
                 proj_f32_dtype = projector_f32[0].weight.dtype
            elif hasattr(projector_f32, 'weight'):
                 proj_f32_dtype = projector_f32.weight.dtype
            print(f"Debug: Projector explicitly cast to dtype: {proj_f32_dtype}")

            # --- 사용할 입력 특징 선택 (계층적) ---
            # 1순위: 슬라이싱된 패치 특징
            # 2순위: 선택된 레이어 전체 특징
            # 3순위: 최종 레이어 전체 특징 (last_hidden_state)
            proj_input_tensor = None
            input_source = "None"
            if selected_patch_features is not None:
                proj_input_tensor = selected_patch_features
                input_source = "Sliced layer -2 patches"
                print(f"Debug: Using '{input_source}' for projector test.")
            elif selected_layer_features is not None:
                proj_input_tensor = selected_layer_features
                input_source = "Full layer -2"
                print(f"Debug: Using '{input_source}' for projector test (patch slice failed or not available).")
            elif vt_hidden_states_direct is not None:
                 proj_input_tensor = vt_hidden_states_direct
                 input_source = "Last hidden state"
                 print(f"Debug: Using '{input_source}' for projector test (layer -2 extraction failed or not available).")
            else:
                 print("Error: No valid input tensor found for projector test.")
                 proj_input_tensor = None # 확실히 None 처리

            if proj_input_tensor is not None:
                # --- 입력 텐서도 명시적으로 float32로 변환 ---
                proj_input_tensor_f32 = proj_input_tensor.to(torch.float32)
                print(f"Debug: Projector input tensor (from '{input_source}') explicitly cast to dtype: {proj_input_tensor_f32.dtype}")

                # 프로젝터 forward 실행 (float32 프로젝터와 float32 입력 사용)
                proj_out_debug = projector_f32(proj_input_tensor_f32)
                print(f"Debug: Projector output dtype: {proj_out_debug.dtype}") # 출력 dtype 확인
                # --- 출력 값 확인 (NaN/Inf 포함) ---
                # 출력도 .float()로 변환하여 계산 (이미 float32지만 명시적으로)
                proj_out_debug_float = proj_out_debug.float()
                print(
                    f"Debug: Projector output (float32 forced) | "
                    f"Input source: '{input_source}', Input shape: {proj_input_tensor_f32.shape}, Output shape: {proj_out_debug_float.shape} | "
                    f"sum={torch.nansum(proj_out_debug_float).item():.4f}, " # NaN 무시하고 합계 계산
                    f"mean={torch.nanmean(proj_out_debug_float).item():.4f}, " # NaN 무시하고 평균 계산
                    f"std={torch.nan_to_num(proj_out_debug_float, nan=0.0).std().item():.4f}" # NaN을 0으로 바꿔 std 계산
                )
                has_inf_out = torch.isinf(proj_out_debug_float).any()
                has_nan_out = torch.isnan(proj_out_debug_float).any() # NaN도 확인
                print(f"  Output has Inf: {has_inf_out}")
                print(f"  Output has NaN: {has_nan_out}") # NaN 결과 출력
                if not has_inf_out and not has_nan_out:
                     # 정상 범위 값 확인
                     try:
                          print(f"  Output Min: {proj_out_debug_float.min().item():.4f}")
                          print(f"  Output Max: {proj_out_debug_float.max().item():.4f}")
                     except RuntimeError as e_minmax: # min/max 연산 에러 처리 (예: 빈 텐서)
                          print(f"  Could not calculate Min/Max for output: {e_minmax}")
                # Inf 발생 시 추가 정보
                if has_inf_out:
                     inf_indices = torch.isinf(proj_out_debug_float).nonzero(as_tuple=True)
                     print(f"    Inf found at {len(inf_indices[0])} locations. First few indices: {[(inf_indices[0][i].item(), inf_indices[1][i].item(), inf_indices[2][i].item()) for i in range(min(5, len(inf_indices[0])))]}")
                     # Inf가 발생한 입력값 확인 시도 (매우 근사적)
                     try:
                          problematic_input_sample = proj_input_tensor_f32[inf_indices[0][0], inf_indices[1][0]].detach().cpu().numpy()
                          print(f"    Sample input vector near first Inf location (approx): mean={problematic_input_sample.mean():.4f}, std={problematic_input_sample.std():.4f}, min={problematic_input_sample.min():.4f}, max={problematic_input_sample.max():.4f}")
                     except Exception as e_inf_input:
                          print(f"    Could not get input sample near Inf: {e_inf_input}")


            else:
                print("Debug: Skipping projector debug check due to missing input tensor.")
    except Exception as e:
        print(f"Error during projector debug check: {e}")


# ==== 디버깅 코드 추가 종료 ====

# --- dtype mismatch 방지: 모델 전체를 float32로 변환 (선택적이지만 권장) ---
# 모델 로드 시 torch_dtype=torch.float32 를 사용했으므로 이 부분은 주석 처리 가능
print(f"\nCasting the entire model to float32...")
model = model.to(torch.float32)
# 다시 한번 dtype 확인
print(f"\nVerifying final model dtypes...")
print(f"Debug: Final model.dtype: {model.dtype}")
vt_final_dtype = "N/A"
if model.get_vision_tower() is not None and hasattr(model.get_vision_tower(), 'dtype'): vt_final_dtype = model.get_vision_tower().dtype
elif model.get_vision_tower() is not None:
     try: vt_final_dtype = next(model.get_vision_tower().parameters()).dtype
     except StopIteration: vt_final_dtype = "No Params"
print(f"Debug: Final Vision_tower dtype: {vt_final_dtype}")
proj_final_dtype = "N/A"
if hasattr(model.get_model(), 'mm_projector') and model.get_model().mm_projector is not None:
    projector_module = model.get_model().mm_projector
    if isinstance(projector_module, nn.Sequential) and len(projector_module) > 0 and hasattr(projector_module[0], 'weight'): proj_final_dtype = projector_module[0].weight.dtype
    elif hasattr(projector_module, 'weight'): proj_final_dtype = projector_module.weight.dtype
    else: proj_final_dtype = "Unknown structure"
print(f"Debug: Final mm_projector dtype (first layer weight): {proj_final_dtype}")

In [None]:
# %% ---------------------------------------------------------------------------
# 5. 프롬프트 구성
# ---------------------------------------------------------------------------
if model.config.mm_use_im_start_end:
    prompt_user = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{USER_PROMPT}"
else:
    prompt_user = f"{DEFAULT_IMAGE_TOKEN}\n{USER_PROMPT}"

conv.append_message(roles[0], prompt_user)
conv.append_message(roles[1], None)
full_prompt = conv.get_prompt()

print("--- Prompt to tokenizer ---")
print(full_prompt)

input_ids = tokenizer_image_token(
    full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0).to(model.device)

# --- 디버그 가드: <image> 토큰이 꼭 포함됐는지 확인 ---
assert (input_ids == IMAGE_TOKEN_INDEX).sum().item() > 0, \
       "IMAGE_TOKEN_INDEX (-200) 가 input_ids 에 없습니다. <image> 토큰 치환 실패!"

attention_mask = torch.ones_like(input_ids)

In [None]:
# %% ---------------------------------------------------------------------------
# 6. (옵션) Cross‑Attention Heatmaps
# ---------------------------------------------------------------------------
VISUALIZE = True
if VISUALIZE:
    # ----------------------------------------------------------------------
    # 6‑A.  Patch → Patch  (vision‑tower 직후) heat‑map
    # ----------------------------------------------------------------------
    with torch.no_grad():
        outs_vt = model(
            input_ids=input_ids,
            images=img_tensor,
            image_sizes=[img_size],
            output_attentions=True,
            return_dict=True,
        )
    attn_vt = outs_vt.attentions[-1].sum(dim=1)[0].cpu()       # (seq, seq)

    # 이미지 패치 구간 인덱스 계산 (generate 전에 미리 저장)
    img_tok_pos = (input_ids[0] == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0].item()
    patch_len   = attn_vt.size(1) - img_tok_pos                # 전체 패치 수
    patch_vec   = attn_vt[img_tok_pos + 1,                     # Q: 첫 패치
                          img_tok_pos : img_tok_pos + patch_len]

    L = patch_vec.shape[0]
    grid = math.ceil(math.sqrt(L))
    if grid * grid != L:
        patch_vec = F.interpolate(
            patch_vec.float()[None,None],
            size=grid * grid,
            mode="linear",
            align_corners=False
        )[0,0].to(patch_vec.dtype)
    patch_map = patch_vec.float().reshape(grid, grid)
    heatA = F.interpolate(
        patch_map[None,None],
        size=pil_img.size[::-1],
        mode="bilinear",
        align_corners=False
    )[0,0].cpu().numpy()
    plt.figure(figsize=(6,6))
    plt.title("Patch → Patch attention")
    plt.imshow(overlay_heatmap(heatA, pil_img)); plt.axis("off"); plt.show()

In [None]:
# %% ---------------------------------------------------------------------------
# 7. 텍스트 생성
# ---------------------------------------------------------------------------
print(f"\n[{roles[0]}] {USER_PROMPT}")
print(f"[{roles[1]}] ", flush=True)

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
    inputs=input_ids,
    attention_mask=attention_mask,
    images=img_tensor,
    image_sizes=[img_size],
    do_sample=TEMPERATURE > 0,
    temperature=TEMPERATURE,
    max_new_tokens=MAX_NEW_TOKENS,
    streamer=streamer,
    use_cache=True,
    pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
    output_attentions=True,
    return_dict_in_generate=True
)

with torch.inference_mode():
    out_gen = model.generate(**gen_kwargs)

In [None]:
# ---------------------------------------------------------------------------
# 8. (옵션) 프롬프트 + 생성 토큰까지 포함한 Text → Patch heat‑map
# ---------------------------------------------------------------------------
if VISUALIZE:
    # out_gen.sequences 는 <image> 토큰이 빠져 있으므로,
    # 원본 input_ids 와 생성 토큰(gen_only)을 이어 붙여 완전한 시퀀스 복원
    gen_only   = out_gen.sequences[:, input_ids.size(1):]
    seqs_full  = torch.cat([input_ids, gen_only], dim=1)

    with torch.no_grad():
        outs_full = model(
            input_ids=seqs_full,
            images=img_tensor,
            image_sizes=[img_size],
            output_attentions=True,
            return_dict=True,
        )

    full_attn = outs_full.attentions[-1].sum(dim=1)[0].cpu()   # (seq, seq)

    # 프롬프트 / 생성 토큰 (텍스트 영역) ➜ 이미지 패치 구간
    text_start = img_tok_pos + patch_len
    text_end   = full_attn.size(1)
    txt2img_vec = full_attn[text_start:text_end,
                            img_tok_pos : img_tok_pos + patch_len].mean(dim=0)
    # ── 수정: NaN 방지 ───────────────────────────────────────────
    txt2img_vec = torch.nan_to_num(txt2img_vec, nan=0.0)
    # ───────────────────────────────────────────────────────────────

    L   = txt2img_vec.shape[0]
    grid = math.ceil(math.sqrt(L))
    if grid * grid != L:
        txt2img_vec = F.interpolate(
            txt2img_vec.float()[None,None],
            size=grid * grid,
            mode="linear",
            align_corners=False
        )[0,0].to(txt2img_vec.dtype)
    txt_map = txt2img_vec.float().reshape(grid, grid)
    heatB = F.interpolate(
        txt_map[None,None],
        size=pil_img.size[::-1],
        mode="bilinear",
        align_corners=False
    )[0,0].cpu().numpy()
    plt.figure(figsize=(6,6))
    plt.title("Text → Patch attention (prompt + generation)")
    plt.imshow(overlay_heatmap(heatB, pil_img)); plt.axis("off"); plt.show()

# %% [markdown]
# ---
# **Inference & Visualization Complete.**
