# GraspGPT 模型推理与生成

本notebook用于从checkpoint加载GraspGPT模型，加载sequence序列数据，并进行token生成预测。

## 1. 导入必要的库和模块

In [1]:
import os
import sys
import json
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import deepspeed
from deepspeed import comm as dist

# 设置路径并导入graspGPT模块
current_dir = os.getcwd()
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, 'graspGPT'))

from graspGPT.model.model import graspGPT
from graspGPT.model.utils import CfgNode as CN
from graspGPT.model.token_manager import get_token_manager, decode_sequence, encode_sequence
from graspGPT.model.parser_and_serializer import Serializer, Parser,Seq
from graspGPT.model.core import generate_amodal_sequence
import random

print("模块导入完成")

[2025-09-29 18:51:22,931] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/wuminye/miniconda3/envs/grasp/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/wuminye/miniconda3/envs/grasp/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


[2025-09-29 18:51:25,495] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
模块导入完成


## 2. 配置参数

In [2]:
# 设置参数
checkpoint_path = "output/checkpoints/amodal"  # 修改为实际的checkpoint路径
sequence_file = "output/scene_0000_objects_merged_aligned_seq.pth"  # 修改为align_coords.py生成的pth文件路径
deepspeed_config_path = "deepspeed_config.json"  # DeepSpeed配置文件路径

# 生成参数
max_new_tokens = 2000
temperature = 0.3
do_sample = True
top_k = None
num_sequences = 1
seed = 42

# 设置随机种子
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

print(f"Checkpoint路径: {checkpoint_path}")
print(f"序列文件路径: {sequence_file}")
print(f"生成参数: max_new_tokens={max_new_tokens}, temperature={temperature}")

Checkpoint路径: output/checkpoints/amodal
序列文件路径: output/scene_0000_objects_merged_aligned_seq.pth
生成参数: max_new_tokens=2000, temperature=0.3


## 3. 辅助函数定义

In [3]:
def parse_config_string(config_str):
    """解析字符串格式的配置到字典"""
    config_dict = {}
    for line in config_str.strip().split('\n'):
        if ':' in line:
            key, value = line.split(':', 1)
            key = key.strip()
            value = value.strip()
            
            # 尝试解析值
            try:
                if value.lower() == 'true':
                    value = True
                elif value.lower() == 'false':
                    value = False
                elif value.lower() == 'none':
                    value = None
                elif value.startswith('(') and value.endswith(')'):
                    value = eval(value)
                elif value.startswith('[') and value.endswith(']'):
                    value = eval(value)
                elif '.' in value:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
                else:
                    try:
                        value = int(value)
                    except ValueError:
                        pass
            except:
                pass
            
            config_dict[key] = value
    
    return config_dict

def load_training_config(checkpoint_dir):
    """从checkpoint目录加载训练配置"""
    # 尝试从training_state.json加载配置
    training_state_path = os.path.join(checkpoint_dir, 'training_state.json')
    if os.path.exists(training_state_path):
        with open(training_state_path, 'r') as f:
            training_state = json.load(f)
            if 'config' in training_state:
                config_data = training_state['config']
                
                if isinstance(config_data, dict):
                    parsed_config = {}
                    for section_name, section_value in config_data.items():
                        if isinstance(section_value, str):
                            parsed_config[section_name] = parse_config_string(section_value)
                        else:
                            parsed_config[section_name] = section_value
                    return CN.from_dict(parsed_config)
                else:
                    return CN.from_dict(config_data)
    
    # 备用：查找config.json文件
    search_dirs = [checkpoint_dir, os.path.dirname(checkpoint_dir)]
    config_names = ['config.json', 'training_config.json']
    
    for search_dir in search_dirs:
        for config_name in config_names:
            config_path = os.path.join(search_dir, config_name)
            if os.path.exists(config_path):
                with open(config_path, 'r') as f:
                    config_dict = json.load(f)
                return CN.from_dict(config_dict)
    
    raise FileNotFoundError(f"未找到配置文件在 {checkpoint_dir} 或其父目录中")

print("辅助函数定义完成")

辅助函数定义完成


## 4. 第一部分：从checkpoint目录加载模型和参数

In [4]:
# 加载训练配置
print("正在加载训练配置...")
config = load_training_config(checkpoint_path)
print("训练配置加载成功")

# 打印模型配置信息
print(f"模型类型: {getattr(config.model, 'model_type', 'custom')}")
print(f"词汇大小: {config.model.vocab_size}")
print(f"块大小: {config.model.block_size}")

# 修复模型配置以满足XOR条件
if hasattr(config.model, 'model_type') and config.model.model_type:
    if hasattr(config.model, 'n_layer'):
        config.model.n_layer = None
    if hasattr(config.model, 'n_head'):
        config.model.n_head = None  
    if hasattr(config.model, 'n_embd'):
        config.model.n_embd = None

正在加载训练配置...
训练配置加载成功
模型类型: gpt2
词汇大小: 147219
块大小: 6000


In [5]:
# 获取token管理器和词汇表
print("正在获取token管理器...")
token_manager = get_token_manager()

# 生成token映射
img_h, img_w, img_d =80, 54, 34  # 默认体积维度
config.dataset.data_path = 'output/precomputed_data/'
# 尝试从数据中获取实际维度
if hasattr(config.dataset, 'data_path') and config.dataset.data_path:
    import glob
    data_files = glob.glob(os.path.join(config.dataset.data_path, "*.pth"))
    if data_files:
        sample_file = data_files[0]
        print(f"从样本文件获取体积维度: {os.path.basename(sample_file)}")
        raw_data = torch.load(sample_file, weights_only=False)
        if 'volume_dims' in raw_data:
            img_h, img_w, img_d = raw_data['volume_dims']
            print(f"从数据获取的体积维度: {img_h}x{img_w}x{img_d}")

token_mapping = token_manager.generate_mapping(img_h, img_w, img_d)
vocab_size = len(token_mapping)

# 更新配置中的词汇大小
config.model.vocab_size = vocab_size

print(f"Token管理器词汇大小: {vocab_size}")
print(f"使用的体积维度: {img_h}x{img_w}x{img_d}")

正在获取token管理器...
从样本文件获取体积维度: precomputed_batch_63938_0.pth
Token管理器词汇大小: 147219
使用的体积维度: 80x54x34


In [6]:
# 创建模型
print("正在创建模型...")
model = graspGPT(config.model)
param_count = sum(p.numel() for p in model.parameters()) / 1e6
print(f"模型创建完成: {param_count:.2f}M 参数")

# 加载DeepSpeed配置
print("正在加载DeepSpeed配置...")
with open(deepspeed_config_path, 'r') as f:
    ds_config = json.load(f)

# 配置推理模式
ds_config.update({
    "train_batch_size": 1,
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 1,
})

if ds_config.get("bf16", {}).get("enabled", False):
    print("bf16已启用用于推理")

print("DeepSpeed配置加载完成")

正在创建模型...
Using Qwen2 model with RoPE position encoding


You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour


number of parameters: 226.36M
模型创建完成: 226.36M 参数
正在加载DeepSpeed配置...
bf16已启用用于推理
DeepSpeed配置加载完成


In [7]:
# 加载checkpoint
print(f"正在加载checkpoint: {checkpoint_path}")

# 解析checkpoint路径
if os.path.isfile(checkpoint_path):
    parent_dir = os.path.dirname(checkpoint_path)
    tag = os.path.basename(checkpoint_path)
elif os.path.isdir(checkpoint_path):
    parent_dir = os.path.dirname(checkpoint_path)
    tag = os.path.basename(checkpoint_path)
else:
    raise ValueError(f"Checkpoint路径不存在: {checkpoint_path}")

# 尝试使用DeepSpeed加载
try:
    # 初始化DeepSpeed引擎用于推理
    model_engine, _, _, _ = deepspeed.initialize(
        model=model,
        config=ds_config,
        model_parameters=model.parameters()
    )
    
    # 加载checkpoint
    _, client_state = model_engine.load_checkpoint(parent_dir, tag=tag)
    print(f"使用DeepSpeed成功加载checkpoint")
    
except Exception as e:
    print(f"DeepSpeed加载失败: {e}")
    print("回退到PyTorch加载...")
    
    # 回退：使用常规PyTorch加载
    checkpoint_file = os.path.join(checkpoint_path, 'mp_rank_00_model_states.pt')
    if os.path.exists(checkpoint_file):
        print(f"使用PyTorch从以下位置加载checkpoint: {checkpoint_file}")
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        checkpoint = torch.load(checkpoint_file, map_location=device)
        
        state_dict = checkpoint.get('module', checkpoint)
        model.load_state_dict(state_dict, strict=False)
        
        if torch.cuda.is_available():
            model = model.cuda()
        
        print(f"使用PyTorch成功加载checkpoint")
        
        # 创建模型包装器以兼容DeepSpeed接口
        class ModelWrapper:
            def __init__(self, model):
                self.module = model
                self.model = model
                self.local_rank = 0
                
            def eval(self):
                self.module.eval()
                
            def __call__(self, *args, **kwargs):
                return self.module(*args, **kwargs)
        
        model_engine = ModelWrapper(model)
    else:
        raise ValueError(f"找不到checkpoint文件: {checkpoint_file}")

print("模型加载完成！")

正在加载checkpoint: output/checkpoints/amodal
[2025-09-29 18:51:29,553] [INFO] [logging.py:107:log_dist] [Rank -1] DeepSpeed info: version=0.17.5, git-hash=unknown, git-branch=unknown
[2025-09-29 18:51:29,553] [INFO] [comm.py:821:init_distributed] cdb=None
[2025-09-29 18:51:29,554] [INFO] [comm.py:836:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...
[2025-09-29 18:51:29,892] [INFO] [comm.py:891:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=172.24.93.140, master_port=29500
[2025-09-29 18:51:29,893] [INFO] [comm.py:852:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2025-09-29 18:51:29,935] [INFO] [config.py:684:__init__] Config mesh_device None world_size = 1
[2025-09-29 18:51:30,210] [INFO] [engine.py:1356:_configure_distributed_model] ********** distributed groups summary **********
	 self.dp_world_size=1
	 self.mp_world_size=1
	 self.seq_dp_world_size=1
	 self

Using /home/wuminye/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/wuminye/.cache/torch_extensions/py310_cu118/fused_adam/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.
Time to load fused_adam op: 0.018658876419067383 seconds
[2025-09-29 18:51:30,370] [INFO] [logging.py:107:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer
[2025-09-29 18:51:30,370] [INFO] [logging.py:107:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2025-09-29 18:51:30,373] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed Basic Optimizer = FusedAdam
[2025-09-29 18:51:30,373] [INFO] [utils.py:59:is_zero_supported_optimizer] Checking ZeRO support for optimizer=FusedAdam type=<class 'deepspeed.ops.adam.fused_adam.FusedAdam'>
[2025-09-29 18:51:30,374] [INFO] [logging.py:107:log_dist] [Rank 0] Creating torch.bfloat16 ZeRO stage 1 optimizer
[2025-09-29 18:51:30,375] [INFO] [stage_1_and_2.py:178:__init__] Reduce bucket size 500000000
[2025-09-29 18:51:30,375] [INFO] [stage_1_and_2.py:179:__init__] Allgather bucket size 500000000
[2025-09-29 18:51:30,376] [INFO] [stage_1_and_2.py:180:__init__] CPU O

Loading extension module fused_adam...


[2025-09-29 18:51:30,751] [INFO] [utils.py:781:see_memory_usage] Before initializing optimizer states
[2025-09-29 18:51:30,752] [INFO] [utils.py:782:see_memory_usage] MA 1.27 GB         Max_MA 1.69 GB         CA 1.69 GB         Max_CA 2 GB 
[2025-09-29 18:51:30,753] [INFO] [utils.py:789:see_memory_usage] CPU Virtual Memory:  used = 3.38 GB, percent = 13.8%
[2025-09-29 18:51:30,887] [INFO] [utils.py:781:see_memory_usage] After initializing optimizer states
[2025-09-29 18:51:30,888] [INFO] [utils.py:782:see_memory_usage] MA 1.27 GB         Max_MA 2.11 GB         CA 2.53 GB         Max_CA 3 GB 
[2025-09-29 18:51:30,888] [INFO] [utils.py:789:see_memory_usage] CPU Virtual Memory:  used = 3.38 GB, percent = 13.8%
[2025-09-29 18:51:30,889] [INFO] [stage_1_and_2.py:605:__init__] optimizer state initialized
[2025-09-29 18:51:30,997] [INFO] [utils.py:781:see_memory_usage] After initializing ZeRO optimizer
[2025-09-29 18:51:30,998] [INFO] [utils.py:782:see_memory_usage] MA 1.27 GB         Max_MA 

## 5. 第二部分：从外部pth文件加载sequence序列

In [8]:
# 加载序列数据
print(f"正在加载序列数据: {sequence_file}")

if not os.path.exists(sequence_file):
    raise FileNotFoundError(f"序列文件不存在: {sequence_file}")

# 加载序列数据
sequence_data = torch.load(sequence_file, weights_only=False)

# 检查数据结构
print(f"序列数据键: {list(sequence_data.keys()) if isinstance(sequence_data, dict) else 'not a dict'}")

# 提取token序列
if isinstance(sequence_data, dict):
    if 'tokens' in sequence_data:
        token_sequence = sequence_data['tokens']
    elif 'sequence' in sequence_data:
        token_sequence = sequence_data['sequence']
    elif 'token_ids' in sequence_data:
        token_sequence = sequence_data['token_ids']
    else:
        # 假设第一个可用的键包含序列
        key = list(sequence_data.keys())[0]
        token_sequence = sequence_data[key]
        print(f"使用键 '{key}' 作为token序列")
elif isinstance(sequence_data, (list, tuple)):
    token_sequence = sequence_data
else:
    token_sequence = sequence_data



正在加载序列数据: output/scene_0000_objects_merged_aligned_seq.pth
序列数据键: ['seq', 'token_sequence', 'data_list', 'volume_dims', 'voxel_size', 'bbox_min', 'bbox_max', 'scene_name']
使用键 'seq' 作为token序列


# --------------(training data

In [17]:
sequence_file = 'output/precomputed_data/precomputed_batch_63938_0.pth'
# 加载序列数据
print(f"正在加载序列数据: {sequence_file}")

if not os.path.exists(sequence_file):
    raise FileNotFoundError(f"序列文件不存在: {sequence_file}")

# 加载序列数据
sequence_data = torch.load(sequence_file, weights_only=False)

sequence_data = sequence_data[5]['raw_tokens']

# 检查数据结构
print(f"序列数据键: {list(sequence_data.keys()) if isinstance(sequence_data, dict) else 'not a dict'}")


print(f"使用键 raw_tokens 作为token序列")

token_sequence = decode_sequence(sequence_data, token_mapping)
parser = Parser(token_sequence)
token_sequence = parser.parse()
print(len(token_sequence.items))


正在加载序列数据: output/precomputed_data/precomputed_batch_63938_0.pth
序列数据键: not a dict
使用键 raw_tokens 作为token序列
2


## 构造场景数据

In [18]:
sbs= token_sequence.items
#random.shuffle(sbs)
sbs[0].sbs = sbs[0].sbs[:3]
New_seq = Seq(items=[sbs[0]])
print(sbs[0].sbs[0].tag, sbs[0].sbs[1].tag)
flat_tokens = Serializer.serialize(New_seq)

object13 object40


## 6. 第三部分：在序列尾部加入新的指定tokens作为prompt，让model接下去预测

## Grasp prediction

In [22]:
scene_promt = encode_sequence(flat_tokens, token_mapping)[:-1]
print(len(scene_promt),scene_promt[-10:])

# 指定要添加到序列尾部的tokens作为prompt
# 这里可以根据需要修改，例如添加特定的命令tokens
additional_tokens = [
    token_mapping['detectgrasp'],  # 检测抓取命令
    token_mapping['grasp'],         # 抓取命令
    #token_mapping['object24']
]

print(f"要添加的tokens: {additional_tokens}")
print(f"对应的token名称: {[k for k, v in token_mapping.items() if v in additional_tokens]}")

# 将新的tokens添加到序列尾部
prompt_sequence = scene_promt + additional_tokens
#prompt_sequence = scene_promt

print(f"添加prompt后的序列长度: {len(prompt_sequence)}")
print(f"完整prompt序列的最后20个tokens: {prompt_sequence[-20:]}")

要添加的tokens: [96, 97]
对应的token名称: ['detectgrasp', 'grasp']
添加prompt后的序列长度: 3409
完整prompt序列的最后20个tokens: [41892, 41924, 41925, 41926, 41958, 41959, 41960, 43727, 43728, 43761, 43762, 43795, 43796, 43829, 43830, 45598, 45632, 45666, 96, 97]


## Amodal

In [26]:
tokens = generate_amodal_sequence(flat_tokens,(img_h, img_w, img_d))
tokens = tokens[:tokens.index("unlabel")+1]
prompt_sequence = encode_sequence(tokens, token_mapping)
#print(tokens)

## do prediction

In [12]:
# 准备输入张量
def prepare_input_from_tokens(token_ids, max_length=None):
    """从token ID列表准备输入"""
    if not token_ids:
        raise ValueError("提供的token ID列表为空")
    
    # 转换为张量
    input_ids = torch.tensor([token_ids], dtype=torch.long)
    
    # 如果需要，进行截断
    if max_length and input_ids.size(1) > max_length:
        input_ids = input_ids[:, -max_length:]
    
    return input_ids

# 准备输入
input_ids = prepare_input_from_tokens(prompt_sequence, config.model.block_size)
input_ids = input_ids.repeat(2,1)
original_length = input_ids.size(1)

print(f"输入张量形状: {input_ids.shape}")
print(f"原始序列长度: {original_length}")

输入张量形状: torch.Size([2, 568])
原始序列长度: 568


In [24]:
# 配置生成参数
generation_config = {
    'max_new_tokens': max_new_tokens+3000,
    'temperature': temperature*3,
    'do_sample': do_sample,
    'top_k': top_k,
    'eos_token_id': token_mapping.get('end', None)
}

print(f"生成配置: {generation_config}")

生成配置: {'max_new_tokens': 5000, 'temperature': 0.8999999999999999, 'do_sample': True, 'top_k': None, 'eos_token_id': 98}


In [25]:
# 执行生成
def generate_with_model(model_engine, input_ids, generation_config):
    """使用模型生成序列"""
    # 将输入移动到正确的设备
    if hasattr(model_engine, 'local_rank') and model_engine.local_rank is not None:
        device = f'cuda:{model_engine.local_rank}'
    else:
        # 回退：从模型参数获取设备
        device = next(model_engine.module.parameters()).device
    
    input_ids = input_ids.to(device)
    
    # 将模型设置为评估模式
    model_engine.eval()
    
    with torch.no_grad():
        # 生成序列
        if hasattr(model_engine.module, 'generate'):
            generated = model_engine.module.generate(
                idx=input_ids,
                max_new_tokens=generation_config.get('max_new_tokens', 50),
                temperature=generation_config.get('temperature', 1.0),
                do_sample=generation_config.get('do_sample', True),
                top_k=generation_config.get('top_k', None),
                end_token=generation_config.get('eos_token_id', None)
            )
        else:
            raise ValueError("模型没有generate方法")
    
    return generated

print("开始生成...")
start_time = time.time()

print("输入长度:",input_ids.size())
# 执行生成
generated = generate_with_model(model_engine, input_ids, generation_config)

generation_time = time.time() - start_time

print(f"生成完成，耗时: {generation_time:.2f}秒")
print(f"生成的序列长度: {generated.size(1)} tokens. {generated.size()}")

开始生成...
输入长度: torch.Size([2, 568])
生成完成，耗时: 204.77秒
生成的序列长度: 5568 tokens. torch.Size([2, 5568, 1])


In [27]:
from extract_sample_and_export import visualize_tokens

all_scene  = generated.squeeze().detach().cpu()
all_scene[:,-1] = token_mapping.get('end', None)
for i in range(all_scene.size(0)):
    visualize_tokens(all_scene[i], token_mapping, volume_dims =(img_h, img_w, img_d), bbox_min = np.array([-0.3, -0.2, 0]), voxel_size = 0.0075, output_dir = f'./output/tokens_visual/{i}')

=== 开始可视化tokens ===
正在转换token ids到序列...
解码得到的tokens数量: 5568
前10个tokens: ['scene', 'incomplete', (2, 24, 5), (3, 27, 6), (4, 23, 5), (4, 28, 6), (4, 29, 6), (4, 30, 1), (4, 30, 5), (5, 23, 5)]
ParseError: Unexpected token 'unknow' after parsing SEQ at position 3127
解析成功，序列包含 3 个项目
体素信息: dims=(80, 54, 34), bbox_min=[-0.3 -0.2  0. ], voxel_size=0.0075
正在按类别提取点云...
Scene SB 'incomplete': 564 个点
Amodal SB 'unlabel': 2157 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GRASP 'incomplete': 3 个点
GR

In [23]:
print(decode_sequence(generated[1].cpu().squeeze().numpy().tolist(),token_mapping))

['scene', 'incomplete', (2, 24, 5), (3, 27, 6), (4, 23, 5), (4, 28, 6), (4, 29, 6), (4, 30, 1), (4, 30, 5), (5, 23, 5), (5, 24, 6), (5, 26, 7), (5, 27, 7), (5, 29, 6), (5, 30, 3), (5, 30, 4), (6, 23, 5), (6, 27, 7), (6, 29, 6), (7, 23, 2), (7, 23, 3), (7, 23, 4), (7, 24, 6), (7, 25, 6), (7, 29, 6), (8, 24, 1), (8, 24, 2), (8, 24, 3), (8, 24, 4), (8, 24, 5), (8, 26, 6), (8, 28, 6), (8, 29, 6), (8, 30, 2), (8, 30, 3), (8, 30, 4), (9, 25, 1), (9, 25, 2), (9, 25, 3), (9, 25, 4), (9, 25, 5), (9, 26, 1), (9, 26, 5), (9, 28, 1), (9, 28, 5), (9, 29, 2), (9, 29, 4), (10, 27, 2), (10, 27, 3), (10, 28, 3), (10, 28, 4), (40, 15, 6), (41, 14, 6), (41, 15, 6), (41, 16, 6), (41, 18, 5), (41, 44, 4), (42, 14, 6), (42, 16, 6), (42, 17, 6), (42, 18, 6), (42, 19, 2), (42, 19, 3), (42, 19, 4), (42, 47, 5), (43, 12, 5), (43, 13, 6), (43, 15, 6), (43, 16, 6), (43, 17, 6), (43, 18, 6), (43, 19, 1), (43, 19, 2), (43, 19, 4), (43, 19, 5), (43, 42, 5), (43, 43, 6), (43, 44, 6), (43, 45, 6), (43, 46, 6), (43, 47

In [None]:
# 完整序列（原始 + prompt + 生成）
full_sequence = input_ids[0].cpu().numpy().tolist() + decoded_tokens

print(f"完整序列长度: {len(full_sequence)}")
print(f"完整序列的最后30个tokens: {full_sequence[-30:]}")

# 保存结果
result = {
    'original_sequence_length': len(token_sequence),
    'prompt_tokens': additional_tokens,
    'input_length': original_length,
    'generated_tokens': decoded_tokens,
    'generated_token_names': decoded_names,
    'total_length': len(full_sequence),
    'generation_time': generation_time,
    'generation_config': generation_config,
    'full_sequence': full_sequence
}

print("\n=== 生成结果总结 ===")
print(f"原始序列长度: {result['original_sequence_length']}")
print(f"添加的prompt tokens: {result['prompt_tokens']}")
print(f"输入长度: {result['input_length']}")
print(f"生成的新tokens: {result['generated_tokens']}")
print(f"生成时间: {result['generation_time']:.2f}秒")

## 7. 可选：可视化生成结果

In [None]:
# 可选：使用可视化功能（如果extract_sample_and_export模块可用）
try:
    sys.path.append('../')
    from extract_sample_and_export import visualize_tokens
    
    # 设置可视化参数（根据实际情况调整）
    volume_dims = (img_h, img_w, img_d)
    bbox_min = np.array([-0.3, -0.2, 0])  # 根据实际情况调整
    voxel_size = 0.0075  # 根据实际情况调整
    
    # 可视化完整序列
    output_dir = "./output/generation_visual/notebook_result"
    
    visualize_tokens(
        tokens=full_sequence,
        token_mapping=token_mapping,
        volume_dims=volume_dims,
        bbox_min=bbox_min,
        voxel_size=voxel_size,
        output_dir=output_dir
    )
    
    print(f"可视化结果保存到: {output_dir}")
    
except ImportError:
    print("可视化模块不可用，跳过可视化步骤")
except Exception as e:
    print(f"可视化过程中出错: {e}")

## 8. 保存结果到文件

In [None]:
# 保存结果到JSON文件
output_file = "generation_result.json"
with open(output_file, 'w') as f:
    json.dump(result, f, indent=2)

print(f"结果已保存到: {output_file}")

# 同时保存为.pth文件以便后续使用
torch.save({
    'generated_sequence': full_sequence,
    'metadata': result
}, "generation_result.pth")

print("结果也已保存为generation_result.pth")
print("\n生成完成！")