In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pywt
from pytorch_wavelets import DWT1D, IDWT1D
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn.functional as F
import math
from typing import List, Tuple
import gc
from thop import profile, clever_format

# 可逆实例归一化层 - 改进版本
class RevIN(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True, subtract_last=False):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.subtract_last = subtract_last
        
        if self.affine:
            self._init_params()
    
    def _init_params(self):
        self.affine_weight = nn.Parameter(torch.ones(1, 1, self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(1, 1, self.num_features))
    
    def forward(self, x, mode):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        return x
    
    def _get_statistics(self, x):
        dim2reduce = 1
        if self.subtract_last:
            self.last = x[:, -1:].detach()
        else:
            self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
    
    def _normalize(self, x):
        if self.subtract_last:
            x = x - self.last
        else:
            x = x - self.mean
        x = x / (self.stdev + self.eps)
        if self.affine:
            x = x * self.affine_weight + self.affine_bias
        return x
    
    def _denormalize(self, x):
        if self.affine:
            x = (x - self.affine_bias) / (self.affine_weight + self.eps)
        x = x * self.stdev
        if self.subtract_last:
            x = x + self.last
        else:
            x = x + self.mean
        return x

# 改进的自适应小波分解模块
class AdaptiveWaveletDecomposition(nn.Module):
    def __init__(self, input_length, pred_length, wavelet_name='db4', level=3, 
                 channel=1, device='cpu', no_decomposition=False):
        super().__init__()
        self.wavelet_name = wavelet_name
        self.level = level
        self.channel = channel
        self.device = device
        self.input_length = input_length
        self.pred_length = pred_length
        self.no_decomposition = no_decomposition
        
        # 可学习的阈值参数 [通道数, 分解层数]
        if not no_decomposition:
            self.thresholds = nn.Parameter(torch.randn(channel, level) * 0.1)
            # 小波系数处理网络
            self.coeff_processor = nn.ModuleList([
                nn.Sequential(
                    nn.Conv1d(channel, channel, kernel_size=3, padding=1, groups=channel),
                    nn.ReLU(),
                    nn.Conv1d(channel, channel, kernel_size=3, padding=1, groups=channel)
                ) for _ in range(level + 1)
            ])
        else:
            self.thresholds = None
        
        # 计算分解后的系数长度
        self.input_w_dim, self.pred_w_dim = self._calculate_dimensions()

    def _calculate_dimensions(self):
        if self.no_decomposition:
            return [self.input_length], [self.pred_length]
        
        try:
            # 输入序列的系数长度
            dwt = DWT1D(wave=self.wavelet_name, J=self.level).to(self.device)
            dummy_input = torch.ones(1, self.channel, self.input_length).to(self.device)
            with torch.no_grad():
                coeffs = dwt(dummy_input)
                approx_len = coeffs[0].shape[-1]
                detail_lens = [c.shape[-1] for c in coeffs[1]]
            input_dims = [approx_len] + detail_lens
            
            # 预测序列的系数长度
            dummy_pred = torch.ones(1, self.channel, self.pred_length).to(self.device)
            with torch.no_grad():
                coeffs_pred = dwt(dummy_pred)
                approx_len_pred = coeffs_pred[0].shape[-1]
                detail_lens_pred = [c.shape[-1] for c in coeffs_pred[1]]
            pred_dims = [approx_len_pred] + detail_lens_pred
            
            return input_dims, pred_dims
        except:
            return [self.input_length], [self.pred_length]

    def transform(self, x):
        if self.no_decomposition:
            return x, []
        
        try:
            # 执行小波分解
            dwt = DWT1D(wave=self.wavelet_name, J=self.level).to(self.device)
            coeffs = dwt(x)
            cA, cD_list = coeffs
            
            # 处理近似系数
            cA_processed = self.coeff_processor[0](cA)
            
            # 应用可学习阈值到细节系数并处理
            cD_thresholded = []
            for i in range(self.level):
                cD_i = cD_list[i]
                th = torch.sigmoid(self.thresholds[:, i]).view(1, self.channel, 1)
                cD_i = torch.sign(cD_i) * torch.clamp(torch.abs(cD_i) - th, min=0)
                cD_i = self.coeff_processor[i+1](cD_i)
                cD_thresholded.append(cD_i)
            
            return cA_processed, cD_thresholded
        except:
            return x, []

    def inv_transform(self, yA, yD):
        if self.no_decomposition:
            return yA
        
        try:
            idwt = IDWT1D(wave=self.wavelet_name).to(self.device)
            return idwt((yA, yD))
        except:
            return yA

# 改进的高频通路：动态脉冲注意力
class SpikeAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=4, spike_threshold=0.3, dropout=0.1):
        super().__init__()
        self.spike_threshold = spike_threshold
        self.embed_dim = embed_dim
        
        # 动态调整头数
        if embed_dim % num_heads != 0:
            # 找到能整除embed_dim的最大头数
            for n in range(num_heads, 0, -1):
                if embed_dim % n == 0:
                    num_heads = n
                    break
        self.num_heads = max(1, num_heads)
        self.head_dim = embed_dim // self.num_heads
        
        # 改进的脉冲检测层
        self.spike_detector = nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim * 2, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(embed_dim * 2, embed_dim, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
        # 注意力层
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim)
        )
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 残差连接
        residual = x
        
        # 脉冲检测
        x_conv = self.spike_detector(x.permute(0, 2, 1)).permute(0, 2, 1)
        spike_map = x_conv
        
        # 识别脉冲点
        spike_mask = spike_map.mean(dim=-1) > self.spike_threshold
        
        # 提取稀疏点的特征
        sparse_idx = torch.nonzero(spike_mask, as_tuple=True)
        
        if len(sparse_idx[0]) == 0:
            # 如果没有检测到脉冲，使用所有点
            sparse_features = x.reshape(-1, self.embed_dim)
            sparse_batch_idx = torch.arange(batch_size, device=x.device).repeat_interleave(seq_len)
            sparse_seq_idx = torch.arange(seq_len, device=x.device).repeat(batch_size)
        else:
            sparse_features = x[sparse_idx]
            sparse_batch_idx, sparse_seq_idx = sparse_idx
        
        # 注意力计算
        Q = self.q_proj(sparse_features)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        # 多头注意力
        Q = Q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(2, 0, 1, 3)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(2, 0, 1, 3)
        
        K = K.reshape(self.num_heads, -1, self.head_dim)
        V = V.reshape(self.num_heads, -1, self.head_dim)
        
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(0, 1).contiguous().view(-1, self.embed_dim)
        
        # 重建完整序列
        output = torch.zeros_like(x)
        if len(sparse_idx[0]) == 0:
            output = output.view(-1, self.embed_dim)
            output = self.out_proj(attn_output).view(batch_size, seq_len, self.embed_dim)
        else:
            output[sparse_idx] = self.out_proj(attn_output)
        
        # 残差连接和层归一化
        output = self.norm1(output + residual)
        
        # 前馈网络
        ffn_output = self.ffn(output)
        output = self.norm2(output + ffn_output)
        
        return output

# 改进的多周期稀疏注意力模块
class CycleAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=4, period_list=[24, 168], dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        
        # 动态调整头数
        if embed_dim % num_heads != 0:
            for n in range(num_heads, 0, -1):
                if embed_dim % n == 0:
                    num_heads = n
                    break
        self.num_heads = max(1, num_heads)
        self.head_dim = embed_dim // self.num_heads
        
        self.period_list = period_list
        
        # 可学习的周期权重
        self.period_weights = nn.Parameter(torch.ones(len(period_list)))
        
        # 周期嵌入
        self.period_embedding = nn.Embedding(len(period_list), embed_dim)
        
        # 三种注意力机制
        # 1. Intra-Period Attention (周期内注意力)
        self.intra_period_attention = nn.MultiheadAttention(
            embed_dim, self.num_heads, dropout=dropout, batch_first=True
        )
        
        # 2. Inter-Period Attention (周期间注意力)
        self.inter_period_attention = nn.MultiheadAttention(
            embed_dim, self.num_heads, dropout=dropout, batch_first=True
        )
        
        # 3. Phase Attention (相位注意力)
        self.phase_attention = nn.MultiheadAttention(
            embed_dim, self.num_heads, dropout=dropout, batch_first=True
        )
        
        # 层归一化
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.norm_out = nn.LayerNorm(embed_dim)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        
        # 融合权重
        self.intra_weight = nn.Parameter(torch.tensor(1.0))
        self.inter_weight = nn.Parameter(torch.tensor(1.0))
        self.phase_weight = nn.Parameter(torch.tensor(1.0))
        
        self.dropout = nn.Dropout(dropout)
    
    def fold_tensor(self, x, period):
        """将序列折叠为3D张量 [batch, num_cycles, period, features]"""
        batch_size, seq_len, features = x.shape
        num_cycles = seq_len // period
        
        # 截断到完整周期
        truncated_len = num_cycles * period
        x_truncated = x[:, :truncated_len, :]
        
        # 重塑为 [batch, num_cycles, period, features]
        folded = x_truncated.view(batch_size, num_cycles, period, features)
        
        return folded, num_cycles
    
    def unfold_tensor(self, folded, target_seq_len):
        """将3D张量展开回2D序列"""
        batch_size, num_cycles, period, features = folded.shape
        unfolded = folded.reshape(batch_size, num_cycles * period, features)
        
        # 如果长度不匹配，填充零
        if unfolded.size(1) < target_seq_len:
            padding = torch.zeros(
                batch_size, target_seq_len - unfolded.size(1), features,
                device=unfolded.device
            )
            unfolded = torch.cat([unfolded, padding], dim=1)
        
        return unfolded
    
    def compute_intra_period_attention(self, folded):
        """计算周期内注意力"""
        batch_size, num_cycles, period, features = folded.shape
        
        # 重塑为 [batch * num_cycles, period, features]
        intra_input = folded.reshape(batch_size * num_cycles, period, features)
        
        # 自注意力计算
        intra_output, _ = self.intra_period_attention(
            intra_input, intra_input, intra_input
        )
        
        # 重塑回原始形状并归一化
        intra_output = intra_output.reshape(batch_size, num_cycles, period, features)
        intra_output = self.norm1(intra_output)
        
        return intra_output
    
    def compute_inter_period_attention(self, folded):
        """计算周期间注意力"""
        batch_size, num_cycles, period, features = folded.shape
        
        # 重塑为 [batch * period, num_cycles, features]
        inter_input = folded.permute(0, 2, 1, 3).reshape(
            batch_size * period, num_cycles, features
        )
        
        # 自注意力计算
        inter_output, _ = self.inter_period_attention(
            inter_input, inter_input, inter_input
        )
        
        # 重塑回原始形状并归一化
        inter_output = inter_output.reshape(
            batch_size, period, num_cycles, features
        ).permute(0, 2, 1, 3)
        inter_output = self.norm2(inter_output)
        
        return inter_output
    
    def compute_phase_attention(self, folded):
        """计算相位注意力"""
        batch_size, num_cycles, period, features = folded.shape
        
        # 创建相位矩阵: [batch, phase_position, num_cycles, features]
        phase_matrix = folded.permute(0, 2, 1, 3)  # [batch, period, num_cycles, features]
        
        # 计算每个相位的注意力
        phase_outputs = []
        for p in range(period):
            phase_slice = phase_matrix[:, p, :, :]  # [batch, num_cycles, features]
            phase_slice_output, _ = self.phase_attention(
                phase_slice, phase_slice, phase_slice
            )
            phase_outputs.append(phase_slice_output.unsqueeze(1))
        
        # 组合所有相位
        phase_output = torch.cat(phase_outputs, dim=1)  # [batch, period, num_cycles, features]
        
        # 转换回原始顺序并归一化
        phase_output = phase_output.permute(0, 2, 1, 3)
        phase_output = self.norm3(phase_output)
        
        return phase_output
    
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        residual = x
        
        all_period_outputs = []
        period_weights = torch.softmax(self.period_weights, dim=0)
        
        for idx, period in enumerate(self.period_list):
            if seq_len < period * 2:  # 需要至少两个完整周期
                continue
            
            # 1. 添加周期嵌入
            period_emb = self.period_embedding(torch.tensor(idx, device=x.device))
            period_emb = period_emb.view(1, 1, -1).expand(batch_size, seq_len, -1)
            x_with_period = x + period_emb
            
            # 2. 时间折叠
            folded, num_cycles = self.fold_tensor(x_with_period, period)
            
            if num_cycles < 2:  # 至少需要2个周期进行有意义的学习
                continue
            
            # 3. 计算三种注意力
            # 3.1 周期内注意力
            intra_output = self.compute_intra_period_attention(folded)
            
            # 3.2 周期间注意力
            inter_output = self.compute_inter_period_attention(folded)
            
            # 3.3 相位注意力
            phase_output = self.compute_phase_attention(folded)
            
            # 4. 加权融合三种注意力输出
            # 使用可学习的权重
            weights = torch.softmax(
                torch.stack([self.intra_weight, self.inter_weight, self.phase_weight]),
                dim=0
            )
            
            cycle_output = (
                weights[0] * intra_output +
                weights[1] * inter_output +
                weights[2] * phase_output
            )
            
            # 5. 展开回2D序列
            output_2d = self.unfold_tensor(cycle_output, seq_len)
            all_period_outputs.append(output_2d)
        
        if all_period_outputs:
            # 使用周期权重合并不同周期的结果
            weighted_outputs = []
            for i, out in enumerate(all_period_outputs):
                weight = period_weights[i]
                weighted_outputs.append(weight * out)
            
            # 求和并归一化
            combined_output = sum(weighted_outputs)
            combined_output = self.norm_out(combined_output)
        else:
            # 如果没有有效周期，返回原始输入
            combined_output = x
        
        # 残差连接
        output = combined_output + residual
        
        # 前馈网络
        ffn_output = self.ffn(output)
        output = output + self.dropout(ffn_output)
        
        return output

# 修复的位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # 修复：确保索引不超出范围
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])  # 对于奇数维度，去掉最后一个div_term
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# 修复的小波重构 + 维度融合
class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, input_length, pred_length, num_features, wavelet_level=3, dropout=0.1):
        super().__init__()
        self.input_length = input_length
        self.pred_length = pred_length
        self.num_features = num_features
        self.wavelet_level = wavelet_level
        
        # 维度融合层 - 使用MLP捕获跨维度依赖
        # 将输入展平，通过MLP，再重塑
        flattened_size = input_length * num_features
        
        # Dimension-Wise Fusion Layer (MLP架构)
        self.dim_fusion_mlp = nn.Sequential(
            # 第一个MLP
            nn.Linear(flattened_size, flattened_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            # 第二个MLP
            nn.Linear(flattened_size * 2, flattened_size),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        
        # 预测维度调整层
        self.output_adjust = nn.Linear(input_length, pred_length)
    
    def forward(self, reconstructed_features):
        """
        特征融合流程：
        1. 输入是经过小波重构的特征 [batch, L, C]
        2. 通过维度融合层捕获跨维度依赖
        3. 输出预测序列 [batch, T, C]
        """
        batch_size = reconstructed_features.size(0)
        
        # 展平特征以捕获跨维度依赖
        flattened = reconstructed_features.reshape(batch_size, -1)
        
        # 通过维度融合MLP
        fused = self.dim_fusion_mlp(flattened)
        
        # 重塑回 [batch, input_length, num_features]
        reshaped = fused.reshape(batch_size, self.input_length, self.num_features)
        
        # 调整时间维度到预测长度
        # 对每个特征单独处理
        output_features = []
        for i in range(self.num_features):
            feature_data = reshaped[:, :, i]  # [batch, input_length]
            # 使用线性层调整时间维度
            adjusted_feature = self.output_adjust(feature_data)  # [batch, pred_length]
            output_features.append(adjusted_feature.unsqueeze(2))
        
        # 合并所有特征
        output = torch.cat(output_features, dim=2)  # [batch, pred_length, num_features]
        
        return output

# 时间序列预测模型 
class TimeSeriesForecaster(nn.Module):
    def __init__(self, input_length, pred_length, num_features, 
                 wavelet_name='db4', wavelet_level=3, 
                 spike_heads=4, cycle_heads=4, dropout=0.1, device='cpu'):
        super().__init__()
        self.input_length = input_length
        self.pred_length = pred_length
        self.num_features = num_features
        self.device = device
        
        # 可逆实例归一化
        self.revin = RevIN(num_features)
        
        # 位置编码
        self.pos_encoder = PositionalEncoding(num_features)
        
        # 输入投影层
        self.input_projection = nn.Linear(num_features, num_features)
        
        # 自适应小波分解模块
        self.wavelet = AdaptiveWaveletDecomposition(
            input_length, pred_length, wavelet_name, wavelet_level, 
            num_features, device
        )
        
        # 高频通路：动态脉冲注意力
        self.spike_attention = SpikeAttention(
            embed_dim=num_features, 
            num_heads=spike_heads,
            dropout=dropout
        )
        
        # 低频通路：多周期稀疏注意力
        self.cycle_attention = CycleAttention(
            embed_dim=num_features,
            num_heads=cycle_heads,
            dropout=dropout
        )
        
        # 特征融合层
        self.feature_fusion = MultiScaleFeatureFusion(
            input_length=input_length,
            pred_length=pred_length,
            num_features=num_features,
            wavelet_level=wavelet_level,
            dropout=dropout
        )
        
        # 初始化权重
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv1d):
                nn.init.kaiming_uniform_(m.weight)
    
    def forward(self, x, mode='train'):
        # 可逆实例归一化（论文公式）
        x_norm = self.revin(x, 'norm')
        
        # 输入投影和位置编码
        x_proj = self.input_projection(x_norm)
        x_encoded = self.pos_encoder(x_proj)
        
        # 转换为通道在前格式 [batch, channels, seq_len] 用于小波分解
        x_permuted = x_encoded.permute(0, 2, 1)
        
        # 小波分解
        cA, cD_list = self.wavelet.transform(x_permuted)
        
        # 处理近似系数 (低频) - Multi-Period Transformer
        cA_processed = cA.permute(0, 2, 1)
        cA_processed = self.cycle_attention(cA_processed)
        cA_processed = cA_processed.permute(0, 2, 1)
        
        # 处理细节系数 (高频) - Spike Transformer
        cD_processed_list = []
        for cD in cD_list:
            cD_processed = cD.permute(0, 2, 1)
            cD_processed = self.spike_attention(cD_processed)
            cD_processed = cD_processed.permute(0, 2, 1)
            cD_processed_list.append(cD_processed)
        
        # 小波重构
        reconstructed = self.wavelet.inv_transform(cA_processed, cD_processed_list)
        
        # 转换回原始格式 [batch, seq_len, features]
        reconstructed_features = reconstructed.permute(0, 2, 1)
        
        # 按照论文的特征融合：维度融合层（Dimension-Wise Fusion）
        output = self.feature_fusion(reconstructed_features.contiguous())
        
        # 反归一化
        output_denorm = self.revin(output, 'denorm')
        
        return output_denorm

# 改进的数据加载和预处理函数
def load_etth_data(file_path):    
    """加载ETTh数据集并处理时间戳"""
    df = pd.read_csv(file_path)    
        # 转换时间戳为datetime对象
    df['date'] = pd.to_datetime(df['date'])    
        # 创建更丰富的时间特征
    df['hour_sin'] = np.sin(2 * np.pi * df['date'].dt.hour / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['date'].dt.hour / 24)
    df['day_sin'] = np.sin(2 * np.pi * df['date'].dt.day / 31)
    df['day_cos'] = np.cos(2 * np.pi * df['date'].dt.day / 31)
    df['month_sin'] = np.sin(2 * np.pi * df['date'].dt.month / 12)
    df['month_cos'] = np.cos(2 * np.pi * df['date'].dt.month / 12)
    df['weekday_sin'] = np.sin(2 * np.pi * df['date'].dt.weekday / 7)
    df['weekday_cos'] = np.cos(2 * np.pi * df['date'].dt.weekday / 7)
        
        # 移除原始日期列
    df.drop('date', axis=1, inplace=True)
        
    return df.values

def preprocess_data(data, train_ratio=0.7, val_ratio=0.1, test_ratio=0.2):    
    """划分数据集并进行归一化"""    
    # 确保比例总和为1
    total_ratio = train_ratio + val_ratio + test_ratio
    train_ratio /= total_ratio
    val_ratio /= total_ratio
    
    # 划分数据集 (按时间顺序)
    train_size = int(len(data) * train_ratio)
    val_size = int(len(data) * val_ratio)
    test_size = len(data) - train_size - val_size  
    
    train_data = data[:train_size]
    val_data = data[train_size:train_size + val_size]
    test_data = data[train_size + val_size:] 
    
    # 归一化处理 - 使用StandardScaler
    scalers = []    
    for i in range(data.shape[1]):
        scaler = StandardScaler()
        train_data[:, i] = scaler.fit_transform(train_data[:, i].reshape(-1, 1)).flatten()
        val_data[:, i] = scaler.transform(val_data[:, i].reshape(-1, 1)).flatten()
        test_data[:, i] = scaler.transform(test_data[:, i].reshape(-1, 1)).flatten()
        scalers.append(scaler)        
    
    return train_data, val_data, test_data, scalers

def create_inout_sequences(data, input_len, output_len, stride=1):    
    """创建时间序列输入-输出对"""
    X, y = [], []
    total_len = len(data)    
    for i in range(0, total_len - input_len - output_len + 1, stride):
        X.append(data[i:i+input_len, :])        
        y.append(data[i+input_len:i+input_len+output_len, :])      
    return np.array(X), np.array(y)

# 改进的训练和评估函数
class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, val_loss):
        if self.best_score is None:
            self.best_score = val_loss
        elif val_loss > self.best_score - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.counter = 0

def train_epoch(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0    
    
    for batch_idx, (batch_x, batch_y) in enumerate(data_loader):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
        
        # 定期清理内存
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache()
            gc.collect()
    
    return total_loss / len(data_loader)

def evaluate(model, data_loader, criterion, device):
    """评估函数，计算MSE和MAE"""
    model.eval()
    total_mse = 0.0
    total_mae = 0.0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():        
        for batch_x, batch_y in data_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(batch_y.cpu().numpy())
            
            # 计算MSE
            mse = criterion(outputs, batch_y)
            total_mse += mse.item()
            
            # 计算MAE
            mae = torch.mean(torch.abs(outputs - batch_y))
            total_mae += mae.item()
    
    avg_mse = total_mse / len(data_loader)
    avg_mae = total_mae / len(data_loader)
    
    # 计算RMSE
    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)
    rmse = np.sqrt(mean_squared_error(all_targets.flatten(), all_outputs.flatten()))
    
    return total_mse, total_mae, avg_mse, avg_mae, rmse

# 学习率调度器
def get_lr_scheduler(optimizer, scheduler_type='cosine', num_epochs=50):
    if scheduler_type == 'cosine':
        return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    elif scheduler_type == 'step':
        return optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    elif scheduler_type == 'plateau':
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    else:
        return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0)

# 主训练函数
def main():
    # 设置超参数
    input_length = 96
    pred_length = 96
    batch_size = 32  # 减小批大小以减少内存使用
    num_epochs = 50
    learning_rate = 0.001
    dropout = 0.1
    patience = 15
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # 加载数据
    data = load_etth_data('../data/ETTm2.csv')
    print(f"Data shape: {data.shape}")
    
    # 预处理数据
    train_data, val_data, test_data, scalers = preprocess_data(data)
    print(f"Train shape: {train_data.shape}, Val shape: {val_data.shape}, Test shape: {test_data.shape}")
    
    # 创建输入-输出序列
    X_train, y_train = create_inout_sequences(train_data, input_length, pred_length, stride=1)
    X_val, y_val = create_inout_sequences(val_data, input_length, pred_length, stride=1)
    X_test, y_test = create_inout_sequences(test_data, input_length, pred_length, stride=1)
    
    print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
    
    # 转换为PyTorch张量（不转移到GPU）
    X_train = torch.FloatTensor(X_train)
    y_train = torch.FloatTensor(y_train)
    X_val = torch.FloatTensor(X_val)
    y_val = torch.FloatTensor(y_val)
    X_test = torch.FloatTensor(X_test)
    y_test = torch.FloatTensor(y_test)
    
    # 创建数据加载器
    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, 
                             pin_memory=True, num_workers=2)
    
    val_dataset = TensorDataset(X_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                           pin_memory=True, num_workers=2)
    
    test_dataset = TensorDataset(X_test, y_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, num_workers=2)
    
    # 初始化模型
    num_features = data.shape[1]
    model = TimeSeriesForecaster(
        input_length=input_length,
        pred_length=pred_length,
        num_features=num_features,
        dropout=dropout,
        device=device
    ).to(device)
    
    # ==================== 计算模型GFLOPS和参数量 ====================
    print("\n" + "="*50)
    print("计算模型复杂度:")
    print("="*50)
    
    # 创建输入张量 [batch_size, seq_len, features]
    batch_size_flops = 1  # 使用batch_size=1进行计算
    dummy_input = torch.randn(batch_size_flops, input_length, num_features).to(device)
    
    try:
        # 使用thop计算FLOPs和参数量
        flops, params = profile(model, inputs=(dummy_input,), verbose=False)
        
        # 转换为GFLOPs
        gflops = flops / 1e9
        
        # 格式化输出
        flops_formatted, params_formatted = clever_format([flops, params], "%.3f")
        
        print(f"模型总参数量: {params_formatted}")
        print(f"模型计算量 (FLOPs): {flops_formatted}")
        print(f"单次前向传播计算量: {gflops:.4f} GFLOPs")
        print(f"输入形状: {dummy_input.shape}")
        print(f"输出形状: {model(dummy_input).shape}")
    except Exception as e:
        print(f"计算FLOPs时出错: {e}")
        print("请确保已安装thop库: pip install thop")
    
    
    
    
    # ==================== 计算完成 ====================
    
    # 定义损失函数和优化器
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = get_lr_scheduler(optimizer, 'plateau', num_epochs)
    
    # 早停机制
    early_stopping = EarlyStopping(patience=patience, delta=0.001)
    
    # 训练循环
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        train_losses.append(train_loss)
        
        # 验证
        val_mse, val_mae, avg_val_mse, avg_val_mae, val_rmse = evaluate(model, val_loader, criterion, device)
        val_losses.append(avg_val_mse)
        
        # 学习率调度
        if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(avg_val_mse)
        else:
            scheduler.step()
        
        # 早停检查
        early_stopping(avg_val_mse)
        
        if (epoch + 1) % 10 == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f'Epoch [{epoch+1}/{num_epochs}], LR: {current_lr:.6f}, '
                  f'Train Loss: {train_loss:.6f}, Val MSE: {avg_val_mse:.6f}, '
                  f'Val MAE: {avg_val_mae:.6f}, Val RMSE: {val_rmse:.6f}')
        
        # 保存最佳模型
        if avg_val_mse < best_val_loss:
            best_val_loss = avg_val_mse
            torch.save(model.state_dict(), 'best_time_series_forecaster_paper.pth')
        
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
    
    # 加载最佳模型进行测试
    model.load_state_dict(torch.load('best_time_series_forecaster_paper.pth'))
    
    # 测试
    test_mse, test_mae, avg_test_mse, avg_test_mae, test_rmse = evaluate(model, test_loader, criterion, device)
    print(f'Test Results - MSE: {avg_test_mse:.6f}, MAE: {avg_test_mae:.6f}, RMSE: {test_rmse:.6f}')
    
    # 保存最终模型
    torch.save(model.state_dict(), 'final_time_series_forecaster_paper.pth')
    print("Model saved successfully")
