In [1]:
from config import SEQUENCE_LENGTH, TRAIN_RATIO, VAL_RATIO, TEST_RATIO, SOLAR_DATA_PATH, COSMIC_DATA_PATH
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta
from pathlib import Path

In [2]:
# data to use
SOLAR_PARAMETERS = ['HMF', 'wind_speed', 'HCS_tilt', 'polarity', 'SSN', 'daily_OSF']
HELIUM_FLUX_COL = 'helium_flux m^-2sr^-1s^-1GV^-1'

# rigidity bins in GV
RIGIDITY_BIN_EDGES = [1.71, 1.92, 2.15, 2.4, 2.67, 2.97, 3.29, 3.64, 4.02, 
                      4.43, 4.88, 5.37, 5.9, 6.47, 7.09, 7.76, 8.48, 9.26, 10.1]
# rigidity_min GV
RIGIDITY_VALUES = RIGIDITY_BIN_EDGES[:-1]

# def load_and_check_data():
"""读数据，检查缺失/重复日期，插值补全"""
# 太阳数据
solar_data = pd.read_csv(SOLAR_DATA_PATH)
solar_data['date'] = pd.to_datetime(solar_data['date'])

solar_data

# 宇宙线数据
cosmic_data = pd.read_csv(COSMIC_DATA_PATH)
cosmic_data['date YYYY-MM-DD'] = pd.to_datetime(cosmic_data['date YYYY-MM-DD'])

cosmic_data

# 筛选所需的刚度数据并重组
cosmic_multi_rigidity = []
for rigidity in RIGIDITY_VALUES:
    rigidity_data = cosmic_data[cosmic_data['rigidity_min GV'] == rigidity].copy()
    if len(rigidity_data) > 0:
        rigidity_data = rigidity_data[['date YYYY-MM-DD', HELIUM_FLUX_COL]].copy()
        rigidity_data = rigidity_data.rename(columns={HELIUM_FLUX_COL: f'helium_{rigidity}GV'})
        cosmic_multi_rigidity.append(rigidity_data)
    else:
        print(f"警告: 刚度 {rigidity} GV 没有数据")

type(cosmic_multi_rigidity)

cosmic_multi_rigidity[0]

# 合并所有刚度数据
if cosmic_multi_rigidity:
    cosmic_data = cosmic_multi_rigidity[0]
    for i in range(1, len(cosmic_multi_rigidity)):
        cosmic_data = cosmic_data.merge(cosmic_multi_rigidity[i], on='date YYYY-MM-DD', how='outer')
else:
    raise ValueError("没有找到任何刚度数据")
cosmic_data

cosmic_data.to_csv('test.csv')

# 更新氦通量列名列表
helium_flux_cols = [f'helium_{rigidity}GV' for rigidity in RIGIDITY_VALUES if f'helium_{rigidity}GV' in cosmic_data.columns]
print(f"成功加载 {len(helium_flux_cols)} 个刚度的数据: {[col.split('_')[-1] for col in helium_flux_cols]}")

# --- Debug alignment information ---
print("=== Data Alignment Debug ===")
# Print ranges
print(
    f"Solar data range: {solar_data['date'].min()} to {solar_data['date'].max()}"
)
print(
    f"Cosmic data range: {cosmic_data['date YYYY-MM-DD'].min()} to {cosmic_data['date YYYY-MM-DD'].max()}"
)

# Check total counts
print(f"Total solar days: {len(solar_data)}")
print(f"Total cosmic days before interpolation: {len(cosmic_data)}")

# 检查数据中是否有缺失日期
solar_dates = pd.to_datetime(solar_data['date'])
cosmic_dates = pd.to_datetime(cosmic_data['date YYYY-MM-DD'])
full_solar = pd.date_range(start=solar_dates.min(), end=solar_dates.max(), freq='D')
full_cosmic = pd.date_range(start=cosmic_dates.min(), end=cosmic_dates.max(), freq='D')
missing_solar = set(full_solar) - set(solar_dates)
missing_cosmic = set(full_cosmic) - set(cosmic_dates)
print(f"Missing solar days: {len(missing_solar)}")
if missing_solar:
    first5 = sorted(list(missing_solar))[:5]
    print(f"First 5 missing solar dates: {first5}")
print(f"Missing cosmic days: {len(missing_cosmic)}")
if missing_cosmic:
    first5_cos = sorted(list(missing_cosmic))[:5]
    print(f"First 5 missing cosmic dates: {first5_cos}")

# 检查数据中是否有重复的日期
solar_dups = solar_data[solar_data.duplicated('date', keep=False)]
cosmic_dups = cosmic_data[cosmic_data.duplicated('date YYYY-MM-DD', keep=False)]
print(f"Duplicate solar dates: {len(solar_dups)}")
print(f"Duplicate cosmic dates: {len(cosmic_dups)}")

##########################################################################
# --- 插值补全 ---
print("Interpolating cosmic data...")
full_range = pd.date_range(start=cosmic_dates.min(), end=cosmic_dates.max(), freq='D')
# 生成完整的日期范围
cosmic_data = cosmic_data.set_index('date YYYY-MM-DD').reindex(full_range)
# 线性插值填充缺失值 - 对所有刚度列进行插值
for col in helium_flux_cols:
    if col in cosmic_data.columns:
        cosmic_data[col] = cosmic_data[col].interpolate(method='linear')
# 把插值后 DataFrame 的索引（即日期）还原成普通列，并把列名改回原来的名字
# save the interpolated cosmic data
cosmic_data.to_csv("interpolated_cosmic_data.csv", index=False)
cosmic_data = (cosmic_data
    .reset_index()
    .rename(columns={'index': 'date YYYY-MM-DD'})
)
print(f"Total cosmic days after interpolation: {len(cosmic_data)}")
print("宇宙线数据已按日线性插值补全完成。")

# Final summary
print("Final data summary:")
print(
    f"Solar data: {len(solar_data)} days ({solar_data['date'].min()} - {solar_data['date'].max()})"
)
print(
    f"Cosmic data: {len(cosmic_data)} days "
    f"({cosmic_data['date YYYY-MM-DD'].min()} - {cosmic_data['date YYYY-MM-DD'].max()})"
)

# return solar_data, cosmic_data

成功加载 18 个刚度的数据: ['1.71GV', '1.92GV', '2.15GV', '2.4GV', '2.67GV', '2.97GV', '3.29GV', '3.64GV', '4.02GV', '4.43GV', '4.88GV', '5.37GV', '5.9GV', '6.47GV', '7.09GV', '7.76GV', '8.48GV', '9.26GV']
=== Data Alignment Debug ===
Solar data range: 1985-01-01 00:00:00 to 2024-12-25 00:00:00
Cosmic data range: 2011-05-20 00:00:00 to 2019-10-29 00:00:00
Total solar days: 14604
Total cosmic days before interpolation: 2824
Missing solar days: 0
Missing cosmic days: 261
First 5 missing cosmic dates: [Timestamp('2011-06-04 00:00:00'), Timestamp('2012-11-01 00:00:00'), Timestamp('2013-09-10 00:00:00'), Timestamp('2013-09-11 00:00:00'), Timestamp('2013-10-21 00:00:00')]
Duplicate solar dates: 0
Duplicate cosmic dates: 0
Interpolating cosmic data...
Total cosmic days after interpolation: 3085
宇宙线数据已按日线性插值补全完成。
Final data summary:
Solar data: 14604 days (1985-01-01 00:00:00 - 2024-12-25 00:00:00)
Cosmic data: 3085 days (2011-05-20 00:00:00 - 2019-10-29 00:00:00)


In [4]:
# def create_sequences(solar_data, cosmic_data, sequence_length=SEQUENCE_LENGTH):
sequence_length=SEQUENCE_LENGTH
"""
每个样本输入：过去SEQUENCE_LENGTH天的[太阳参数*len(SOLAR_PARAMETERS) + helium_flux*len(RIGIDITY_VALUES)]，
输出：第SEQUENCE_LENGTH+1天的所有刚度的helium_flux
"""

print(f"\n=== 创建 {sequence_length} 天序列（太阳参数+多刚度宇宙线流强） ===")
helium_flux_cols = [f'helium_{rigidity}GV' for rigidity in RIGIDITY_VALUES if f'helium_{rigidity}GV' in cosmic_data.columns]
features = SOLAR_PARAMETERS + helium_flux_cols
print(f"输入特征: {features}")


=== 创建 180 天序列（太阳参数+多刚度宇宙线流强） ===
输入特征: ['HMF', 'wind_speed', 'HCS_tilt', 'polarity', 'SSN', 'daily_OSF', 'helium_1.71GV', 'helium_1.92GV', 'helium_2.15GV', 'helium_2.4GV', 'helium_2.67GV', 'helium_2.97GV', 'helium_3.29GV', 'helium_3.64GV', 'helium_4.02GV', 'helium_4.43GV', 'helium_4.88GV', 'helium_5.37GV', 'helium_5.9GV', 'helium_6.47GV', 'helium_7.09GV', 'helium_7.76GV', 'helium_8.48GV', 'helium_9.26GV']


In [None]:
X = []
y = []
dates = []
successful_alignments = 0
failed_alignments = 0
# cosmic_data按日期排序，保证滑窗正确
cosmic_data = cosmic_data.sort_values('date YYYY-MM-DD').reset_index(drop=True)

for idx in range(len(cosmic_data) - sequence_length):
    # 输入窗口的起止日期
    input_start = cosmic_data.loc[idx, 'date YYYY-MM-DD']
    input_end = cosmic_data.loc[idx + sequence_length - 1, 'date YYYY-MM-DD']
    output_date = cosmic_data.loc[idx + sequence_length, 'date YYYY-MM-DD']
    
    # 构造输入序列
    input_rows = []
    for i in range(sequence_length):
        date_i = cosmic_data.loc[idx + i, 'date YYYY-MM-DD']
        # 查找太阳参数
        solar_mask = solar_data['date'] == date_i
        if solar_mask.sum() == 1:
            solar_row = solar_data[solar_mask][SOLAR_PARAMETERS].iloc[0].values
            # 获取所有刚度的氦通量数据
            helium_flux_row = cosmic_data.loc[idx + i, helium_flux_cols].values
            input_rows.append(np.concatenate([solar_row, helium_flux_row]))
        else:
            break
    
    if len(input_rows) == sequence_length:
        X.append(np.array(input_rows))
        # 输出是所有刚度的氦通量
        y.append(cosmic_data.loc[idx + sequence_length, helium_flux_cols].values)
        dates.append(output_date)
        successful_alignments += 1
        if successful_alignments <= 3:
            print(f"\n样例 {successful_alignments}: 输入 {input_start} 到 {input_end}, 输出 {output_date}")
            print(f"  输入形状: {np.array(input_rows).shape}")
            print(f"  输出形状: {len(helium_flux_cols)} 个刚度")
    else:
        failed_alignments += 1
        if failed_alignments <= 3:
            print(f"失败样例 {failed_alignments}: 只找到 {len(input_rows)} 天数据，需要 {sequence_length} 天")

X = np.array(X)
y = np.array(y)
print(f"\n=== 序列创建结果 ===")
print(f"成功对齐: {successful_alignments} 个样例")
print(f"失败对齐: {failed_alignments} 个样例")
print(f"最终数据形状:")
print(f"  X: {X.shape} (样例数, 时间步数, 特征数)")
print(f"  y: {y.shape} (样例数, 刚度数)")
print(f"  特征顺序: {features}")
print(f"  刚度输出顺序: {helium_flux_cols}")

print(f"\n=== 数据质量检查 ===")
if X.shape[0] > 0:
    # 转换为numpy数组并检查NaN
    X_array = np.array(X, dtype=np.float64)
    y_array = np.array(y, dtype=np.float64)
    
    print(f"X 中的 NaN 数量: {np.isnan(X_array).sum()}")
    print(f"y 中的 NaN 数量: {np.isnan(y_array).sum()}")
    
    print(f"\nX 统计 (所有特征):")
    for i, feature in enumerate(features):
        feature_data = X_array[:, :, i].flatten()
        if not np.all(np.isnan(feature_data)):
            print(f"  {feature}: 均值={np.nanmean(feature_data):.4f}, 标准差={np.nanstd(feature_data):.4f}, 范围=[{np.nanmin(feature_data):.2f}, {np.nanmax(feature_data):.2f}]")
        else:
            print(f"  {feature}: 全部为NaN")
    
    print(f"\ny 统计 (所有刚度):")
    for i, rigidity in enumerate(helium_flux_cols):
        rigidity_data = y_array[:, i].flatten()
        if not np.all(np.isnan(rigidity_data)):
            print(f"  {rigidity}: 均值={np.nanmean(rigidity_data):.4f}, 标准差={np.nanstd(rigidity_data):.4f}, 范围=[{np.nanmin(rigidity_data):.2f}, {np.nanmax(rigidity_data):.2f}]")
        else:
            print(f"  {rigidity}: 全部为NaN")
else:
    print("没有数据可检查")


=== 创建 180 天序列（太阳参数+多刚度宇宙线流强） ===
输入特征: ['HMF', 'wind_speed', 'HCS_tilt', 'polarity', 'SSN', 'daily_OSF', 'helium_1.71GV', 'helium_1.92GV', 'helium_2.15GV', 'helium_2.4GV', 'helium_2.67GV', 'helium_2.97GV', 'helium_3.29GV', 'helium_3.64GV', 'helium_4.02GV', 'helium_4.43GV', 'helium_4.88GV', 'helium_5.37GV', 'helium_5.9GV', 'helium_6.47GV', 'helium_7.09GV', 'helium_7.76GV', 'helium_8.48GV', 'helium_9.26GV']

样例 1: 输入 2011-05-20 00:00:00 到 2011-11-15 00:00:00, 输出 2011-11-16 00:00:00
  输入形状: (180, 24)
  输出形状: 18 个刚度

样例 2: 输入 2011-05-21 00:00:00 到 2011-11-16 00:00:00, 输出 2011-11-17 00:00:00
  输入形状: (180, 24)
  输出形状: 18 个刚度

样例 3: 输入 2011-05-22 00:00:00 到 2011-11-17 00:00:00, 输出 2011-11-18 00:00:00
  输入形状: (180, 24)
  输出形状: 18 个刚度

=== 序列创建结果 ===
成功对齐: 2905 个样例
失败对齐: 0 个样例
最终数据形状:
  X: (2905, 180, 24) (样例数, 时间步数, 特征数)
  y: (2905, 18) (样例数, 刚度数)
  特征顺序: ['HMF', 'wind_speed', 'HCS_tilt', 'polarity', 'SSN', 'daily_OSF', 'helium_1.71GV', 'helium_1.92GV', 'helium_2.15GV', 'helium_2.4GV', 'heli