In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.fft import fft, fftfreq
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import os
import common_io
import pandas as pd
from tqdm import tqdm
from scipy.signal import find_peaks, savgol_filter
from statsmodels.tsa.seasonal import STL
from sklearn.linear_model import LinearRegression
from pandas.tseries.holiday import USFederalHolidayCalendar  # 示例用美国节日，需替换为中国节日

os.environ['ODPS_CONFIG_FILE_PATH'] = '/mnt/workspace/HaochengZhang/odps_config.ini'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import json
import pandas as pd
from pandas.tseries.holiday import AbstractHolidayCalendar, Holiday  # 添加缺失的导入

# 加载节假日配置
with open('/mnt/workspace/HaochengZhang/周期检测/节假日.json') as f:
    holiday_config = json.load(f)


def generate_holiday_rules(config):
    rules = []
    for year_str, holidays in config.items():
        year = int(year_str)
        for name, dates in holidays.items():
            start = pd.to_datetime(dates[0])
            end = pd.to_datetime(dates[1])
            current_date = start
            while current_date <= end:
                rules.append(
                    Holiday(
                        f"{name}_{year}_{current_date.day}",
                        month=current_date.month,
                        day=current_date.day,
                    )
                )
                current_date += pd.DateOffset(days=1)
    return rules

# 继承AbstractHolidayCalendar
class ChinaHolidayCalendar(AbstractHolidayCalendar):
    rules = generate_holiday_rules(holiday_config)

# 参数配置区
class Config:
    # 周期检测参数
    MIN_PERIOD = 26  # 最小周期周数（约半年）
    MAX_PERIOD = 80  # 最大周期周数（1年）
    TOP_N = 2       # 保留前N个周期
    SNR_THRESHOLD = 0.85  # 信噪比阈值
    
    # 预处理参数
    TREND_WINDOW = 13    # 趋势提取滑动窗口（季度级）
    SMOOTH_WINDOW = 5    # 平滑窗口大小
    MIN_DATA_DAYS = 56   # 最小数据天数（8周）
    
    # 节假日参数
    HOLIDAY_RANGE = 3    # 节假日前后影响天数

cfg = Config()

# 数据加载（保持原逻辑）
def load_test_data(reader):
    total_records = reader.get_row_count()
    data_origin = reader.read(total_records)
    schema = reader.get_schema()
    col_name = [schema[i][0] for i in range(len(schema))]
    data = pd.DataFrame(data_origin, columns=col_name)
    x_col_name = ['se_ipvuv_1w', 'term', 'week_encode', 'ds']
    return data[x_col_name]
# 修改后的preprocess_ts函数
def preprocess_ts(df_term, cfg):
    df_sorted = df_term.sort_values('ds')
    full_dates = pd.date_range(start=df_sorted['ds'].min(), end=df_sorted['ds'].max())
    
    # 重新索引并填充非数值列
    df_full = df_sorted.set_index('ds').reindex(full_dates).reset_index()
    df_full.rename(columns={'index': 'ds'}, inplace=True)
    df_full['term'] = df_full['term'].ffill().bfill()
    df_full['week_encode'] = df_full['week_encode'].ffill().bfill()
    
    # 生成节日日历
    cal = ChinaHolidayCalendar()
    holidays = cal.holidays(start=df_full['ds'].min(), end=df_full['ds'].max())
    
    # 标记节假日影响范围
    df_full['is_holiday'] = 0
    for day in holidays:
        if isinstance(day, pd.Timestamp):
            start_range = day - pd.Timedelta(days=cfg.HOLIDAY_RANGE)
            end_range = day + pd.Timedelta(days=cfg.HOLIDAY_RANGE)
            mask = (df_full['ds'] >= start_range) & (df_full['ds'] <= end_range)
            df_full.loc[mask, 'is_holiday'] = 1
    
    # 节假日期间线性插值
    holiday_mask = df_full['is_holiday'] == 1
    df_full['se_ipvuv_1w'] = df_full['se_ipvuv_1w'].mask(holiday_mask).interpolate()
    
    
    # ---- 缺失值处理 ----
    df_full['se_ipvuv_1w'] = df_full['se_ipvuv_1w'].interpolate(method='linear').ffill().bfill()
    
    # ---- 趋势分解 ----
    # 使用STL分解（保留季节性）
    try:
        stl = STL(df_full['se_ipvuv_1w'], period=52, seasonal=13)
        res = stl.fit()
        df_full['trend'] = res.trend
        df_full['seasonal'] = res.seasonal
        df_full['resid'] = res.resid
    except:
        # 异常时回退到移动平均
        df_full['trend'] = df_full['se_ipvuv_1w'].rolling(window=cfg.TREND_WINDOW, min_periods=1).mean()
        df_full['seasonal'] = df_full['se_ipvuv_1w'] - df_full['trend']
    
    # ---- 降噪处理 ----
    # ts = df_full['seasonal'].values
    ts = df_full['trend'].values
    ts = savgol_filter(ts, window_length=cfg.SMOOTH_WINDOW, polyorder=2)
    
    # 标准化
    scaler = StandardScaler()
    return scaler.fit_transform(ts.reshape(-1, 1)).flatten(), df_full

# 增强的周期检测
def detect_periods(ts, cfg):
    n = len(ts)
    if n < 2 * cfg.MAX_PERIOD:
        return [np.nan] * cfg.TOP_N
    
    # 加窗处理
    window = np.hanning(n)
    ts_windowed = ts * window
    
    # FFT计算
    fft_vals = fft(ts_windowed)
    freqs = fftfreq(n)[1:n//2]
    power = np.abs(fft_vals[1:n//2]) ** 2  # 使用功率谱
    
    # 筛选周期范围
    valid_mask = (1/freqs >= cfg.MIN_PERIOD) & (1/freqs <= cfg.MAX_PERIOD)
    valid_freqs = freqs[valid_mask]
    valid_power = power[valid_mask]
    
    if len(valid_power) == 0:
        return [np.nan] * cfg.TOP_N
    
    # 自适应阈值
    noise_floor = np.percentile(valid_power, 30)
    threshold = noise_floor + cfg.SNR_THRESHOLD * (np.max(valid_power) - noise_floor)
    
    # 找峰值
    peaks, _ = find_peaks(valid_power, height=threshold, distance=cfg.MIN_PERIOD//2)
    
    # 结果过滤
    candidates = []
    for idx in peaks:
        period = 1 / valid_freqs[idx]
        # 排除谐波干扰
        if not any(abs(period - p) < 2 for p in candidates):
            candidates.append(period)
    
    # 取前TOP_N个，按功率排序
    if candidates:
        # 找到每个候选周期对应的最近频率索引
        candidates.sort(key=lambda x: -valid_power[np.argmin(np.abs(valid_freqs - 1/x))])
        periods = np.round(candidates[:cfg.TOP_N]).astype(int).tolist()
    else:
        periods = []
    
    # 填充不足部分为NaN
    periods += [np.nan] * (cfg.TOP_N - len(periods))
    return periods

# 批量处理
def batch_detect_periods(df, cfg):
    results = []
    for term, group in tqdm(df.groupby('term'), desc='Processing Terms'):
        try:
            ts, df_full = preprocess_ts(group, cfg)
            if len(ts) < cfg.MIN_DATA_DAYS:
                results.append({'term': term, 'periods': [np.nan]*cfg.TOP_N})
                continue
            
            periods = detect_periods(ts, cfg)
            # 结果验证
            valid_periods = []
            for p in periods:
                if p >= cfg.MIN_PERIOD and p <= cfg.MAX_PERIOD:
                    # 检查自相关性
                    lag = int(p)
                    if len(ts) > 2*lag:
                        corr = np.corrcoef(ts[:-lag], ts[lag:])[0,1]
                        if corr > 0.3:
                            valid_periods.append(p)
            valid_periods += [np.nan]*(cfg.TOP_N - len(valid_periods))
            
            results.append({
                'term': term,
                'periods': valid_periods,
                'data_days': len(ts)
            })
        except Exception as e:
            print(f"Error processing {term}: {str(e)}")
            results.append({'term': term, 'periods': [np.nan]*cfg.TOP_N})
    
    result_df = pd.DataFrame(results)
    # 展开周期列
    period_cols = [f'period_{i+1}' for i in range(cfg.TOP_N)]
    result_df[period_cols] = pd.DataFrame(result_df['periods'].tolist(), index=result_df.index)
    return result_df.drop(columns=['periods'])

# 主流程
if __name__ == "__main__":
    # 初始化reader
    reader = common_io.table.TableReader(
        "odps://new_retail_algo/tables/457291_opp_query_atlas_splited_result_to_be_detected_formatted_alpha/write_ds=20250404",
        selected_cols="",
        excluded_cols="",
        slice_id=0,
        slice_count=1,
        num_threads=1,
        capacity=2048
    )
    
    # 加载数据
    df = load_test_data(reader)
    df['ds'] = pd.to_datetime(df['ds'], format='%Y%m%d')
    
    # 执行检测
    period_df = batch_detect_periods(df, cfg)
    
    # 合并结果
    result_df = df.merge(period_df, on='term', how='left')
    
    # 过滤有效结果
    valid_mask = result_df[[f'period_{i+1}' for i in range(cfg.TOP_N)]].notna().any(axis=1)
    filtered_df = result_df[valid_mask].copy()
    
    # 结果分析
    print("\n检测结果统计：")
    print(f"总term数：{len(result_df['term'].unique())}")
    print(f"有效term数：{len(filtered_df['term'].unique())}")
    print("\n典型周期分布：")
    print(filtered_df[[f'period_{i+1}' for i in range(cfg.TOP_N)]].describe())
    print(filtered_df['term'].unique())
