In [None]:
SELECTED_GPUS = [0,1]
#环境设置
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])

import tensorflow as tf 

"""
https://github.com/tensorflow/tensorflow/issues/34415#issuecomment-895336269
https://stackoverflow.com/questions/59616436/how-to-reset-initialization-in-tensorflow-2
"""
MAX_CPU_THREADS = 16
tf.config.threading.set_intra_op_parallelism_threads(MAX_CPU_THREADS)
tf.config.threading.set_inter_op_parallelism_threads(MAX_CPU_THREADS)

tf.get_logger().setLevel('INFO')

GPUS = tf.config.experimental.list_physical_devices('GPU')
for gpu in GPUS:
    tf.config.experimental.set_memory_growth(gpu, True)

DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.NcclAllReduce(),
    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]
)

NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync

print('Number of devices: {}'.format(NUM_GPUS))

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import json
import time
import random
from collections import deque
import psutil  # 用于监控系统内存
import gc      # 垃圾回收

# 导入基类
from dqn_split_basic import BaseDNNSplitter
from download_resnet50 import patch_dnn_splitter

import logging
import datetime

# 创建数据和结果目录
def ensure_directories():
    """确保数据和结果目录存在"""
    os.makedirs('data', exist_ok=True)
    os.makedirs('models', exist_ok=True)
    os.makedirs('result', exist_ok=True)
    os.makedirs('log', exist_ok=True)

# 初始化目录
ensure_directories()


In [None]:
class DQNSplitTrainer(BaseDNNSplitter):
    """用于训练DQN的DNN分割系统
    
    继承BaseDNNSplitter基类，专注于DQN训练功能
    """
    
    def __init__(self, config):
        """初始化DQN训练器
        
        Args:
            config: 配置字典，包含模型、环境和算法参数
        """
        super().__init__(config)
        self.logger = config.get('logger', logging.getLogger(self.__class__.__name__))
        
        # DQN参数
        self.memory = deque(maxlen=2000)  # 经验回放缓冲区
        
        # 核心DQN超参数
        self.gamma = 0.95    # 折扣因子
        self.epsilon = 1.0   # 探索率
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.logger.info(f"DQN参数设置完成: γ={self.gamma}, ε初始值={self.epsilon}, ε最小值={self.epsilon_min}")
        
        # 先创建主模型
        self.logger.info("初始化主DQN模型...")
        self.model = self.build_dqn_model()
        self.logger.info("主DQN模型已初始化")

        # 然后创建目标网络
        self.logger.info("初始化DQN目标网络...")
        self.target_model = self.build_dqn_model()

        # 现在可以安全地调用update_target_model()了
        self.update_target_model()
        
        self.logger.info(f"DQN训练器初始化完成")

In [None]:
def remember(self, state, action, reward, next_state, done):
    """将经验存储到记忆库"""
    self.memory.append((state, action, reward, next_state, done))
    if len(self.memory) % 100 == 0:
        self.logger.debug(f"记忆库大小: {len(self.memory)}")

def act(self, state):
    """根据当前状态选择动作"""
    if np.random.rand() <= self.epsilon:
        # 探索：随机选择动作
        return random.randrange(len(self.natural_bottlenecks) + 2)
    
    # 利用：根据Q值选择动作
    act_values = self.model.predict(state)
    return np.argmax(act_values[0])

def update_target_model(self):
    """更新目标模型权重"""
    self.target_model.set_weights(self.model.get_weights())
    self.logger.info("目标网络权重已更新")

In [None]:
def replay(self, batch_size):
    """从记忆库中随机抽取批量样本进行训练 (改进的DDQN实现)
    
    Args:
        batch_size: 批大小
        
    Returns:
        训练损失
    """
    if len(self.memory) < batch_size:
        self.logger.debug(f"记忆库样本不足，当前: {len(self.memory)}，需要: {batch_size}")
        return 0
        
    minibatch = random.sample(self.memory, batch_size)
    
    losses = []
    
    for state, action, reward, next_state, done in minibatch:
        target = reward
        if not done:
            # DDQN: 主网络选择动作，目标网络评估该动作
            a = np.argmax(self.model.predict(next_state,verbose=0)[0])  # 使用主网络选择动作
            target = reward + self.gamma * self.target_model.predict(next_state,verbose=0)[0][a]  # 使用目标网络评估
        
        target_f = self.model.predict(state, verbose=0)
        original_q = target_f[0][action]
        
        # 计算TD误差
        td_error = abs(target - original_q)
        self.logger.info(f"TD误差: {td_error:.4f}, 原Q值: {original_q:.4f}, 目标Q值: {target:.4f}")
        
        # 设置目标Q值
        target_f[0][action] = target
        
        # 根据TD误差决定训练强度
        if td_error > 0.5:  # 大误差，加强学习
            history = self.model.fit(state, target_f, epochs=2, verbose=0)
        else:
            history = self.model.fit(state, target_f, epochs=1, verbose=0)
        
        losses.append(history.history['loss'][0])
    
    # 衰减探索率 (在主训练函数中已有更复杂的探索策略，这里可以删除)
    # if self.epsilon > self.epsilon_min:
    #     self.epsilon *= self.epsilon_decay
    
    return np.mean(losses) if losses else 0

In [None]:
def _load_or_calculate_best_actions(self, max_bandwidth, fixed_bs):
    """加载或计算不同带宽下的最佳动作
    
    首先尝试从缓存文件加载，如果文件不存在则重新计算并保存
    
    Args:
        max_bandwidth: 最大带宽值(MBps)
        fixed_bs: 固定的批处理大小
        
    Returns:
        best_actions_dict: 不同带宽下的最佳动作字典
    """
    # 构建缓存文件路径
    cache_filename = f"best_actions_{self.model_name}_bs{fixed_bs}.json"
    cache_path = os.path.join('data', cache_filename)
    
    # 检查缓存文件是否存在
    if os.path.exists(cache_path):
        self.logger.info(f"发现缓存的最佳动作文件: {cache_filename}，正在加载...")
        try:
            with open(cache_path, 'r') as f:
                best_actions = json.load(f)
                
            # 将字符串键转换为整数键
            best_actions = {int(k): v for k, v in best_actions.items()}
            self.logger.info(f"成功加载缓存的最佳动作，共 {len(best_actions)} 个带宽点")
            return best_actions
        except Exception as e:
            self.logger.warning(f"加载缓存文件失败: {str(e)}，将重新计算最佳动作")
    else:
        self.logger.info(f"未找到缓存文件 {cache_filename}，将计算所有带宽点的最佳动作...")
    
    # 计算所有带宽点的最佳动作
    best_actions = {}
    
    self.logger.info(f"开始计算所有带宽点的最优动作...")
    # 预先计算所有带宽点
    for bw in range(1, max_bandwidth + 1):
        self.network_bandwidth = bw * 10**6  # 转换为Bps
        self.inference_time_cache = {}  # 清除缓存
        
        self.logger.info(f"计算带宽 {bw}MBps 的最优动作...")
        best_action, best_reward, all_rewards_dict = self._find_best_action_for_bandwidth(bw)
        best_actions[bw] = {
            'action': best_action,
            'reward': best_reward,
            'all_rewards': all_rewards_dict
        }
        self.logger.info(f"带宽 {bw}MBps 的最佳动作: {self.describe_action(best_action)}, 奖励: {best_reward:.4f}")
    
    # 保存计算结果
    self.logger.info(f"所有带宽点的最优动作计算完成，正在保存到 {cache_path}")
    try:
        # 确保data目录存在
        os.makedirs('data', exist_ok=True)
        
        # 保存为JSON文件
        with open(cache_path, 'w') as f:
            json.dump(best_actions, f, indent=2)
        self.logger.info(f"最佳动作数据已保存到 {cache_path}")
    except Exception as e:
        self.logger.error(f"保存最佳动作数据失败: {str(e)}")
    
    return best_actions

In [None]:
def train_dqn_adaptive(self, fixed_bs=1, max_bandwidth=100, max_episodes=1000, target_accuracy=0.90):
    """改进的DDQN训练过程，增强模型在不同带宽下选择最优动作的能力
    
    Args:
        fixed_bs: 固定的批处理大小，默认为1
        max_bandwidth: 最大带宽值(MBps)，默认为100
        max_episodes: 最大训练轮数，默认为1000
        target_accuracy: 早停机制的目标准确率，当达到此准确率时停止训练，默认为0.95
    """
    self.logger.info(f"开始改进版DDQN训练: 固定批处理大小={fixed_bs}, 训练轮数={max_episodes}, 目标准确率={target_accuracy:.2%}")

    
    # 记录训练过程
    training_results = {}
    all_rewards = []
    all_losses = []
    all_accuracies = []  # 新增：记录准确率历史
    
    # 保存原始配置
    original_bw = self.network_bandwidth
    original_bs = self.batch_size
    original_epsilon = self.epsilon
    
    # 固定批处理大小
    self.batch_size = fixed_bs
    self.logger.info(f"训练使用固定批大小: {fixed_bs}")
    
    # 定义带宽采样策略：更均衡地采样各带宽区间
    bandwidth_ranges = [
        (1, 10),    # 低带宽区间
        (10, 20),   # 中低带宽区间
        (20, 60),   # 中带宽区间
        (60, 100)   # 高带宽区间
    ]
    range_weights = [0.25, 0.25, 0.25, 0.25]  # 均衡采样
    
    # 优先经验回放缓冲区增强版
    priority_memory = {}  # 带宽范围 -> 经验列表
    for bw_range in bandwidth_ranges:
        priority_memory[bw_range] = []
    
    # 性能跟踪
    best_overall_reward = float('-inf')
    no_improvement_count = 0
    best_model_path = os.path.join('models', f'dqn_best_{self.model_name}_bs{fixed_bs}.h5')
    
    # 新增：早停参数
    last_accuracy = 0.0
    early_stop = False
    
    # 【方案1】添加最小训练轮数
    min_episodes = 100# 至少训练30%的轮数
    self.logger.info(f"设置最小训练轮数: {min_episodes} (即使准确率达标也会继续训练到此轮数)")
    
    # 【方案2】加载或计算所有带宽的最优动作
    self.logger.info(f"准备最佳动作数据...")
    self.best_actions_per_bandwidth = self._load_or_calculate_best_actions(max_bandwidth, fixed_bs)
    self.logger.info(f"最佳动作数据准备完成，共 {len(self.best_actions_per_bandwidth)} 个带宽点")
    
    try:
        for episode in range(max_episodes):
            # 自适应带宽采样策略
            if episode < max_episodes * 0.2:  # 前20%均匀探索
                selected_range_idx = np.random.choice(len(bandwidth_ranges), p=range_weights)
                min_bw, max_bw = bandwidth_ranges[selected_range_idx]
                bw = np.random.randint(min_bw, max_bw + 1)
            else:  # 后80%聚焦错误率高的带宽区间
                # 每100轮进行一次全带宽扫描找出模型表现差的区间
                if episode % 100 == 0 and episode > 0:
                    error_weights = self._analyze_model_errors()
                    if any(error_weights):  # 如果有错误权重数据
                        selected_range_idx = np.random.choice(len(bandwidth_ranges), p=error_weights)
                        min_bw, max_bw = bandwidth_ranges[selected_range_idx]
                        bw = np.random.randint(min_bw, max_bw + 1)
                    else:  # 默认均匀采样
                        selected_range_idx = np.random.choice(len(bandwidth_ranges), p=range_weights)
                        min_bw, max_bw = bandwidth_ranges[selected_range_idx]
                        bw = np.random.randint(min_bw, max_bw + 1)
                else:  # 正常采样
                    selected_range_idx = np.random.choice(len(bandwidth_ranges), p=range_weights)
                    min_bw, max_bw = bandwidth_ranges[selected_range_idx]
                    bw = np.random.randint(min_bw, max_bw + 1)
            
            self.network_bandwidth = bw * 10**6  # 转换为Bps
            self.inference_time_cache = {}  # 清除缓存
            
            # 使用预计算的最佳动作
            best_action = self.best_actions_per_bandwidth[bw]['action']
            best_reward = self.best_actions_per_bandwidth[bw]['reward']
            
            # 获取当前状态并执行动作
            state = self.get_state()
            action = self.act(state)  # 通过DQN策略选择动作
            
            action_result = self.execute_action(action)
            reward = self.calculate_reward(action_result['result'])
            next_state = self.get_state()
            
            # 将经验存入记忆库
            self.remember(state, action, reward, next_state, False)
            
            # 优先存储到带宽区间特定的记忆库
            for (min_r, max_r) in bandwidth_ranges:
                if min_r <= bw <= max_r:
                    # 每个区间记忆库限制大小
                    if len(priority_memory[(min_r, max_r)]) >= 500:
                        priority_memory[(min_r, max_r)].pop(0)
                    priority_memory[(min_r, max_r)].append((state, action, reward, next_state, False))
                    break
            
            # 训练模型
            current_loss = 0
            if len(self.memory) > 32:
                current_loss = self.replay(32)
                all_losses.append(current_loss)
            
            all_rewards.append(reward)
             # 在循环末尾添加打印当前epsilon值的代码
            self.logger.info(f"Episode {episode+1}/{max_episodes}, 当前探索率(epsilon): {self.epsilon:.4f}")
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            
            # 打印每轮的信息
            is_optimal = "是" if action == best_action else "否"
            gap = 0 if action == best_action else (best_reward - reward)
            self.logger.info(f"Episode {episode+1} 结果: 动作={action}({self.describe_action(action)}), "
                          f"奖励={reward:.4f}, 是最优动作: {is_optimal}, 差距: {gap:.4f}, "
                          f"损失={current_loss:.6f}")
            
            # 周期性更新目标网络
            if episode % 5 == 0:
                self.update_target_model()
                self.logger.info(f"Episode {episode+1}: 目标网络已更新")
            
            # 保存性能最好的模型
            if reward > best_overall_reward and episode > max_episodes * 0.1:
                best_overall_reward = reward
                no_improvement_count = 0
                self.model.save_weights(best_model_path)
                self.logger.info(f"发现更好的模型，奖励: {reward:.4f}")
            else:
                no_improvement_count += 1
            
            # 定期检查模型准确性（选择正确动作的比例）
            if episode % 20 == 0 and episode > 0:
                accuracy = self._evaluate_model_accuracy()
                all_accuracies.append(accuracy)  # 记录准确率
                self.logger.info(f"Episode {episode+1}: 模型准确率: {accuracy:.2f}%")
                
                # 【方案1】添加最小训练轮数条件
                if accuracy >= target_accuracy * 100 and episode >= min_episodes:
                    self.logger.info(f"达到目标准确率 {accuracy:.2f}% >= {target_accuracy*100:.2f}%，且已满足最小训练轮数，提前结束训练!")
                    early_stop = True
                    break
                
                last_accuracy = accuracy
            # 定期执行垃圾回收
            if episode % 100 == 0:
                gc.collect()
        
        # 训练结束，恢复最佳模型
        if os.path.exists(best_model_path):
            self.model.load_weights(best_model_path)
            self.logger.info(f"训练结束，已加载性能最佳模型")
        
        # 保存训练结果
        self._save_training_results(all_rewards, all_losses, max_episodes, fixed_bs)
        self._plot_training_progress(all_rewards, all_losses, fixed_bs)
        
    finally:
        # 恢复原始配置
        self.network_bandwidth = original_bw
        self.batch_size = original_bs
        self.epsilon = original_epsilon
    
    return training_results

In [None]:
def _find_best_action_for_bandwidth(self, bw):
    """针对特定带宽尝试所有可能的动作，找出最佳动作
    
    Args:
        bw: 带宽值(MBps)
        
    Returns:
        best_action: 最佳动作索引
        best_reward: 最佳奖励值
        all_rewards: 所有动作的奖励值字典
    """
    self.network_bandwidth = bw * 10**6  # 转换为Bps
    self.inference_time_cache = {}  # 清除缓存
    best_action = None
    best_reward = float('-inf')
    all_rewards = {}
    
    # 尝试所有可能的动作
    num_actions = len(self.natural_bottlenecks) + 2
    for action in range(num_actions):
        action_result = self.execute_action(action)
        reward = self.calculate_reward(action_result['result'])
        all_rewards[action] = {
            'reward': float(reward),
            'description': self.describe_action(action)
        }
        
        if reward > best_reward:
            best_reward = reward
            best_action = action
    
    return best_action, best_reward, all_rewards

def _priority_replay(self, minibatch):
    """带优先级的经验回放
    
    Args:
        minibatch: 经验样本批次
        
    Returns:
        训练损失
    """
    losses = []
    
    for state, action, reward, next_state, done in minibatch:
        target = reward
        if not done:
            # DDQN: 主网络选择动作，目标网络评估该动作
            a = np.argmax(self.model.predict(next_state, verbose=0)[0])
            target = reward + self.gamma * self.target_model.predict(next_state, verbose=0)[0][a]
        
        # 获取当前Q值估计
        target_f = self.model.predict(state, verbose=0)
        original_q = target_f[0][action]
        target_f[0][action] = target
        
        # 计算TD误差
        td_error = abs(target - original_q)
        
        # 根据TD误差调整学习率
        if td_error > 0.5:  # 大误差，更强调学习
            history = self.model.fit(state, target_f, epochs=2, verbose=0)
        else:
            history = self.model.fit(state, target_f, epochs=1, verbose=0)
        
        losses.append(history.history['loss'][0])
    
    return np.mean(losses) if losses else 0

def _analyze_model_errors(self):
    """分析模型在不同带宽区间的错误，返回应当重点关注的区间权重
    
    Returns:
        区间权重列表
    """
    bandwidth_ranges = [
        (1, 10),    # 低带宽区间
        (10, 20),   # 中低带宽区间
        (20, 60),   # 中带宽区间
        (60, 100)   # 高带宽区间
    ]
    
    # 默认平均权重
    if not hasattr(self, 'bandwidth_errors'):
        return [0.25, 0.25, 0.25, 0.25]
    
    # 计算每个区间的错误率
    error_rates = []
    for min_bw, max_bw in bandwidth_ranges:
        errors = 0
        total = 0
        for bw in range(min_bw, max_bw + 1):
            if bw in self.bandwidth_errors:
                errors += self.bandwidth_errors[bw]['errors']
                total += self.bandwidth_errors[bw]['total']
        
        if total > 0:
            error_rates.append(errors / total)
        else:
            error_rates.append(0.25)  # 默认值
    
    # 归一化错误率作为权重
    if sum(error_rates) > 0:
        weights = [rate / sum(error_rates) for rate in error_rates]
        return weights
    else:
        return [0.25, 0.25, 0.25, 0.25]  # 默认平均权重

def _evaluate_model_accuracy(self):
    """评估模型在各带宽下选择最佳动作的准确率
    
    Returns:
        准确率百分比
    """
    if not hasattr(self, 'best_actions_per_bandwidth') or not self.best_actions_per_bandwidth:
        return 0
    
    correct = 0
    total = 0
    
    # 初始化或重置带宽错误记录
    if not hasattr(self, 'bandwidth_errors'):
        self.bandwidth_errors = {}
    
    # 评估带宽采样点
    test_bandwidths = list(range(1, 101))
    temp_epsilon = self.epsilon
    self.epsilon = 0  # 禁用探索
    
    for bw in test_bandwidths:
        if bw in self.best_actions_per_bandwidth:
            best_action = self.best_actions_per_bandwidth[bw]['action']
            
            # 使用模型预测动作
            self.network_bandwidth = bw * 10**6
            self.inference_time_cache = {}  # 清除缓存
            state = self.get_state()
            predicted_action = self.act(state)
            
            # 记录准确性
            if predicted_action == best_action:
                correct += 1
            
            # 记录错误统计
            if bw not in self.bandwidth_errors:
                self.bandwidth_errors[bw] = {'errors': 0, 'total': 0}
            
            if predicted_action != best_action:
                self.bandwidth_errors[bw]['errors'] += 1
            self.bandwidth_errors[bw]['total'] += 1
            
            total += 1
    
    # 恢复探索率
    self.epsilon = temp_epsilon
    
    return (correct / total * 100) if total > 0 else 0

def _focused_training_for_problem_bandwidths(self):
    """针对问题带宽区间进行集中训练"""
    if not hasattr(self, 'bandwidth_errors') or not self.bandwidth_errors:
        return
    
    problem_bandwidths = []
    for bw, stats in self.bandwidth_errors.items():
        if stats['total'] > 0 and stats['errors'] / stats['total'] > 0.5:
            problem_bandwidths.append(bw)
    
    if not problem_bandwidths:
        return
    
    self.logger.info(f"对问题带宽进行集中训练: {problem_bandwidths}")
    
    # 保存当前状态
    original_epsilon = self.epsilon
    
    # 针对问题带宽进行集中训练
    for bw in problem_bandwidths:
        if bw not in self.best_actions_per_bandwidth:
            continue
        
        best_action = self.best_actions_per_bandwidth[bw]['action']
        self.network_bandwidth = bw * 10**6
        self.inference_time_cache = {}  # 清除缓存
        
        # 集中训练10次
        for _ in range(10):
            state = self.get_state()
            next_state = state  # 简化
            reward = self.best_actions_per_bandwidth[bw]['reward']
            
            # 直接使用目标值进行监督训练
            target_f = self.model.predict(state, verbose=0)
            target_f[0][best_action] = reward
            
            # 其他动作设为较低值
            for a in range(len(target_f[0])):
                if a != best_action:
                    target_f[0][a] = reward - 0.5
            
            # 加强训练
            self.model.fit(state, target_f, epochs=5, verbose=0)
    
    # 恢复探索率
    self.epsilon = original_epsilon

In [None]:
# 1. 修改 _save_training_results 函数
def _save_training_results(self, rewards, losses, episodes, fixed_bs=1):
    """保存训练结果到JSON文件"""
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    result = {
        'model_name': self.model_name,
        'batch_size': fixed_bs,
        'episodes': episodes,
        'completed_episodes': len(rewards),
        'final_epsilon': float(self.epsilon),
        'rewards': [float(r) for r in rewards],
        'losses': [float(l) if not np.isnan(l) else None for l in losses]
    }
    
    # 修改：文件名包含模型名称和批处理大小
    result_path = os.path.join('data', f'dqn_training_{self.model_name}_bs{fixed_bs}.json')
    with open(result_path, 'w') as f:
        json.dump(result, f, indent=2)
    
    # 修改：文件名包含模型名称和批处理大小
    final_model_path = os.path.join('models', f'dqn_{self.model_name}_bs{fixed_bs}_final.h5')
    self.model.save_weights(final_model_path)
    
    self.logger.info(f"训练结果已保存到 {result_path}")
    self.logger.info(f"最终模型已保存到 {final_model_path}")

In [None]:
def _plot_training_progress(self, rewards, losses, fixed_bs=1):
    """绘制训练进度图（不进行平滑处理）"""
    self.logger.info("绘制训练过程曲线图")
    
    # 创建图表
    plt.figure(figsize=(10, 5))
    
    # 绘制奖励曲线
    plt.subplot(1, 2, 1)
    plt.plot(np.arange(len(rewards)), rewards, 'b-', linewidth=1.5, label='Reward')
    plt.title('DDQN Training Reward', fontsize=14)
    plt.xlabel('Episodes', fontsize=12)
    plt.ylabel('Average Reward', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # 绘制损失曲线
    if losses:
        plt.subplot(1, 2, 2)
        plt.plot(np.arange(len(losses)), losses, 'r-', linewidth=1.5, label='Loss')
        plt.title('Training Loss', fontsize=14)
        plt.xlabel('Episode', fontsize=12)
        plt.ylabel('Loss', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend()
    
    plt.tight_layout()
    plot_path = os.path.join('result', f"dqn_training_{self.model_name}_bs{fixed_bs}.png")
    plt.savefig(plot_path)
    plt.close()
    
    self.logger.info(f"训练过程图已保存到 {plot_path}")

In [None]:
def predict_optimal_split_dqn(self):
    """使用DQN模型预测最优的分割策略
    
    Returns:
        包含最优分割策略信息的字典
    """
    # 获取当前环境状态
    state = self.get_state()
    
    # 使用DQN模型预测最佳动作
    action, q_values = self.predict_action_with_qvalues(state, self.model)
    
     # 解释动作但不执行它
    action_info = self.interpret_action(action)
    
    # 记录日志
    self.logger.info(f"DQN预测分割策略: {action_info['strategy']}")
    if action_info['strategy'] == 'split':
        bottleneck_name = action_info.get('bottleneck', 'unknown')
        self.logger.info(f"分割点: {bottleneck_name}")
    
    # 添加Q值信息
    action_info['q_values'] = q_values.tolist() if hasattr(q_values, 'tolist') else q_values
    action_info['action_index'] = action
    
    return action_info

In [None]:
def predict_action_with_qvalues(self, state, model):
    """使用给定模型预测动作和对应的Q值
    
    Args:
        state: 当前状态
        model: DQN模型
        
    Returns:
        action: 预测的动作
        q_values: 所有动作的Q值
    """
    q_values = model.predict(state, verbose=0)[0]
    action = np.argmax(q_values)
    return action, q_values

In [None]:
def validate_trained_model(self, fixed_bs=1, max_bandwidth=100):
    """验证已训练模型在各带宽下的性能
    
    Args:
        fixed_bs: 固定的批处理大小
        max_bandwidth: 最大带宽值(MBps)
    
    Returns:
        验证结果字典
    """
    self.logger.info(f"Validating model performance for batch size={fixed_bs}")
    
    # 保存原始配置
    original_bw = self.network_bandwidth
    original_bs = self.batch_size
    original_epsilon = self.epsilon
    
    # 禁用探索
    self.epsilon = 0.0
    
    # 固定批处理大小
    self.batch_size = fixed_bs
    
    results = {}
    
    try:
        for bw in range(1, max_bandwidth + 1):
            self.network_bandwidth = bw * 10**6  # 转换为Bps
            self.logger.info(f"带宽={bw}MBps")
            self.inference_time_cache = {}  # 清除缓存
            # 获取状态
            state = self.get_state()
            
            # 使用模型预测最佳动作
            action = self.act(state)  # epsilon为0，总是选择最高Q值的动作
            
            # 执行动作并计算奖励
            action_result = self.execute_action(action)
            reward = self.calculate_reward(action_result['result'])
            latency = action_result['result'].get('total_time', 0)
            
            # 记录结果
            results[bw] = {
                'action': int(action),
                'action_description': self.describe_action(action),
                'reward': float(reward),
                'latency': float(latency)
            }
            
            if bw % 10 == 0 or bw == 1 or bw == max_bandwidth:
                self.logger.info(f"Validation bandwidth={bw}MBps: Action={self.describe_action(action)}, "
                               f"Reward={reward:.4f}, Latency={latency:.4f}s")
    
    finally:
        # 恢复原始配置
        self.network_bandwidth = original_bw
        self.batch_size = original_bs
        self.epsilon = original_epsilon
    
    # 保存验证结果，加入模型名称
    validation_filename = f"validation_results_{self.model_name}_bs{fixed_bs}.json"
    with open(os.path.join('result', validation_filename), 'w') as f:
        # 转换为可JSON序列化格式
        json_results = {str(k): v for k, v in results.items()}
        json.dump(json_results, f, indent=4)
    
    self.logger.info(f"Validation results saved as {validation_filename}")
    
    return results

In [None]:
# 将已定义的方法绑定到类上
DQNSplitTrainer.remember = remember
DQNSplitTrainer.act = act
DQNSplitTrainer.update_target_model = update_target_model
DQNSplitTrainer.replay = replay
DQNSplitTrainer.train_dqn_adaptive = train_dqn_adaptive
DQNSplitTrainer._find_best_action_for_bandwidth = _find_best_action_for_bandwidth
DQNSplitTrainer._priority_replay = _priority_replay
DQNSplitTrainer._analyze_model_errors = _analyze_model_errors
DQNSplitTrainer._evaluate_model_accuracy = _evaluate_model_accuracy
DQNSplitTrainer._focused_training_for_problem_bandwidths = _focused_training_for_problem_bandwidths
# 添加缺失的方法绑定
DQNSplitTrainer._save_training_results = _save_training_results
DQNSplitTrainer._plot_training_progress = _plot_training_progress
DQNSplitTrainer.predict_optimal_split_dqn = predict_optimal_split_dqn
DQNSplitTrainer.predict_action_with_qvalues = predict_action_with_qvalues
DQNSplitTrainer.validate_trained_model = validate_trained_model
# 添加新的方法绑定
DQNSplitTrainer._load_or_calculate_best_actions = _load_or_calculate_best_actions

def main():
    """主函数，执行DQN训练过程"""
    #model_name = 'vit-l32'
    #model_name = 'vgg-16'
    model_name = 'ResNet50'
    #model_name = 'efficientnet-b4'
    fixed_bs = 5  # 固定批处理大小为1
    if model_name == 'ResNet50':
        patch_dnn_splitter()
    # 配置日志
    log_filename = os.path.join('log', f"dqn_training_{model_name}_bs{fixed_bs}.log")
    # 创建处理器
    file_handler = logging.FileHandler(log_filename)
    console_handler = logging.StreamHandler()
    # 设置格式
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    #logging.getLogger('absl').setLevel(logging.ERROR)
    
    # 配置根日志记录器
    global logger  # 使用全局变量，以便其他方法可以访问
    logger = logging.getLogger('DQNSplitTrainer')
    logger.setLevel(logging.INFO)
    
    # 清除可能存在的旧处理器
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    try:
        logger.info("========== 开始运行DQN模型训练系统 ==========")

        # 配置
        config = {
            'model_name': model_name,
            'image_size': 224,
            'edge_nodes': [
                {
                    'name': 'edge1',
                    'device': '/GPU:1',  
                },
            ],
            'cloud_device': '/GPU:0',
            'batch_size': list(range(1, 65)),  # 批处理大小范围
            'max_bandwidth': 128 * 10**6,  # 最大带宽 128 MBps
            'min_bandwidth': 1 * 10**6,    # 最小带宽 1 MBps
            'bandwidth_step': 1 * 10**6,    # 带宽步长 1 MBps
        }
        
        # 创建DQN训练器实例
        trainer = DQNSplitTrainer(config)
        
        # 1. 训练DQN模型
        logger.info("\n=== 开始训练DQN模型 ===")
        training_results = trainer.train_dqn_adaptive(fixed_bs=fixed_bs, max_episodes=800)

        # 验证模型
        validation_results = trainer.validate_trained_model(fixed_bs=fixed_bs)
        
        logger.info("\n===== 所有训练过程完成 =====")
        
    except Exception as e:
        logger.error(f"执行过程中发生错误: {str(e)}", exc_info=True)
        raise e
    
if __name__ == "__main__":
    results = main()
    print("DQN训练完成！")