In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import time

# 参数配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEQ_LEN = 24        # 输入序列长度
PRED_LEN = 1        # 预测长度
BATCH_SIZE = 512
num_epochs = 10
CNN_CHANNELS = 32   # CNN输出通道数

# 数据预处理
def load_data(csv_path, seq_len=SEQ_LEN, split_ratio=(0.8, 0.2)):
    df = pd.read_csv(csv_path, parse_dates=['date'])
    data = df[['OT']].values.reshape(-1, 1)  # 转换为(n_samples, 1)
    
    scaler = StandardScaler()
    train_split = int(len(data) * split_ratio[0])
    scaler.fit(data[:train_split])
    data_scaled = scaler.transform(data)
    
    # 创建滑动窗口
    X, y = [], []
    for i in range(len(data_scaled) - seq_len - PRED_LEN + 1):
        X.append(data_scaled[i:i+seq_len])
        y.append(data_scaled[i+seq_len:i+seq_len+PRED_LEN])
    
    X = np.array(X)[:, :, np.newaxis, :]  # 调整为CNN所需的(样本, 序列长度, 通道, 特征)
    y = np.array(y)
    
    # 划分数据集
    X_train, y_train = X[:train_split], y[:train_split]
    X_test, y_test = X[train_split:], y[train_split:]
    
    # 转换为Tensor并调整维度为(batch, channel, seq_len, feature)
    X_train = torch.FloatTensor(X_train).permute(0, 2, 1, 3).to(device)  # (B, C, L, F)
    y_train = torch.FloatTensor(y_train).to(device)
    X_test = torch.FloatTensor(X_test).permute(0, 2, 1, 3).to(device)
    y_test = torch.FloatTensor(y_test).to(device)
    
    return (X_train, y_train), (X_test, y_test), scaler

# 数据集类
class TimeSeriesDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return {'input': self.X[idx], 'target': self.y[idx]}

# CNN+LSTM模型
class CNNLSTM(nn.Module):
    def __init__(self, input_channels=1, cnn_channels=CNN_CHANNELS, 
                 lstm_hidden=50, output_size=1):
        super().__init__()
        self.cnn = nn.Conv2d(input_channels, cnn_channels, kernel_size=(3, 1), padding=(1, 0))
        self.lstm = nn.LSTM(cnn_channels, lstm_hidden, num_layers=1, batch_first=True)
        self.fc = nn.Linear(lstm_hidden, output_size)
        
    def forward(self, x):
        # x shape: (B, C, L, F) -> (B, C, L, 1)
        x = self.cnn(x)  # (B, C, L, 1)
        x = x.permute(0, 2, 1, 3).squeeze(-1)  # 转换为(B, L, C)
        out, _ = self.lstm(x)
        return self.fc(out[:, -1, :])  # 取最后一个时间步输出

# CNN+MinLSTM模型
class CNNMinLSTM(nn.Module):
    def __init__(self, input_channels=1, cnn_channels=CNN_CHANNELS, 
                 minilstm_hidden=50, output_size=1):
        super().__init__()
        self.cnn = nn.Conv2d(input_channels, cnn_channels, kernel_size=(3, 1), padding=(1, 0))
        self.minilstm = MinLSTM(cnn_channels, minilstm_hidden, output_size)
        
    def forward(self, x):
        # x shape: (B, C, L, F) -> (B, C, L, 1)
        x = self.cnn(x)  # (B, C, L, 1)
        x = x.permute(0, 2, 1, 3).squeeze(-1)  # 转换为(B, L, C)
        return self.minilstm(x)  # 输入到MinLSTM

# MinLSTM 模型定义
class MinLSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int, device=None, dtype=None):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.linear = nn.Linear(input_size, hidden_size * 3, bias=False)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device, dtype=x.dtype)
        f, i, h = torch.chunk(self.linear(x), chunks=3, dim=-1)
        diff = F.softplus(-f) - F.softplus(-i)
        log_f = -F.softplus(diff)
        log_i = -F.softplus(-diff)
        log_h_0 = self.log_g(h_prev)
        log_tilde_h = self.log_g(h)
        log_coeff = log_f.unsqueeze(1)
        log_val = torch.cat([log_h_0.unsqueeze(1), (log_i + log_tilde_h)], dim=1)
        h_t = self.parallel_scan_log(log_coeff, log_val)
        output = self.output_layer(h_t[:, -1, :])
        return output

    def parallel_scan_log(self, log_coeffs, log_values):
        a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0)).squeeze(1)
        log_h0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=1).squeeze(1)
        log_h = a_star + log_h0_plus_b_star
        return torch.exp(log_h)

    def g(self, x):
        return torch.where(x >= 0, x + 0.5, torch.sigmoid(x))

    def log_g(self, x):
        return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))

# 训练和评估函数
def train_and_evaluate(model, train_dataloader, test_dataloader,
                      loss_fn, optimizer, num_epochs=10, scheduler=None):
    start_time = time.time()
    train_losses = []
    memory_usage = []
    
    for epoch in range(num_epochs):
        model.train()
        training_loss = 0.0
        
        for idx, batch in enumerate(train_dataloader):
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
            
            inputs = batch['input']
            targets = batch['target'].squeeze(-1)  # 调整维度
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            
            loss.backward()
            optimizer.step()
            
            if torch.cuda.is_available():
                peak_mem = torch.cuda.max_memory_allocated()
                memory_usage.append(peak_mem)
            
            training_loss += loss.item()
            train_losses.append(loss.item())
            
            if idx % 100 == 0:
                print(f'Epoch: {epoch}, Step: {idx}, Train Loss: {training_loss/(idx+1):.4f}')

        avg_train_loss = training_loss/len(train_dataloader)
        print(f'Epoch: {epoch} => Avg Train Loss: {avg_train_loss:.4f}')
        
        if scheduler:
            scheduler.step()
    
    total_time = time.time() - start_time
    avg_memory = sum(memory_usage)/len(memory_usage)/(1024**2) if memory_usage else 0
    
    model.eval()
    y_pred, y_true = [], []
    with torch.no_grad():
        for batch in test_dataloader:
            inputs = batch['input']
            targets = batch['target'].squeeze(-1)
            outputs = model(inputs)
            y_pred.extend(outputs.cpu().numpy())
            y_true.extend(targets.cpu().numpy())
    
    y_pred = scaler.inverse_transform(np.array(y_pred))
    y_true = scaler.inverse_transform(np.array(y_true))
    mse = mean_squared_error(y_true, y_pred)
    
    print(f'Test MSE: {mse:.4f}')
    
    return train_losses, mse, total_time, avg_memory

if __name__ == "__main__":
    # 加载数据（请替换为实际数据路径）
    (X_train, y_train), (X_test, y_test), scaler = load_data(
        r'/root/hh/ETT-small/ETTh1.csv'
    )
    
    # 创建数据加载器
    train_dataset = TimeSeriesDataset(X_train, y_train)
    test_dataset = TimeSeriesDataset(X_test, y_test)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    
    # 结果存储
    results = {}
    
    # 训练CNN+LSTM模型
    print("\n=== Training CNN+LSTM ===")
    cnn_lstm = CNNLSTM().to(device)
    
    # 初始化权重
    for name, param in cnn_lstm.named_parameters():
        if 'weight' in name:
            nn.init.xavier_uniform_(param)
        elif 'bias' in name:
            nn.init.constant_(param, 0.0)
    
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(cnn_lstm.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    cnn_lstm_losses, cnn_lstm_mse, cnn_lstm_time, cnn_lstm_mem = train_and_evaluate(
        cnn_lstm, train_loader, test_loader, loss_fn, optimizer, num_epochs, scheduler
    )
    
    results['CNN+LSTM'] = {
        'mse': cnn_lstm_mse,
        'time': cnn_lstm_time,
        'memory': cnn_lstm_mem,
        'losses': cnn_lstm_losses
    }
    
    # 训练CNN+MinLSTM模型
    print("\n=== Training CNN+MinLSTM ===")
    cnn_minilstm = CNNMinLSTM().to(device)
    
    # 初始化权重
    for name, param in cnn_minilstm.named_parameters():
        if 'weight' in name:
            nn.init.xavier_uniform_(param)
        elif 'bias' in name:
            nn.init.constant_(param, 0.0)
    
    optimizer = torch.optim.Adam(cnn_minilstm.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    cnn_minilstm_losses, cnn_minilstm_mse, cnn_minilstm_time, cnn_minilstm_mem = train_and_evaluate(
        cnn_minilstm, train_loader, test_loader, loss_fn, optimizer, num_epochs, scheduler
    )
    
    results['CNN+MinLSTM'] = {
        'mse': cnn_minilstm_mse,
        'time': cnn_minilstm_time,
        'memory': cnn_minilstm_mem,
        'losses': cnn_minilstm_losses
    }
    
    # 可视化比较结果
    plt.figure(figsize=(18, 5))
    
    # 1. 训练损失曲线比较
    plt.subplot(1, 3, 1)
    plt.plot(results['CNN+LSTM']['losses'], label='CNN+LSTM Loss', alpha=0.7)
    plt.plot(results['CNN+MinLSTM']['losses'], label='CNN+MinLSTM Loss', alpha=0.7)
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training Loss Comparison')
    plt.legend()
    plt.grid(True)
    
    # 2. MSE比较
    plt.subplot(1, 3, 2)
    models = list(results.keys())
    mse_values = [results[model]['mse'] for model in models]
    bars_mse = plt.bar(models, mse_values, color=['blue', 'orange'])
    
    for bar in bars_mse:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                 f'{height:.4f}', ha='center', va='bottom')
    
    plt.title('Test MSE Comparison')
    plt.ylabel('MSE')
    
    # 3. 训练时间和内存比较
    plt.subplot(1, 3, 3)
    x = np.arange(len(models))
    width = 0.35
    
    ax1 = plt.subplot(1, 3, 3)
    time_values = [results[model]['time'] for model in models]
    mem_values = [results[model]['memory'] for model in models]
    
    bars_time = ax1.bar(x - width/2, time_values, width, label='Training Time (s)', color='blue')
    ax2 = ax1.twinx()
    bars_mem = ax2.bar(x + width/2, mem_values, width, label='Memory Usage (MB)', color='orange')
    
    for bar in bars_time:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                 f'{height:.2f}s', ha='center', va='bottom')
    
    for bar in bars_mem:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                 f'{height:.2f}MB', ha='center', va='bottom')
    
    ax1.set_xlabel('Models')
    ax1.set_ylabel('Time (s)')
    ax2.set_ylabel('Memory (MB)')
    ax1.set_xticks(x)
    ax1.set_xticklabels(models)
    ax1.legend(loc='upper left')
    ax2.legend(loc='upper right')
    
    plt.tight_layout()
    plt.show()
    
    # 打印详细比较结果
    print("\n=== 性能比较汇总 ===")
    print(f"{'模型':<12} {'MSE':<10} {'训练时间(秒)':<12} {'内存使用(MB)'}")
    print("-" * 45)
    for model, result in results.items():
        print(f"{model:<12} {result['mse']:<10.4f} {result['time']:<12.2f} {result['memory']:.2f}")

In [None]:
# 可视化比较结果
plt.figure(figsize=(12, 6))

# 1. 训练损失曲线及MSE横线
plt.subplot(1, 1, 1)
plt.plot(results['CNN+LSTM']['losses'], label='CNN+LSTM Loss', alpha=0.7)
plt.plot(results['CNN+MinLSTM']['losses'], label='CNN+MinLSTM Loss', alpha=0.7)
plt.axhline(y=results['CNN+LSTM']['mse'], color='blue', linestyle='--', 
            label=f'CNN+LSTM MSE: {results["CNN+LSTM"]["mse"]:.4f}')
plt.axhline(y=results['CNN+MinLSTM']['mse'], color='orange', linestyle='--', 
            label=f'CNN+MinLSTM MSE: {results["CNN+MinLSTM"]["mse"]:.4f}')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss and Test MSE Comparison')
plt.legend()
plt.grid(True)
plt.show()

# 2. 训练时间对比
plt.figure(figsize=(8, 5))
models = list(results.keys())
time_values = [results[model]['time'] for model in models]
bars = plt.bar(models, time_values, color=['blue', 'orange'])

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.2f}s', ha='center', va='bottom')

plt.title('Training Time Comparison')
plt.ylabel('Time (Seconds)')
plt.grid(axis='y', alpha=0.3)
plt.show()

# 3. 内存消耗对比
plt.figure(figsize=(8, 5))
mem_values = [results[model]['memory'] for model in models]
bars = plt.bar(models, mem_values, color=['blue', 'orange'])

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.2f}MB', ha='center', va='bottom')

plt.title('Memory Usage Comparison')
plt.ylabel('Memory (MB)')
plt.grid(axis='y', alpha=0.3)
plt.show()

In [None]:
# 可视化比较结果
# 1. 训练损失曲线（前100步）
plt.figure(figsize=(12, 6))
plt.plot(results['CNN+LSTM']['losses'][:100], label='CNN+LSTM Loss', alpha=0.7)
plt.plot(results['CNN+MinLSTM']['losses'][:100], label='CNN+MinLSTM Loss', alpha=0.7)
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss Comparison (First 100 Steps)')
plt.legend()
plt.grid(True)
plt.show()

# 2. MSE对比柱状图
plt.figure(figsize=(8, 5))
models = list(results.keys())
mse_values = [results[model]['mse'] for model in models]
bars = plt.bar(models, mse_values, color=['blue', 'orange'])

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.4f}', ha='center', va='bottom')

plt.title('Test MSE Comparison')
plt.ylabel('MSE')
plt.grid(axis='y', alpha=0.3)
plt.show()

# 3. 训练时间对比
plt.figure(figsize=(8, 5))
time_values = [results[model]['time'] for model in models]
bars = plt.bar(models, time_values, color=['blue', 'orange'])

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.2f}s', ha='center', va='bottom')

plt.title('Training Time Comparison')
plt.ylabel('Time (Seconds)')
plt.grid(axis='y', alpha=0.3)
plt.show()

# 4. 内存消耗对比
plt.figure(figsize=(8, 5))
mem_values = [results[model]['memory'] for model in models]
bars = plt.bar(models, mem_values, color=['blue', 'orange'])

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.2f}MB', ha='center', va='bottom')

plt.title('Memory Usage Comparison')
plt.ylabel('Memory (MB)')
plt.grid(axis='y', alpha=0.3)
plt.show()