In [7]:
import torch
import torch.onnx
import numpy as np
import os
from collections import OrderedDict

# 首先需要重新定义网络结构（与您原始代码中的结构相同）
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        return torch.nn.functional.softmax(self.fc2(x), dim=1)

class QValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        return self.fc2(x)

def convert_cql_model_to_onnx(model_path, output_dir="./onnx_models"):
    """
    将CQL模型转换为ONNX格式
    
    Args:
        model_path: CQL模型文件路径
        output_dir: ONNX文件输出目录
    """
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 加载保存的模型检查点
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # 从超参数中获取网络结构信息
    hyperparams = checkpoint.get('hyperparameters', {})
    
    # 设置网络参数（根据您的代码推断）
    state_dim = 39  # H-1，根据您的代码
    hidden_dim = 128
    action_dim = hyperparams.get('action_dim', 5)
    
    print(f"网络参数: state_dim={state_dim}, hidden_dim={hidden_dim}, action_dim={action_dim}")
    
    # 创建网络实例
    actor = PolicyNet(state_dim, hidden_dim, action_dim)
    critic_1 = QValueNet(state_dim, hidden_dim, action_dim)
    critic_2 = QValueNet(state_dim, hidden_dim, action_dim)
    
    # 加载权重
    actor.load_state_dict(checkpoint['actor_state_dict'])
    critic_1.load_state_dict(checkpoint['critic_1_state_dict'])
    critic_2.load_state_dict(checkpoint['critic_2_state_dict'])
    
    # 设置为评估模式
    actor.eval()
    critic_1.eval()
    critic_2.eval()
    
    # 创建示例输入
    dummy_input = torch.randn(1, state_dim)
    
    # 转换Actor网络
    print("正在转换Actor网络...")
    torch.onnx.export(
        actor,
        dummy_input,
        os.path.join(output_dir, "cql_actor.onnx"),
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['state'],
        output_names=['action_probs'],
        dynamic_axes={
            'state': {0: 'batch_size'},
            'action_probs': {0: 'batch_size'}
        }
    )
    
    # 转换Critic1网络
    print("正在转换Critic1网络...")
    torch.onnx.export(
        critic_1,
        dummy_input,
        os.path.join(output_dir, "cql_critic1.onnx"),
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['state'],
        output_names=['q_values'],
        dynamic_axes={
            'state': {0: 'batch_size'},
            'q_values': {0: 'batch_size'}
        }
    )
    
    # 转换Critic2网络
    print("正在转换Critic2网络...")
    torch.onnx.export(
        critic_2,
        dummy_input,
        os.path.join(output_dir, "cql_critic2.onnx"),
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['state'],
        output_names=['q_values'],
        dynamic_axes={
            'state': {0: 'batch_size'},
            'q_values': {0: 'batch_size'}
        }
    )
    
    print(f"转换完成！ONNX文件保存在: {output_dir}")
    
    # 验证ONNX模型
    verify_onnx_models(output_dir, dummy_input)

def verify_onnx_models(onnx_dir, test_input):
    """
    验证ONNX模型是否正确转换
    """
    try:
        import onnx
        import onnxruntime as ort
        
        model_files = ['cql_actor.onnx', 'cql_critic1.onnx', 'cql_critic2.onnx']
        
        for model_file in model_files:
            model_path = os.path.join(onnx_dir, model_file)
            
            # 验证ONNX模型格式
            onnx_model = onnx.load(model_path)
            onnx.checker.check_model(onnx_model)
            
            # 测试推理
            ort_session = ort.InferenceSession(model_path)
            input_name = ort_session.get_inputs()[0].name
            output = ort_session.run(None, {input_name: test_input.numpy()})
            
            print(f"✓ {model_file} 验证成功，输出形状: {output[0].shape}")
            
    except ImportError:
        print("警告: 没有安装onnx或onnxruntime，跳过验证步骤")
        print("可以运行: pip install onnx onnxruntime 来安装验证工具")

# 创建一个简化的推理类，也可以转换为ONNX
class CQLActorInference(torch.nn.Module):
    """
    用于推理的简化Actor模型
    """
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(CQLActorInference, self).__init__()
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim)
    
    def forward(self, x):
        probs = self.actor(x)
        # 返回概率分布和最可能的动作
        action = torch.argmax(probs, dim=1)
        return probs, action

def convert_cql_inference_model(model_path, output_dir="./onnx_models"):
    """
    转换用于推理的完整CQL模型
    """
    os.makedirs(output_dir, exist_ok=True)
    
    checkpoint = torch.load(model_path, map_location='cpu')
    hyperparams = checkpoint.get('hyperparameters', {})
    
    state_dim = 39
    hidden_dim = 128
    action_dim = hyperparams.get('action_dim', 5)
    
    # 创建推理模型
    inference_model = CQLActorInference(state_dim, hidden_dim, action_dim)
    inference_model.actor.load_state_dict(checkpoint['actor_state_dict'])
    inference_model.eval()
    
    dummy_input = torch.randn(1, state_dim)
    
    print("正在转换推理模型...")
    torch.onnx.export(
        inference_model,
        dummy_input,
        os.path.join(output_dir, "cql_inference.onnx"),
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['state'],
        output_names=['action_probs', 'predicted_action'],
        dynamic_axes={
            'state': {0: 'batch_size'},
            'action_probs': {0: 'batch_size'},
            'predicted_action': {0: 'batch_size'}
        }
    )
    
    print("推理模型转换完成！")

if __name__ == "__main__":
    # 转换您的模型
    model_path = "./saved_models/cql_model_64.pth"
    
    if os.path.exists(model_path):
        print("开始转换CQL模型...")
        
        # 转换各个网络组件
        convert_cql_model_to_onnx(model_path)
        
        # 转换推理模型
        convert_cql_inference_model(model_path)
        
        print("\n所有转换完成！")
        print("生成的文件:")
        print("- cql_actor.onnx: Actor网络")
        print("- cql_critic1.onnx: Critic1网络")
        print("- cql_critic2.onnx: Critic2网络")
        print("- cql_inference.onnx: 完整推理模型")
        
    else:
        print(f"错误: 模型文件不存在 {model_path}")
        print("请确保模型文件路径正确")

开始转换CQL模型...
网络参数: state_dim=39, hidden_dim=128, action_dim=5
正在转换Actor网络...
正在转换Critic1网络...
正在转换Critic2网络...
转换完成！ONNX文件保存在: ./onnx_models
✓ cql_actor.onnx 验证成功，输出形状: (1, 5)
✓ cql_critic1.onnx 验证成功，输出形状: (1, 5)
✓ cql_critic2.onnx 验证成功，输出形状: (1, 5)
正在转换推理模型...
推理模型转换完成！

所有转换完成！
生成的文件:
- cql_actor.onnx: Actor网络
- cql_critic1.onnx: Critic1网络
- cql_critic2.onnx: Critic2网络
- cql_inference.onnx: 完整推理模型
