In [9]:
#!/usr/bin/env python3
# coding: utf-8
"""
Enhanced BC model diagnostic script.
Features:
 - robust checkpoint loading and partial-parameter copying
 - detailed layer-by-layer statistics
 - auto-insert a proprio adapter linear when input dim mismatch detected (fixes mat1/mat2 shape error)
 - supports npz samples / env / fallback random obs
 - scaler support (Standardizer.load if available)
"""

import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path

# ---------- repo root detection (兼容 notebook/script) ----------
try:
    REPO_ROOT = Path(__file__).resolve().parents[3]
except NameError:
    REPO_ROOT = Path.cwd().resolve().parents[1]

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

print(f"Using repo root: {REPO_ROOT}")

# ---------- try imports from project ----------
BCPolicy = None
FrankaGym = None
Standardizer = None
try:
    from wdy_file.useful_scripts.bc_pipeline import BCPolicy, Standardizer  # type: ignore
    Standardizer = Standardizer
except Exception as e:
    try:
        from wdy_file.useful_scripts.bc_pipeline import BCPolicy  # type: ignore
    except Exception as e2:
        BCPolicy = None
        print("警告：无法导入 BCPolicy:", e2)

try:
    from wdy_file.wdy_assemble_gym import FrankaGym
except Exception:
    FrankaGym = None

# ---------- checkpoint loader ----------
def load_checkpoint(ckpt_path, device):
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"checkpoint not found: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)
    state = None
    model_kwargs = {}
    if isinstance(ckpt, dict):
        # try common keys
        state = ckpt.get('model') or ckpt.get('model_state_dict') or ckpt.get('state_dict') or None
        model_kwargs = ckpt.get('model_kwargs', {}) or {}
    else:
        state = ckpt
    # strip "module." prefix if present
    if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
        new = {}
        for k,v in state.items():
            nk = k[len("module."): ] if k.startswith("module.") else k
            new[nk] = v
        state = new
    return state, model_kwargs, ckpt

# ---------- helper: summary of state dict vs model dict ----------
def compare_state_and_model(state_dict, model):
    ms = model.state_dict()
    matches = []
    mismatches = []
    for k_m, v_m in ms.items():
        if state_dict is None:
            mismatches.append((k_m, v_m.shape, None))
            continue
        v = state_dict.get(k_m, None)
        if v is None:
            mismatches.append((k_m, v_m.shape, None))
        else:
            if v.shape == v_m.shape:
                matches.append((k_m, v_m.shape))
            else:
                mismatches.append((k_m, v_m.shape, v.shape))
    return matches, mismatches

# ---------- build model and attempt robust load ----------
def build_model_from_kwargs(model_kwargs, state_dict, device):
    if BCPolicy is None:
        raise RuntimeError("BCPolicy class not available for constructing model.")
    if not model_kwargs:
        raise RuntimeError("checkpoint 缺少 model_kwargs，请提供训练时的 model 参数（input_dim/output_dim/hidden_fusion 等）")
    # map known keys
    input_dim = model_kwargs.get('input_dim') or model_kwargs.get('proprio_dim') or model_kwargs.get('proprio_input_dim') or None
    output_dim = model_kwargs.get('output_dim') or model_kwargs.get('action_dim') or 7
    hidden_fusion = model_kwargs.get('hidden_fusion', (512,512,256))

    if input_dim is None:
        print("警告：无法从 model_kwargs 推断 input_dim/proprio_dim，继续构造模型但后续可能需要 adapter。")
        input_dim = 10  # placeholder

    # construct model (try-catch)
    model = BCPolicy(input_dim=int(input_dim), hidden_fusion=hidden_fusion, output_dim=int(output_dim))
    model.to(device)

    if state_dict is not None:
        ms = model.state_dict()
        # exact matches
        load_keys = {k: v for k, v in state_dict.items() if (k in ms and v.shape == ms[k].shape)}
        if load_keys:
            ms.update(load_keys)
            model.load_state_dict(ms, strict=False)
            print(f"加载 {len(load_keys)} 个完全匹配的参数。")
        else:
            # partial copy
            partial = {}
            for k,v in state_dict.items():
                if k in ms:
                    v_m = ms[k]
                    if v.ndim == v_m.ndim:
                        mins = tuple(min(a,b) for a,b in zip(v.shape, v_m.shape))
                        if any(m==0 for m in mins):
                            continue
                        new = v_m.clone()
                        slices = tuple(slice(0,m) for m in mins)
                        new[slices] = v[slices]
                        partial[k] = new
            if partial:
                ms.update(partial)
                model.load_state_dict(ms, strict=False)
                print(f"部分拷贝 {len(partial)} 参数（重叠区域）。")
            else:
                print("未加载 checkpoint 参数（找不到匹配形状），返回随机初始化模型。")
    # show comparison
    if state_dict is not None:
        matches, mismatches = compare_state_and_model(state_dict, model)
        print(f"state/model 匹配项: {len(matches)}, 不匹配项: {len(mismatches)} (展示前10项)")
        for m in mismatches[:10]:
            print("  mismatch:", m)
    return model

# ---------- wrapper model that applies adapter to proprio input ----------
class ModelWithAdapter(nn.Module):
    def __init__(self, model, adapter: nn.Module, adapter_name="proprio_adapter"):
        super().__init__()
        self.model = model
        self.adapter = adapter
        self.adapter_name = adapter_name

    def forward(self, front, hand, proprio):
        # allow numpy / tensor
        if isinstance(proprio, np.ndarray):
            proprio = torch.tensor(proprio, dtype=torch.float32, device=next(self.model.parameters()).device)
        # if adapter exists, apply it
        if self.adapter is not None:
            # adapter expects shape (B, D)
            proprio = self.adapter(proprio)
        return self.model(front, hand, proprio)

# ---------- try to auto-wrap with adapter if forward fails due to matmul shape ----------
def wrap_model_with_adapter_if_needed(model, sample_front, sample_hand, sample_proprio, device):
    """
    Try a forward pass. If it throws a matmul shape mismatch that indicates
    proprio length != model expected; infer expected_in from first Linear layer
    (heuristic) and build adapter: nn.Linear(actual_proprio_dim, expected_in).
    Returns (model_maybe_wrapped, adapter_created_flag).
    """
    model = model.eval()
    sample_front = sample_front.to(device)
    sample_hand = sample_hand.to(device)
    sample_proprio = sample_proprio.to(device)
    try:
        with torch.no_grad():
            _ = model(sample_front, sample_hand, sample_proprio)
        return model, False
    except RuntimeError as e:
        msg = str(e)
        print("初次前向时捕获 RuntimeError:", msg)
        # try to detect matmul shape error
        if "mat1 and mat2 shapes cannot be multiplied" in msg or "mat1 and mat2" in msg:
            # infer actual proprio dim
            actual_dim = int(sample_proprio.shape[1])
            # heuristic: find a linear layer whose in_features is small (likely expects proprio)
            candidate_in = None
            for name, m in model.named_modules():
                if isinstance(m, nn.Linear):
                    in_f = m.weight.shape[1]
                    # prefer smaller in_f relative to actual_dim, or first found
                    if candidate_in is None or abs(in_f - actual_dim) < abs(candidate_in - actual_dim):
                        candidate_in = int(in_f)
                        candidate_name = name
            if candidate_in is None:
                print("无法找到线性层来推断期望的 proprio 维度，请手动检查模型结构。")
                raise
            print(f"检测到实际 proprio dim={actual_dim}，模型可能期望 dim={candidate_in}（推断于层 '{candidate_name}'）。")
            # create adapter
            adapter = nn.Linear(actual_dim, candidate_in)
            # init adapter weights sensibly (small)
            nn.init.xavier_uniform_(adapter.weight, gain=0.01)
            if adapter.bias is not None:
                nn.init.constant_(adapter.bias, 0.0)
            adapter = adapter.to(device)
            wrapped = ModelWithAdapter(model, adapter)
            try:
                with torch.no_grad():
                    _ = wrapped(sample_front, sample_hand, sample_proprio)
                print("已插入 proprio adapter 并成功前向。")
                return wrapped, True
            except Exception as e2:
                print("插入 adapter 后仍然失败，错误：", e2)
                raise
        else:
            # other runtime error: re-raise
            raise

# ---------- preprocessing obs (mostly from your code, enhanced) ----------
def preprocess_obs_for_model(obs, device, scaler=None, target_img_hw=(128,128)):
    # 取第一帧
    arr = obs['front_rgb'][0]  # shape (N,128,128,3)
    front_camera_rgb = torch.tensor(arr,dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2).to(device)
    arr = obs['wrist_rgb'][0]  # shape (N, 128,128,3)
    hand_camera_rgb = torch.tensor(arr,dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2).to(device)
    # 取 proprio obs
    # 1. 按训练时处理观测
    for k,v in obs.items():
        if k.endswith('contact_forces'):
            contact_forces = v

    proprio_vec = np.concatenate([
        obs['ee_pose'][0],
        obs['joint_states'][0],
        obs['joint_forces'][0],
        contact_forces[0]
    ],axis=0)
    proprio_vec = torch.tensor(proprio_vec, dtype=torch.float32).unsqueeze(0).to(device)
    proprio_scaled = scaler.transform(proprio_vec)
    return front_camera_rgb, hand_camera_rgb, proprio_scaled

# ---------- enhanced layer stats helper ----------
def print_model_stats(model):
    print("模型层统计（线性/conv/bn）:")
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            w = m.weight.detach().cpu().numpy()
            b = m.bias.detach().cpu().numpy() if m.bias is not None else None
            print(f"  Linear {name}: weight {w.shape}, mean {w.mean():.4e}, min {w.min():.4e}, max {w.max():.4e}")
            if b is not None:
                print(f"    bias mean {b.mean():.4e}, min {b.min():.4e}, max {b.max():.4e}")
        elif isinstance(m, nn.Conv2d):
            w = m.weight.detach().cpu().numpy()
            print(f"  Conv2d {name}: weight {w.shape}, mean {w.mean():.4e}")
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            print(f"  BN {name}")

# ---------- main diagnostic pass ----------
def diagnose_once(model, obs, device, scaler=None, env=None):
    model.eval()
    front_t, hand_t, proprio_t = preprocess_obs_for_model(obs, device, scaler)
    # quick checks input shapes
    print("输入形状: front", tuple(front_t.shape), "hand", tuple(hand_t.shape), "proprio", tuple(proprio_t.shape))
    # check NaN/Inf in inputs
    for name, t in [('front', front_t), ('hand', hand_t), ('proprio', proprio_t)]:
        arr = t.cpu().numpy()
        print(f"  {name} NaN: {np.isnan(arr).any()}, Inf: {np.isinf(arr).any()}, mean {arr.mean():.4e}, min {arr.min():.4e}, max {arr.max():.4e}")
    # try forward (model may be wrapped already)
    try:
        with torch.no_grad():
            out = model(front_t, hand_t, proprio_t)
    except RuntimeError as e:
        print("首次前向抛出错误：", e)
        raise
    if isinstance(out, torch.Tensor):
        out_np = out.cpu().numpy()[0]
    else:
        out_np = np.asarray(out)[0]
    print("=== forward 输出统计 ===")
    print("raw pre-tanh min/max/mean:", float(out_np.min()), float(out_np.max()), float(out_np.mean()))
    after_tanh = np.tanh(out_np)
    print("post-tanh min/max/mean:", float(after_tanh.min()), float(after_tanh.max()), float(after_tanh.mean()))
    # print last linear if any
    last_linear = None
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            last_linear = m
    if last_linear is not None:
        w = last_linear.weight.detach().cpu().numpy()
        b = last_linear.bias.detach().cpu().numpy() if last_linear.bias is not None else None
        print("last_linear weight mean/min/max:", float(w.mean()), float(w.min()), float(w.max()))
        if b is not None:
            print("last_linear bias mean/min/max:", float(b.mean()), float(b.min()), float(b.max()))
    # image enc stats
    if hasattr(model, "img_enc"):
        try:
            with torch.no_grad():
                f1 = model.img_enc(front_t)
                f2 = model.img_enc(hand_t)
            print("img feat f1 mean/min/max:", float(f1.mean()), float(f1.min()), float(f1.max()))
            print("img feat f2 mean/min/max:", float(f2.mean()), float(f2.min()), float(f2.max()))
        except Exception as e:
            print("尝试获取 img_enc 特征时出错：", e)
    print("proprio (raw) stats min/max/mean:", float(proprio_t.min()), float(proprio_t.max()), float(proprio_t.mean()))
    if env is not None:
        try:
            print("env.action_space low:", env.action_space.low, " high:", env.action_space.high)
            low = env.action_space.low
            high = env.action_space.high
        except Exception:
            low, high = -1.0, 1.0
    else:
        low, high = -1.0, 1.0
    # mapping examples
    action_a = low + (after_tanh + 1.0) * 0.5 * (high - low)
    max_sym = np.maximum(np.abs(low), np.abs(high))
    action_b = after_tanh * max_sym
    print("mapped action A (env-range):", action_a.astype(np.float32))
    print("mapped action B (sym-debug):", action_b.astype(np.float32))
    # check if outputs are all positive/constant
    if np.allclose(out_np, out_np[0]):
        print("警告：输出所有元素相等，可能最后一层偏置过大或模型 collapse。")
    if np.all(out_np >= 0):
        print("注意：原始输出全部为非负。请检查最后激活或输出尺度（是否需要 tanh/clip）。")
    # more model stats
    print_model_stats(model)
    return out_np, after_tanh, action_a, action_b

# ---------- main ----------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", default="/home/wdy02/wdy_program/simulation_plus/IsaacLab/wdy_data/bc_data/bc_ckpts/ONE_PEG_IN_HOLE_60/best.pt", help="模型 checkpoint 路径")
    p.add_argument("--scaler", default="/home/wdy02/wdy_program/simulation_plus/IsaacLab/wdy_data/bc_data/bc_ckpts/ONE_PEG_IN_HOLE_60/scaler.pkl", help="可选 scaler 路径（若有 Standardizer.load）")
    p.add_argument("--use_env", action="store_true", help="从环境获取观测（需要 IsaacSim 环境可用）")
    p.add_argument("--sample_npz", default="/home/wdy02/software/isaacsim/wdy_data/bc_data/npz/ONE_PEG_IN_HOLE_60/demo_0.npz", help="可选：从 npz 文件读取观测样本（优先）")
    p.add_argument("--num_samples", type=int, default=1, help="要诊断的样本数")
    import sys as _sys
    if 'ipykernel' in _sys.modules:
        args, _ = p.parse_known_args()
    else:
        args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    state, model_kwargs, raw_ckpt = load_checkpoint(args.ckpt, device)
    model = build_model_from_kwargs(model_kwargs, state, device)

    # prepare a dummy sample for forward-check and possible adapter insertion
    dummy_front = torch.zeros(1,3,128,128, device=device)
    dummy_hand = torch.zeros(1,3,128,128, device=device)
    # dummy proprio constructed from model_kwargs if possible, else fallback
    proprio_len_guess = int(model_kwargs.get('input_dim') or model_kwargs.get('proprio_dim') or 10)
    dummy_proprio = torch.zeros(1, proprio_len_guess, device=device)

    # try to wrap model with adapter if needed (catches matmul mismatch)
    try:
        model, adapter_created = wrap_model_with_adapter_if_needed(model, dummy_front, dummy_hand, dummy_proprio, device)
        if adapter_created:
            print("注意：已自动插入 proprio adapter（临时解决维度不匹配）。建议修正 checkpoint.model_kwargs 或模型定义以匹配真实 proprio 维度。")
    except Exception as e:
        print("尝试自动插入 adapter 失败，继续但你可能需要手动检查模型/输入维度。错误：", e)

    scaler = None
    if args.scaler and Standardizer is not None:
        try:
            scaler = Standardizer.load(args.scaler)
            print("Loaded scaler from", args.scaler)
        except Exception as e:
            print("载入 scaler 失败:", e)
            scaler = None

    env = None
    if args.use_env:
        if FrankaGym is None:
            print("无法导入 FrankaGym，use_env 选项不可用。")
        else:
            env = FrankaGym(render=False)
            obs = env.reset()

    # load samples
    samples = []
    if args.sample_npz:
        if os.path.exists(args.sample_npz):
            d = np.load(args.sample_npz, allow_pickle=True)
            if 'obs' in d:
                samples = list(d['obs'])
            elif 'data' in d:
                samples = list(d['data'])
            elif 'observations' in d:
                samples = list(d['observations'])
            else:
                sample = {}
                for k in d.files:
                    sample[k] = d[k]
                samples = [sample]
            print(f"从 npz 读取到 {len(samples)} 个样本（优先使用）。")
        else:
            print("sample_npz 不存在:", args.sample_npz)

    for i in range(args.num_samples):
        if samples:
            obs = samples[i % len(samples)]
        elif env is not None:
            obs = env.get_observation()
        else:
            obs = {
                'agent_rgb': (np.random.rand(128,128,3)*255).astype(np.uint8),
                'hand_rgb': (np.random.rand(128,128,3)*255).astype(np.uint8),
                'ee_pose': np.zeros(7),
                'joint_positions': np.zeros(7),
                'joint_velocities': np.zeros(7),
            }
        print(f"\n--- sample {i+1} ---")
        try:
            diagnose_once(model, obs, device, scaler=scaler, env=env)
        except Exception as e:
            print("诊断 sample 时出错：", e)
            # continue to next sample

if __name__ == "__main__":
    main()


Using repo root: /home/wdy02/wdy_program/simulation_plus/IsaacLab
Using device: cuda


  ckpt = torch.load(ckpt_path, map_location=device)


加载 324 个完全匹配的参数。
state/model 匹配项: 324, 不匹配项: 6 (展示前10项)
  mismatch: ('fusion.2.weight', torch.Size([512, 512]), None)
  mismatch: ('fusion.2.bias', torch.Size([512]), None)
  mismatch: ('fusion.4.weight', torch.Size([256, 512]), None)
  mismatch: ('fusion.4.bias', torch.Size([256]), None)
  mismatch: ('fusion.6.weight', torch.Size([7, 256]), torch.Size([256, 512]))
  mismatch: ('fusion.6.bias', torch.Size([7]), torch.Size([256]))
Loaded scaler from /home/wdy02/wdy_program/simulation_plus/IsaacLab/wdy_data/bc_data/bc_ckpts/ONE_PEG_IN_HOLE_60/scaler.pkl
从 npz 读取到 1 个样本（优先使用）。

--- sample 1 ---
输入形状: front (1, 3, 128, 128) hand (1, 3, 128, 128) proprio (1, 78)
  front NaN: False, Inf: False, mean 1.4874e+02, min 0.0000e+00, max 2.5500e+02
  hand NaN: False, Inf: False, mean 1.7075e+02, min 1.0000e+00, max 2.5400e+02
  proprio NaN: False, Inf: False, mean 1.6170e-01, min -2.8698e+00, max 3.2277e+00
=== forward 输出统计 ===
raw pre-tanh min/max/mean: -0.04922686889767647 0.046364836394786835 0.