In [None]:
# 添加以下指標:
# vwap標準差, 均值價, 傅立葉轉換, 卡爾曼綠波, 標準差波動率
# 協方差, 斜率, 動量, 價格一階變化, 價格二階變化, macd, macd的sin波轉換
# 對數收益率, 收益變化率, 每三根為一組他們形態上的(歐幾里得距離, 曼哈頓距離, cosine), 以及其形態上映射到常態分佈的數值
# OLS, ARIMA, GARCH, VECM, VAR, SVAR, XGBOOST, LightGBM, ..., 所有欄位兩兩之間的corelation, standardscaler, minmaxscaler, Box-cox, yeo-johnson

In [1]:
from smartmoneyconcepts import smc
import pandas as pd
import statsmodels.api as sm
from datetime import datetime, timedelta, time
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from scipy.fft import fft, fftfreq
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import mplfinance as mpf
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from statsmodels.tsa.stattools import adfuller, coint
from statsmodels.tsa.vector_ar.vecm import coint_johansen
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.stattools import acf, pacf
from collections import defaultdict
import math
from arch.unitroot import KPSS
from hurst import compute_Hc
import seaborn as sns
from scipy.integrate import quad
from scipy.spatial.distance import cityblock, euclidean, cosine
from scipy.stats import gaussian_kde, shapiro, probplot, skew, kurtosis, norm, jarque_bera, anderson, normaltest

def prepare_data(type_of_data, data_name):
    result = type_of_data.split('/')[0]
    tmp = pd.read_csv(f'../index_data/{type_of_data}/{data_name}.csv')
    if result == 'shioaji':
        tmp['ts'] = pd.to_datetime(tmp['ts'])
        tmp = tmp.rename(columns=lambda x: x.lower())
    else:
        tmp['ts'] = pd.to_datetime(tmp['datetime'])
        tmp = tmp.rename(columns=lambda x: x.lower())
    
    return tmp

def aggregate_pv_list(pv_lists):
    if pv_lists.isna().all():
        return np.nan
    
    combined_volume = defaultdict(float)
    for pv in pv_lists.dropna():
        if isinstance(pv, list):  # 確保是 list
            for item in pv:
                if isinstance(item, dict):
                    for price, vol in item.items():
                        combined_volume[float(price)] += float(vol)  # 轉 float 避免類型問題

    # 按價格降序排序並轉為 list of dict
    aggregated_pv = [{price: combined_volume[price]} for price in sorted(combined_volume, reverse=True)]
    return aggregated_pv

def analyze_tick_types(tick_type_series, volume_series):
    """
    分析該秒內的成交類型分布。
    外盤成交量為正數，內盤成交量為負數。
    """
    # 將每一筆的 tick_type (1 或 -1) 和 volume 相乘
    signed_volumes = [t * v for t, v in zip(tick_type_series, volume_series)]

    # 外盤成交量 = 所有正數的總和
    outer_vol = sum(vol for vol in signed_volumes if vol > 0)

    # 內盤成交量 = 所有負數的總和 (不取絕對值)
    inner_vol = sum(vol for vol in signed_volumes if vol < 0)

    return outer_vol + inner_vol

def calculate_liquidity_factors(bid_list, ask_list, bid_vol_list, ask_vol_list):
    try:
        # 1. 計算平均買賣價差 (Average Bid-Ask Spread)
        spreads = [a - b for a, b in zip(ask_list, bid_list)]
        avg_spread = np.mean(spreads) if spreads else np.nan

        # 2. 計算委託簿失衡 (Order Book Imbalance, OBI)
        obi_list = [
            bv / (bv + av) if (bv + av) > 0 else 0.5 
            for bv, av in zip(bid_vol_list, ask_vol_list)
        ]
        # 計算這一秒內所有 tick OBI 的平均值
        avg_obi = np.mean(obi_list) if obi_list else np.nan
        
        return pd.Series({
            'avg_spread': avg_spread,
            'avg_obi': avg_obi
        })
        
    except (ValueError, TypeError):
        return pd.Series({
            'avg_spread': np.nan,
            'avg_obi': np.nan
        })

def calculate_price_factors(close_list, volume_list):
    """
    計算秒級價格因子。
    """
    try:
        sum_price_volume = sum((float(c)) * float(v) for c, v in zip(close_list, volume_list))
        sum_price_sq_volume = sum((float(c) ** 2) * float(v) for c, v in zip(close_list, volume_list))
        
        # 計算 pv_list：按 close 價格聚合成交量
        volume_by_close = defaultdict(float)
        for close, volume in zip(close_list, volume_list):
            volume_by_close[close] += float(volume)
        
        # 轉為 pv_list 格式並按 close 價格降序排序
        pv_list = [{close: volume} for close, volume in sorted(volume_by_close.items(), key=lambda x: x[0], reverse=True)]
        
        return pd.Series({
            'sum_price_volume': sum_price_volume,
            'sum_price_sq_volume': sum_price_sq_volume,
            'pv_list': pv_list
        })
    except (ValueError, ZeroDivisionError, TypeError, IndexError):
        return pd.Series({
            'sum_price_volume': np.nan,
            'sum_price_sq_volume': np.nan,
            'pv_list': np.nan
        })

def calculate_capital_factors(bid_list, ask_list, bid_vol_list, ask_vol_list):
    """
    計算秒級資金因子。
    """
    try:
        # 計算名義金額
        notional_bids = [p * v for p, v in zip(bid_list, bid_vol_list)]
        notional_asks = [p * v for p, v in zip(ask_list, ask_vol_list)]
        
        # 計算淨資金流向
        monetary_delta = sum(notional_bids) - sum(notional_asks)
        volume_delta = sum(bid_vol_list) - sum(ask_vol_list)

        return pd.Series({
            'monetary_delta': monetary_delta,  # CVD 聚合用
            'volume_delta': volume_delta,      # 可選，純數量 delta
        })
    
    except (ValueError, ZeroDivisionError, TypeError):
        return pd.Series({
            'monetary_delta': np.nan,
            'volume_delta': np.nan,
        })

def pre_process(df1, data_type, date_start, date_end):
    # 假设你的df已经加载到dataframe
    df1['ts'] = pd.to_datetime(df1['ts'])  # 将ts列转换为datetime类型

    # 按秒分组，使用 named aggregation
    df1 = df1.groupby(df1['ts'].dt.floor('s')).agg(
        close=('close', 'last'),
        volume=('volume', 'sum'), # 總成交量
        close_list=('close', list),
        volume_list=('volume', list), # 成交量列表
        bid_price=('bid_price', lambda x: tuple(sorted(filter(lambda price: price != 0, x)))),
        ask_price=('ask_price', lambda x: tuple(sorted(filter(lambda price: price != 0, x)))),
        bid_list=('bid_price', list),
        ask_list=('ask_price', list),
        bid_vol_list=('bid_volume', list),
        ask_vol_list=('ask_volume', list),
        tick_type=('tick_type', lambda x: [1 if t == 1 else -1 if t == 2 else 0 for t in x])
    ).reset_index()
    
    # 分析每秒的內外盤成交量
    df1['flow_imbalance'] = df1.apply(lambda row: analyze_tick_types(row['tick_type'], row['volume_list']), axis=1)
    liquidity_features = df1.apply(lambda row: calculate_liquidity_factors(row['bid_list'], row['ask_list'], row['bid_vol_list'], row['ask_vol_list']), axis=1)
    df1 = pd.concat([df1, liquidity_features], axis=1)
    
    price_features = df1.apply(lambda row: calculate_price_factors(row['close_list'], row['volume_list']), axis=1)
    df1 = pd.concat([df1, price_features], axis=1)
    
    capital_features = df1.apply(lambda row: calculate_capital_factors(row['bid_list'], row['ask_list'], row['bid_vol_list'], row['ask_vol_list']), axis=1)
    df1 = pd.concat([df1, capital_features], axis=1)

    # 刪除 tick_type 原始資料
    df1.drop(columns=['tick_type', 'volume_list', 'bid_list', 'ask_list', 'bid_vol_list', 'ask_vol_list', 'close_list'], inplace=True)

    # 创建时间范围从開始到結束天數（或多个天数）
    time_range = pd.date_range(date_start, date_end, freq='s')

    # 将时间范围转换为DataFrame
    full_time_df = pd.DataFrame(time_range, columns=['ts'])
    
    if data_type == 's_day':
        # 通过检查时间是否在9:00:00到13:30:00之间来剔除跨天的数据
        valid_time_range = full_time_df['ts'].dt.time.between(pd.to_datetime('09:00:00').time(), pd.to_datetime('13:30:00').time())
        valid_time = full_time_df[valid_time_range]
        
    elif data_type == 'f_day':
        # 通过检查时间是否在9:00:00到13:30:00之间来剔除跨天的数据
        valid_time_range = full_time_df['ts'].dt.time.between(pd.to_datetime('08:45:00').time(), pd.to_datetime('13:45:00').time())
        valid_time = full_time_df[valid_time_range]
        
    elif data_type == 'f_night':
        t1 = pd.to_datetime('08:45:00').time()
        t2 = pd.to_datetime('13:45:00').time()
        t3 = pd.to_datetime('15:00:00').time()
        t4 = pd.to_datetime('05:00:00').time()
        
        # 定义有效时间范围
        def is_valid_time(time):
            # 日盘时间：09:00:00 到 13:30:00
            if time.time() >= t1 and time.time() <= t2:
                return True
            # 夜盘时间：15:00:00 到 隔天 04:00:00
            elif time.time() >= t3 or time.time() <= t4:
                return True
            else:
                return False

        # 筛选出有效时间
        valid_time = full_time_df[full_time_df['ts'].apply(is_valid_time)]

    # 合并df1和df2的结果，确保它们与mer_ori_data按秒对齐, 首先将df1和df2与mer_ori_data合并，使用'left'连接方式，以保留所有有效时间
    mer_ori_data = pd.merge(valid_time, df1, on='ts', how='left')

    # 设置'ts'为index
    mer_ori_data.set_index('ts', inplace=True)
    mer_ori_data = mer_ori_data.dropna()
    
    extra_df = mer_ori_data.resample('1min').agg({
        'flow_imbalance': 'sum',
        'avg_spread': 'mean',
        'avg_obi': 'mean',
        'sum_price_volume': 'sum',
        'sum_price_sq_volume': 'sum',
        'monetary_delta': 'sum',
        'volume_delta': 'sum',
        'pv_list': aggregate_pv_list
    })
    
    # 為了配合合併分K, 需把時間+1分鐘
    extra_df.index = extra_df.index + pd.Timedelta(minutes=1)

    # 過濾掉 bid_price 或 ask_price 為空 tuple 的行 (漲停或跌停)
    mer_ori_data = mer_ori_data[(mer_ori_data['bid_price'].map(len) > 0) & (mer_ori_data['ask_price'].map(len) > 0)]
    
    return mer_ori_data.dropna(), extra_df.dropna()

def analyze_relationship(series1, series2, significance_level=0.05):
    # 計算相關係數
    correlation = np.corrcoef(series1, series2)[0, 1]
    abs_corr = abs(correlation)
    
    # 判斷相關性強度
    if abs_corr > 0.8:
        correlation_strength = "高度相關"
    elif abs_corr > 0.5:
        correlation_strength = "中度相關"
    elif abs_corr > 0.3:
        correlation_strength = "低度相關"
    else:
        correlation_strength = "幾乎無相關"
    
    # 執行共整合檢定
    coint_t, p_value, _ = coint(series1, series2)
    
    # 判斷是否存在共整合關係
    if p_value <= significance_level:
        cointegration_status = "存在共整合關係"
    else:
        cointegration_status = "不存在共整合關係"
        
    print(f"相關係數: {correlation:.4f}，{correlation_strength}")
    print(f"共整合檢定的 p 值: {p_value:.4f}，{cointegration_status}")
    
    return correlation, correlation_strength, p_value, cointegration_status

def convert_ohlcv(df, freq=60):
    # 建立 session_type 與 session_start
    def classify_session(ts):
        time = ts.time()
        if datetime.strptime("08:45", "%H:%M").time() <= time <= datetime.strptime("13:45", "%H:%M").time():
            session_type = "day"
            session_date = ts.date()
            session_start = datetime.combine(session_date, datetime.strptime("08:45", "%H:%M").time())
        elif time >= datetime.strptime("15:00", "%H:%M").time():
            session_type = "night"
            session_date = ts.date()
            session_start = datetime.combine(session_date, datetime.strptime("15:00", "%H:%M").time())
        elif time <= datetime.strptime("05:00", "%H:%M").time():
            session_type = "night"
            session_date = (ts - timedelta(days=1)).date()
            session_start = datetime.combine(session_date, datetime.strptime("15:00", "%H:%M").time())
        else:
            return pd.Series(["other", pd.NaT])
        return pd.Series([session_type, session_start])

    df[["session_type", "session_start"]] = df.index.to_series().apply(classify_session)
    df = df[df["session_type"].isin(["day", "night"])]

    # 新增4: 補齊缺失的1分K資料
    def fill_missing_minutes(df_session, session_start, session_type):
        # 定義交易時段範圍
        if session_type == "day":
            start_time = datetime.combine(session_start.date(), datetime.strptime("08:45", "%H:%M").time())
            end_time = datetime.combine(session_start.date(), datetime.strptime("13:45", "%H:%M").time())
        else:  # night
            start_time = datetime.combine(session_start.date(), datetime.strptime("15:00", "%H:%M").time())
            end_time = datetime.combine(session_start.date() + timedelta(days=1), datetime.strptime("05:00", "%H:%M").time())

        # 生成完整的1分鐘時間序列
        full_time_index = pd.date_range(start=start_time, end=end_time, freq="1min")
        existing_times = df_session.index

        # 找出缺失的時間點
        missing_times = [t for t in full_time_index if t not in existing_times]
        
        if missing_times:
            # 為每個缺失時間點填充資料
            missing_data = []
            last_valid_row = None
            for t in missing_times:
                # 找到前一筆有效資料
                prev_time = t - timedelta(minutes=1)
                if prev_time in df_session.index:
                    last_valid_row = df_session.loc[prev_time]
                if last_valid_row is not None:
                    missing_data.append({
                        "ts": t,
                        "open": last_valid_row["close"],
                        "high": last_valid_row["close"],
                        "low": last_valid_row["close"],
                        "close": last_valid_row["close"],
                        "volume": 0,
                        "amount": 0,
                        "complete": True,
                        "session_type": session_type,
                        "session_start": session_start
                    })

            # 將缺失資料合併到原資料
            if missing_data:
                missing_df = pd.DataFrame(missing_data).set_index("ts")
                df_session = pd.concat([df_session, missing_df]).sort_index()

        return df_session

    # 按 session 分組並補齊缺失資料
    df_filled = []
    for session_start, session_data in df.groupby("session_start"):
        session_type = session_data["session_type"].iloc[0]
        session_data = fill_missing_minutes(session_data, session_start, session_type)
        df_filled.append(session_data)

    if df_filled:
        df = pd.concat(df_filled).sort_index()

    df.index = df.index - pd.Timedelta(minutes=1)

    # 設定 K 棒時間長度
    window = timedelta(minutes=freq)

    # 分段處理每個 session 的資料
    result = []

    for session_start, session_data in df.groupby("session_start"):
        current_time = session_start
        max_time = session_data.index.max()
        
        session_result = []  # 临时存储当前 session 的 K 棒数据

        while current_time < max_time:
            next_time = current_time + window
            window_data = session_data[(session_data.index >= current_time) & (session_data.index < next_time)]

            if not window_data.empty:
                o = window_data["open"].iloc[0]
                h = window_data["high"].max()
                l = window_data["low"].min()
                c = window_data["close"].iloc[-1]
                v = window_data["volume"].sum()
                complete = window_data.index[-1] >= next_time - timedelta(minutes=1)
                
                # 添加额外的列聚合
                flow_imbalance_agg = window_data['flow_imbalance'].sum()
                avg_spread_agg = window_data['avg_spread'].mean()
                avg_obi_agg = window_data['avg_obi'].mean()
                sum_price_volume = window_data['sum_price_volume'].sum()
                sum_price_sq_volume = window_data['sum_price_sq_volume'].sum()
                monetary_delta_agg = window_data['monetary_delta'].sum()
                volume_delta_agg = window_data['volume_delta'].sum()
                pv_list_agg = aggregate_pv_list(window_data['pv_list'])

                session_result.append({
                    "ts": current_time,
                    "open": o,
                    "high": h,
                    "low": l,
                    "close": c,
                    "volume": v,
                    "flow_imbalance": flow_imbalance_agg,
                    "avg_spread": avg_spread_agg,
                    "avg_obi": avg_obi_agg,
                    "sum_price_volume": sum_price_volume,
                    "sum_price_sq_volume": sum_price_sq_volume,
                    "monetary_delta": monetary_delta_agg,
                    "volume_delta": volume_delta_agg,
                    "pv_list": pv_list_agg,
                    "complete": complete
                })

            current_time = next_time
            
        # 对 session_result 进行累积计算
        if session_result:
            session_df = pd.DataFrame(session_result)
            session_df["acc_vol"] = session_df["volume"].cumsum()  # 在 session 内累积 volume
            session_df["acc_price_volume"] = session_df["sum_price_volume"].cumsum()  # 累积 sum_price_volume
            session_df["acc_price_sq_volume"] = session_df["sum_price_sq_volume"].cumsum()  # 累积 sum_price_sq_volume
            result.extend(session_df.to_dict('records'))

    # 建立新的 DataFrame
    agg_df = pd.DataFrame(result)
    agg_df.set_index("ts", inplace=True, drop=False)
    agg_df = agg_df.shift(1).dropna()

    return agg_df

def combine_daily_k_bars(df):
    df = df.copy()
    df.index = pd.to_datetime(df.index)
    
    # 提取日期部分
    df['date'] = df.index.date
    
    # 按日期分組並聚合，並強制 complete=True
    combined = df.groupby('date').agg({
        'open': 'first',      # 當天第一根 K 棒的開盤價
        'high': 'max',        # 當天最高價
        'low': 'min',         # 當天最低價
        'close': 'last',     # 當天最後一根 K 棒的收盤價
        'volume': 'sum',      # 當天總成交量
        'complete': lambda x: True  # 強制設為 True
    })
    
    # 重置 index 並設置為當天最後一根 K 棒的時間
    combined = combined.reset_index()
    combined['ts'] = pd.to_datetime(combined['date'])
    combined = combined.set_index('ts')
    combined = combined.drop('date', axis=1)
    
    return combined.shift(1).dropna()

def process_multiple_datasets(dataset1, dataset2, expensive_commodity, cheap_commodity):
    df1_list, df2_list, df3_list, df4_list, extra_list1, extra_list2 = [], [], [], [], [], []

    # 处理 tick 数据
    for type_of_data, data_name, date_start, date_end, data_type in dataset1:
        df = prepare_data(type_of_data, data_name)
        df, extra_data = pre_process(df, data_type, date_start, date_end)
        if data_name.startswith(expensive_commodity):
            df1_list.append(df)
            extra_list1.append(extra_data)
        elif data_name.startswith(cheap_commodity):
            df2_list.append(df)
            extra_list2.append(extra_data)

    for type_of_data, data_name in dataset2: # 如要使用shioaji的1分K資料
        df = prepare_data(type_of_data, data_name)
        df = df.set_index('ts')
        if data_name.startswith(expensive_commodity + 'k'):
            df3_list.append(df)
        elif data_name.startswith(cheap_commodity + 'k'):
            df4_list.append(df)
    
    # 合并 df1 和 df2
    df1 = pd.concat(df1_list) if df1_list else pd.DataFrame()
    df2 = pd.concat(df2_list) if df2_list else pd.DataFrame()
    df3 = pd.concat(df3_list) if df3_list else pd.DataFrame()
    df4 = pd.concat(df4_list) if df4_list else pd.DataFrame()
    extra_data1 = pd.concat(extra_list1) if extra_list1 else pd.DataFrame()
    extra_data2 = pd.concat(extra_list2) if extra_list2 else pd.DataFrame()
    
    # 对 df1 和 df2 进行按秒分组，保留最后一笔（处理重复时间戳）
    if not df1.empty:
        df1 = df1.groupby(df1.index).last()
    if not df2.empty:
        df2 = df2.groupby(df2.index).last()

    # 对 df3 和 df4 进行按时间戳分组，保留最后一笔（处理重复时间戳）
    if not df3.empty:
        df3 = df3.groupby(df3.index).last()
    if not df4.empty:
        df4 = df4.groupby(df4.index).last()

    # 合并 df1 和 df2
    df = pd.merge(df1, df2, left_index=True, right_index=True, how='inner', suffixes=('_df1', '_df2'))

    # 按时间索引排序并删除缺失值
    df = df.sort_index().dropna()

    # 对 df3 和 df4 按时间索引排序并删除缺失值
    if not df3.empty:
        df3 = df3.sort_index().dropna()
        df3 = pd.merge(
            df3,
            extra_data1,
            left_index=True,
            right_index=True,
            how="inner"
        )
    if not df4.empty:
        df4 = df4.sort_index().dropna()
        df4 = pd.merge(
            df4,
            extra_data2,
            left_index=True,
            right_index=True,
            how="inner"
        )

    return df, df3, df4

dataset1 = [
    ('shioaji/test', 'FXFR1', '2024-06-03 08:45:00', '2024-06-11 13:45:00', 'f_day'),
    ('shioaji/test', 'ZFFR1', '2024-06-03 08:45:00', '2024-06-11 13:45:00', 'f_day')
]

dataset2 = [
    ('shioaji/test', 'FXFR1k'),
    ('shioaji/test', 'ZFFR1k')
]

df, df3, df4 = process_multiple_datasets(dataset1, dataset2, 'FXFR1', 'ZFFR1') # 放入貴的, 便宜的

print("Index:", df.head(3).index.tolist())
print("Columns:", df.columns.tolist())

[1;33mThank you for using SmartMoneyConcepts! ⭐ Please show your support by giving a star on the GitHub repository: [4;34mhttps://github.com/joshyattridge/smart-money-concepts[0m
Index: [Timestamp('2024-06-03 08:46:12'), Timestamp('2024-06-03 08:49:41'), Timestamp('2024-06-03 08:50:06')]
Columns: ['close_df1', 'volume_df1', 'bid_price_df1', 'ask_price_df1', 'flow_imbalance_df1', 'avg_spread_df1', 'avg_obi_df1', 'sum_price_volume_df1', 'sum_price_sq_volume_df1', 'pv_list_df1', 'monetary_delta_df1', 'volume_delta_df1', 'close_df2', 'volume_df2', 'bid_price_df2', 'ask_price_df2', 'flow_imbalance_df2', 'avg_spread_df2', 'avg_obi_df2', 'sum_price_volume_df2', 'sum_price_sq_volume_df2', 'pv_list_df2', 'monetary_delta_df2', 'volume_delta_df2']


In [2]:
def calculate_rsi(series, period=14):
    delta = series.diff()
    gain = delta.where(delta > 0, 0.0)
    loss = -delta.where(delta < 0, 0.0)

    # 第一步：計算初始平均
    avg_gain = gain[:period].mean()
    avg_loss = loss[:period].mean()

    rsi_values = [None] * period  # 前 period 筆無 RSI
    for i in range(period, len(series)):
        current_gain = gain.iloc[i]
        current_loss = loss.iloc[i]

        avg_gain = (avg_gain * (period - 1) + current_gain) / period
        avg_loss = (avg_loss * (period - 1) + current_loss) / period

        rs = avg_gain / avg_loss if avg_loss != 0 else float('inf')
        rsi = 100 - (100 / (1 + rs))
        rsi_values.append(rsi)

    return pd.Series(rsi_values, index=series.index).dropna()

def calculate_bollinger_bands(series, window=20, num_std=2):
    """
    計算布林通道
    """
    rolling_mean = series.rolling(window=window).mean()
    rolling_std = series.rolling(window=window).std()
    upper_band = rolling_mean + (rolling_std * num_std)
    lower_band = rolling_mean - (rolling_std * num_std)
    
    return {
        'middle': rolling_mean,
        'upper': upper_band,
        'lower': lower_band,
        'std': rolling_std
    }

def calculate_atr(df, window=14):
    # 確保數據框包含必要的欄位
    if not all(col in df.columns for col in ['high', 'low', 'close']):
        raise ValueError("DataFrame must contain 'high', 'low', and 'close' columns")
    
    # 創建 DataFrame 的副本
    df = df.copy()
    
    # 使用 .loc 進行賦值
    df.loc[:, 'prev_close'] = df['close'].shift(1)  # 前一收盤價
    df.loc[:, 'tr1'] = df['high'] - df['low']
    df.loc[:, 'tr2'] = abs(df['high'] - df['prev_close'])
    df.loc[:, 'tr3'] = abs(df['low'] - df['prev_close'])
    
    # 真實波幅是三者的最大值
    df.loc[:, 'tr'] = df[['tr1', 'tr2', 'tr3']].max(axis=1)
    
    # 使用滾動窗口計算 ATR（簡單移動平均）
    atr = df['tr'].rolling(window=window).mean()
    
    # 清理臨時欄位
    df.drop(['prev_close', 'tr1', 'tr2', 'tr3', 'tr'], axis=1, inplace=True)
    
    return atr

def calculate_ema(df_close, span=5, type='ema'):
    if type == 'ema':
        return df_close.ewm(span=span, adjust=False).mean()
    elif type == 'sma':
        return df_close.rolling(window=span).mean()
    
def calculate_bias_ratio(df_close, ma_period, type='sma'):
    """
    計算單一資產價格與其移動平均線的乖離率 (Bias Ratio)
    
    參數:
        df_close (pd.Series): 資產的K棒收盤價序列
        ma_period (int): 移動平均週期 (K棒數)
    
    返回:
        pd.Series: 乖離率百分比序列 (已去除NaN值)
    """
    if type == 'sma':
        ma = df_close.rolling(window=ma_period).mean()
    
    elif type == 'ema':
        ma = df_close.ewm(span=ma_period, adjust=False).mean()
    
    # 計算乖離率: (當前價格 - 移動平均) / 移動平均 * 100%
    bias_ratio = (df_close - ma) / ma * 100
    
    return bias_ratio.fillna(0)

def calculate_ov(df_volume, span, type='sma'):
    if type == 'sma':
        # 計算短期 SMA
        short_ema = df_volume.rolling(window=(span//2)).mean()

        # 計算長期 SMA
        long_ema = df_volume.rolling(window=span).mean()
    
    elif type == 'ema':
        # 計算短期 EMA
        short_ema = df_volume.ewm(span=(span//2), adjust=False).mean()

        # 計算長期 EMA
        long_ema = df_volume.ewm(span=span, adjust=False).mean()
    
    # 計算 OV
    ov = short_ema - long_ema
    
    return ov

def calculate_obv(df):
    obv = [0]  # 初始 OBV 為 0
    for i in range(1, len(df)):
        if df['close'].iloc[i] > df['close'].iloc[i-1]:
            obv.append(obv[-1] + df['volume'].iloc[i])
        elif df['close'].iloc[i] < df['close'].iloc[i-1]:
            obv.append(obv[-1] - df['volume'].iloc[i])
        else:
            obv.append(obv[-1])
    return pd.Series(obv, index=df.index)

def align_time_series(df1, df2):
    """
    對齊兩個時間序列 DataFrame 的索引，只保留共有的時間點(K)。
    
    參數:
        df1 (pd.DataFrame): 第一個時間序列 DataFrame
        df2 (pd.DataFrame): 第二個時間序列 DataFrame
        
    返回:
        tuple: (對齊後的 df1, 對齊後的 df2) => K
    """
    # 確保索引是 datetime 格式
    if not isinstance(df1.index, pd.DatetimeIndex):
        df1.index = pd.to_datetime(df1.index)
    if not isinstance(df2.index, pd.DatetimeIndex):
        df2.index = pd.to_datetime(df2.index)
    
    # 找出共同的時間點（索引交集）
    common_index = df1.index.intersection(df2.index)
    
    # 檢查是否有共同時間點
    if len(common_index) == 0:
        raise ValueError("No common timestamps found between df1 and df2")
    
    # 對齊兩個 DataFrame 的索引
    df1_aligned = df1.loc[common_index]
    df2_aligned = df2.loc[common_index]
    
    return df1_aligned, df2_aligned

def calculate_smc_indicators(df, swing_length, **params):
    """
    計算 SMC 指標（FVG, OB, Liquidity, BOS, CHOCH）並返回統一的字典格式。
    
    參數:
    - df: Pandas DataFrame，包含 'open', 'high', 'low', 'close', 'volume' 列
    - swing_length: 波段高低點計算的窗口大小
    - close_break: BOS/CHOCH 是否基於收盤價突破
    
    返回:
    - 字典，包含以下 SMC 指標：
      - 'FVG': FVG 數據
      - 'OB': Order Block 數據
      - 'Liquidity': Liquidity 數據
      - 'BOS': BOS 數據（僅非空）
      - 'CHOCH': CHOCH 數據（僅非空）
      - 'Swing': 波段高低點數據
    """
    # 確保輸入 df 的索引是 DatetimeIndex
    df.index = pd.to_datetime(df.index)
    
    # 計算波段高低點
    swing_data = smc.swing_highs_lows(df, swing_length=swing_length)
    swing_data.index = df.index  # 對齊索引
    
    # 計算 FVG
    fvg_data = smc.fvg(df, True)
    fvg_data['MitigatedIndex'] = fvg_data['MitigatedIndex'].apply(lambda x: df.index[int(x)] if x >= 0 and x < len(df) else None)
    fvg_data.index = df.index  # 對齊索引
    
    # 計算 Order Blocks
    ob_data = smc.ob(df, swing_data)
    ob_data.index = df.index  # 對齊索引
    
    # 計算 Liquidity
    liquidity_data = smc.liquidity(df, swing_data)
    liquidity_data.index = df.index  # 對齊索引
    
    # 計算 BOS 和 CHOCH
    bos_choch_data = smc.bos_choch(df, swing_data, close_break=params.get('close_break', True))
    bos_choch_data.index = df.index  # 對齊索引
    
    # 計算PHL
    previous_high_low_data = smc.check_high_low(df)
    previous_high_low_data.index = df.index  # 對齊索引
    
    # 分離 BOS 和 CHOCH
    bos_data = bos_choch_data[bos_choch_data['BOS'].notna()][['BOS', 'Level', 'BrokenIndex']]
    choch_data = bos_choch_data[bos_choch_data['CHOCH'].notna()][['CHOCH', 'Level', 'BrokenIndex']]
    
    # 構建輸出字典
    smc_indicators = {
        'FVG': fvg_data.dropna(),
        'OB': ob_data.dropna(),
        'Liquidity': liquidity_data.dropna(),
        'BOS': bos_data.dropna(),
        'CHOCH': choch_data.dropna(),
        'Swing': swing_data.dropna(),
        'PHL': previous_high_low_data.dropna()
    }
    
    return smc_indicators

def detect_ma_compression(df, short_period=5, long_period=20, lookback=10):
    df_analysis = df.copy()
    
    # 1. 計算短期和長期 EMA (這部分不變)
    ema_short_col = f'EMA_{short_period}'
    ema_long_col = f'EMA_{long_period}'
    df_analysis[ema_short_col] = df_analysis['close'].ewm(span=short_period, adjust=False).mean()
    df_analysis[ema_long_col] = df_analysis['close'].ewm(span=long_period, adjust=False).mean()
    
    # 2. 向量化判斷交叉發生的確切時間點 (取代 for 迴圈)
    # 黃金交叉: 前一刻 short <= long AND 此刻 short > long
    golden_cross = (df_analysis[ema_short_col].shift(1) <= df_analysis[ema_long_col].shift(1)) & \
                   (df_analysis[ema_short_col] > df_analysis[ema_long_col])
    
    # 死亡交叉: 前一刻 short >= long AND 此刻 short < long
    death_cross = (df_analysis[ema_short_col].shift(1) >= df_analysis[ema_long_col].shift(1)) & \
                  (df_analysis[ema_short_col] < df_analysis[ema_long_col])
                  
    # 3. 使用滾動視窗計算在 lookback 期間內的交叉次數
    # .rolling(window=lookback) 創建滾動視窗
    # .sum() 會將布林值 (True=1, False=0) 加總，得到次數
    golden_cross_count_col = f'golden_cross_count_{lookback}'
    death_cross_count_col = f'death_cross_count_{lookback}'
    
    df_analysis[golden_cross_count_col] = golden_cross.rolling(window=lookback).sum()
    df_analysis[death_cross_count_col] = death_cross.rolling(window=lookback).sum()
    
    # 4. 返回包含滾動計數結果的 Series (以 DataFrame 形式返回更清晰)
    # 由於前 lookback-1 個值無法計算完整的窗格，會是 NaN，我們用 0 填充
    return df_analysis[[golden_cross_count_col, death_cross_count_col]].fillna(0).astype(int)

def add_indicators(
    df_input,
    ma_period=[5, 10, 20],
    rsi_period=[6, 12, 24],
    bias_period=[5, 10],
    bb_period=[20],
    atr_period=[7, 14],
    ov_period=[10, 20],
    ma_compress_configs=[{'short': 5, 'long': 20, 'lookback': 10}],
    swings=[50]
    ) -> pd.DataFrame:
    df = df_input.copy()

    # 計算 ov_sma 及其差值
    for period in ov_period:
        df[f'ov_{period}_sma'] = calculate_ov(df['volume'], span=period, type='sma')
    
    # 僅當 ov_sma 包含多於一個值時計算差值
    if len(ov_period) > 1:
        for i, short_period in enumerate(ov_period):
            for long_period in ov_period[i+1:]:
                df[f'ov_sma_{short_period}_minus_{long_period}'] = df[f'ov_{short_period}_sma'] - df[f'ov_{long_period}_sma']
    
    # 計算 ov_ema 及其差值
    for period in ov_period:
        df[f'ov_{period}_ema'] = calculate_ov(df['volume'], span=period, type='ema')
    
    # 僅當 ov_ema 包含多於一個值時計算差值
    if len(ov_period) > 1:
        for i, short_period in enumerate(ov_period):
            for long_period in ov_period[i+1:]:
                df[f'ov_ema_{short_period}_minus_{long_period}'] = df[f'ov_{short_period}_ema'] - df[f'ov_{long_period}_ema']
    
    # 計算 Bias_sma 及其差值
    for period in bias_period:
        df[f'bias_{period}_sma'] = calculate_bias_ratio(df['close'], ma_period=period, type='sma')
    
    # 僅當 bias_sma 包含多於一個值時計算差值
    if len(bias_period) > 1:
        for i, short_period in enumerate(bias_period):
            for long_period in bias_period[i+1:]:
                df[f'bias_sma_{short_period}_minus_{long_period}'] = df[f'bias_{short_period}_sma'] - df[f'bias_{long_period}_sma']
    
    # 計算 Bias_ema 及其差值
    for period in bias_period:
        df[f'bias_{period}_ema'] = calculate_bias_ratio(df['close'], ma_period=period, type='ema')
        
    # 僅當 bias_ema 包含多於一個值時計算差值
    if len(bias_period) > 1:
        for i, short_period in enumerate(bias_period):
            for long_period in bias_period[i+1:]:
                df[f'bias_ema_{short_period}_minus_{long_period}'] = df[f'bias_{short_period}_ema'] - df[f'bias_{long_period}_ema']
    
    # 計算 SMA 及其差值
    for period in ma_period:
        df[f'sma_{period}'] = calculate_ema(df['close'], span=period, type='sma')
    
    # 僅當 sma_period 包含多於一個值時計算差值
    if len(ma_period) > 1:
        for i, short_period in enumerate(ma_period):
            for long_period in ma_period[i+1:]:
                df[f'sma_{short_period}_minus_{long_period}'] = df[f'sma_{short_period}'] - df[f'sma_{long_period}']
    
    # 計算 EMA 及其差值
    for period in ma_period:
        df[f'ema_{period}'] = calculate_ema(df['close'], span=period, type='ema')
    
    # 僅當 ema_period 包含多於一個值時計算差值
    if len(ma_period) > 1:
        for i, short_period in enumerate(ma_period):
            for long_period in ma_period[i+1:]:
                df[f'ema_{short_period}_minus_{long_period}'] = df[f'ema_{short_period}'] - df[f'ema_{long_period}']

    # 計算 RSI 及其差值
    for period in rsi_period:
        df[f'rsi_{period}'] = calculate_rsi(df['close'], period=period)
    
    # 僅當 rsi_period 包含多於一個值時計算差值
    if len(rsi_period) > 1:
        for i, short_period in enumerate(rsi_period):
            for long_period in rsi_period[i+1:]:
                df[f'rsi_{short_period}_minus_{long_period}'] = df[f'rsi_{short_period}'] - df[f'rsi_{long_period}']

    # 計算布林通道
    for window in bb_period:
        bb = calculate_bollinger_bands(df['close'], window=window, num_std=2)
        df[f'bb_upper_{window}'] = bb['upper']
        df[f'bb_middle_{window}'] = bb['middle']
        df[f'bb_lower_{window}'] = bb['lower']
        df[f'bb_std_{window}'] = bb['std']
    
    # 僅當 bb_window 包含多於一個值時計算差值
    if len(bb_period) > 1:
        for i, short_window in enumerate(bb_period):
            for long_window in bb_period[i+1:]:
                df[f'bb_upper_{short_window}_minus_{long_window}'] = df[f'bb_upper_{short_window}'] - df[f'bb_upper_{long_window}']
                df[f'bb_middle_{short_window}_minus_{long_window}'] = df[f'bb_middle_{short_window}'] - df[f'bb_middle_{long_window}']
                df[f'bb_lower_{short_window}_minus_{long_window}'] = df[f'bb_lower_{short_window}'] - df[f'bb_lower_{long_window}']
                df[f'bb_std_{short_window}_minus_{long_window}'] = df[f'bb_std_{short_window}'] - df[f'bb_std_{long_window}']

    # 計算 ATR 及其差值
    for window in atr_period:
        df[f'atr_{window}'] = calculate_atr(df, window=window)
    
    # 僅當 atr_window 包含多於一個值時計算差值
    if len(atr_period) > 1:
        for i, short_window in enumerate(atr_period):
            for long_window in atr_period[i+1:]:
                df[f'atr_{short_window}_minus_{long_window}'] = df[f'atr_{short_window}'] - df[f'atr_{long_window}']

    # 計算 obv
    df['obv'] = calculate_obv(df)

    for swing in swings:
        smc = calculate_smc_indicators(df, swing)
        # 為每個指標創建前綴，避免列名衝突
        prefix = f"swing_{swing}_"
        
        # 儲存新添加的欄位名稱，用於後續填充
        new_columns = []
        
        # 處理每個指標並添加到 df
        for indicator_name, indicator_data in smc.items():
            # 確保 indicator_data 是 DataFrame
            if not isinstance(indicator_data, pd.DataFrame):
                indicator_data = pd.DataFrame(indicator_data, index=df.index)
            
            # 為每個指標的列添加前綴
            indicator_data.columns = [f"{prefix}{indicator_name}_{col}" for col in indicator_data.columns]
            
            # 記錄新添加的欄位
            new_columns.extend(indicator_data.columns)
            
            # 按索引合併到 df
            df = pd.concat([df, indicator_data], axis=1)
        
        # 對新添加的欄位填充 'N/A'
        df[new_columns] = df[new_columns].fillna('N/A')
    
    for params in ma_compress_configs:
        short_p = params['short']
        long_p = params['long']
        lookback_p = params['lookback']
        cross_stats = detect_ma_compression(df, short_period=short_p, long_period=long_p, lookback=lookback_p)
        df = pd.concat([df, cross_stats], axis=1)

    df = df.replace(np.nan, 'NA')

    return df

### 資料整理與轉換

##### code1 vs code2

In [3]:
K_time = 5
df1_k = convert_ohlcv(df3, K_time)
df2_k = convert_ohlcv(df4, K_time)
df1_k, df2_k = align_time_series(df1_k, df2_k) 
df1_k, df2_k = add_indicators(df1_k), add_indicators(df2_k)

#### 繪製走勢圖

In [None]:
seriesA = df1_k['close']
seriesB = df2_k['close']

param1 = seriesA
param2 = seriesB

def plot_interactive_two_trend(df_series_1, df_series_2):
    # 1. 正規化數據
    scaler = MinMaxScaler()
    data1_normalized = scaler.fit_transform(df_series_1.values.reshape(-1, 1))
    data2_normalized = scaler.fit_transform(df_series_2.values.reshape(-1, 1))
    
    # 2. 創建 Plotly 圖表物件
    fig = go.Figure()

    # 3. 添加第一條線 (Series A)
    fig.add_trace(go.Scatter(
        x=df_series_1.index, 
        y=data1_normalized.flatten(),
        mode='lines',
        name='A (Normalized)',
        line=dict(color='blue')
    ))

    # 4. 添加第二條線 (Series B)
    fig.add_trace(go.Scatter(
        x=df_series_2.index, 
        y=data2_normalized.flatten(),
        mode='lines',
        name='B (Normalized)',
        line=dict(color='red')
    ))

    # 5. 更新圖表佈局 - 使用 category 軸
    fig.update_layout(
        title='A vs B (Normalized)',
        xaxis_title='時間',
        yaxis_title='正規化數值',
        legend_title='商品',
        xaxis=dict(
            type="category"  # ← 關鍵：把 index 當作類別處理
        )
    )
    
    # 6. 顯示圖表
    fig.show()

plot_interactive_two_trend(param1, param2)

def plot_interactive_spread(df_series_1, df_series_2):
    # 1. 正規化數據
    scaler = StandardScaler()
    data1_normalized = scaler.fit_transform(df_series_1.values.reshape(-1, 1))
    data2_normalized = scaler.fit_transform(df_series_2.values.reshape(-1, 1))

    # 2. 計算正規化後的價差 (spread)
    spread_series = data2_normalized - data1_normalized
    spread_mean = spread_series.mean()  # 計算價差的平均值

    # 3. 創建 Plotly 圖表物件
    fig = go.Figure()

    # 4. 添加價差線 (Spread Series)
    fig.add_trace(go.Scatter(
        x=df_series_1.index,          
        y=spread_series.flatten(),    
        mode='lines',
        name='Spread (B_norm - A_norm)',
        line=dict(color='blue')
    ))

    # 5. 添加平均值線 (Mean Line)
    fig.add_trace(go.Scatter(
        x=df_series_1.index,
        y=[spread_mean] * len(df_series_1),
        mode='lines',
        name='Mean',
        line=dict(color='red', dash='dash')
    ))

    # 6. 更新圖表佈局 - 改成 category 軸
    fig.update_layout(
        title='Normalized Spread (B - A)',
        xaxis_title='時間',
        yaxis_title='價差',
        legend_title='系列',
        xaxis=dict(
            type="category"
        )
    )
    
    # 7. 顯示圖表
    fig.show()

plot_interactive_spread(param1, param2)

def plot_interactive_candlesticks(df1, df2, name1='商品A', name2='商品B'):
    # 1. 創建子圖 (4 行 1 列)
    fig = make_subplots(
        rows=4, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.02,
        subplot_titles=(f"{name1} K線", f"{name1} 成交量", f"{name2} K線", f"{name2} 成交量")
    )

    # 2. 商品 A 的 K 線 (row=1)
    fig.add_trace(go.Candlestick(
        x=df1.index,
        open=df1['open'],
        high=df1['high'],
        low=df1['low'],
        close=df1['close'],
        name=f"{name1} K線"
    ), row=1, col=1)

    # 3. 商品 A 的成交量 (row=2)
    fig.add_trace(go.Bar(
        x=df1.index,
        y=df1['volume'],
        name=f"{name1} 成交量",
        marker_color='blue',
        opacity=0.5
    ), row=2, col=1)

    # 4. 商品 B 的 K 線 (row=3)
    fig.add_trace(go.Candlestick(
        x=df2.index,
        open=df2['open'],
        high=df2['high'],
        low=df2['low'],
        close=df2['close'],
        name=f"{name2} K線"
    ), row=3, col=1)

    # 5. 商品 B 的成交量 (row=4)
    fig.add_trace(go.Bar(
        x=df2.index,
        y=df2['volume'],
        name=f"{name2} 成交量",
        marker_color='red',
        opacity=0.5
    ), row=4, col=1)

    # 6. 更新佈局 - 使用 category 軸
    fig.update_layout(
        title_text='商品 K 線 + 成交量 對比圖',
        height=1000,
        showlegend=False,
        # 讓所有 X 軸都使用 category (避免補全天)
        xaxis=dict(type="category"),
        xaxis2=dict(type="category"),
        xaxis3=dict(type="category"),
        xaxis4=dict(type="category"),
        # 各自的 y 軸標題
        yaxis_title=f"{name1} 價格",
        yaxis2_title=f"{name1} 成交量",
        yaxis3_title=f"{name2} 價格",
        yaxis4_title=f"{name2} 成交量"
    )

    fig.update_xaxes(rangeslider_visible=False, row=1, col=1)
    fig.update_xaxes(rangeslider_visible=False, row=3, col=1)

    # 7. 顯示
    fig.show()

plot_interactive_candlesticks(df1_k, df2_k, name1='第一個商品', name2='第二個商品')

def plot_jointplot_from_series(series1, series2, kind="scatter", title="sns_plot", show_stats=True, **kwargs):
    import seaborn as sns
    from scipy import stats
    #圖形類型 ('scatter', 'reg', 'hex', 'kde', 'hist')

    # 設置默認顏色
    if "color" not in kwargs:
        kwargs["color"] = "blue"
    
    # 繪製圖形
    jp = sns.jointplot(x=series1, y=series2, kind=kind, **kwargs)
    jp.fig.suptitle(title)
    
    # 計算統計量
    pearson_r = series1.corr(series2, method='pearson')
    spearman_r = series1.corr(series2, method='spearman')
    covariance = np.cov(series1, series2)[0, 1]
    slope, intercept, r_value, p_value, _ = stats.linregress(series1, series2)
    
    # 打印詳細分析結果
    print("\n" + "="*50)
    print("【變量分析報告】")
    print(f"變量1: {series1.name if series1.name else 'Series1'}")
    print(f"變量2: {series2.name if series2.name else 'Series2'}")
    print("-"*50)
    
    # 判斷相關係數強度
    def judge_correlation(r, method_name):
        abs_r = abs(r)
        if abs_r >= 0.8:
            strength = "極強相關"
        elif abs_r >= 0.6:
            strength = "強相關"
        elif abs_r >= 0.4:
            strength = "中等相關"
        elif abs_r >= 0.2:
            strength = "弱相關"
        else:
            strength = "幾乎無相關"
        
        direction = "正" if r > 0 else "負" if r < 0 else "無"
        print(f"{method_name:>10}: {r:.4f} ({direction}相關, {strength})")
    
    print("\n【相關性分析】")
    judge_correlation(pearson_r, "Pearson")
    judge_correlation(spearman_r, "Spearman")
    print(f"{'協方差':>10}: {covariance:.4f}")
    
    print("\n【迴歸分析】")
    print(f"{'迴歸方程':>10}: y = {slope:.4f}x + {intercept:.4f}")
    print(f"{'R平方值':>10}: {r_value**2:.4f}")
    
    # 判斷顯著性
    significance = "顯著" if p_value < 0.05 else "不顯著"
    print(f"{'p-value':>10}: {p_value:.4f} ({significance}, α=0.05)")
    
    print("="*50 + "\n")
    
    # 在圖上添加統計資訊
    if show_stats:
        stats_text = (f"Pearson r = {pearson_r:.2f}\n"
                     f"Spearman ρ = {spearman_r:.2f}\n"
                     f"y = {slope:.2f}x + {intercept:.2f}\n"
                     f"p = {p_value:.3f}")
        
        jp.ax_joint.text(0.05, 0.95, stats_text, 
                        transform=jp.ax_joint.transAxes,
                        verticalalignment='top',
                        bbox=dict(facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    return jp  # 只返回圖形對象

# plot_jointplot_from_series(series1=param1, series2=param2, kind="reg")

#### Neo4j

##### 創建

In [None]:
from neo4j import GraphDatabase

URI = "bolt://localhost:7687"
AUTH = ("neo4j", "test1234")

def create_database(uri, user, password, db_name):
    """
    連接到 system 資料庫並創建一個新的資料庫。
    """
    # 建立驅動程式實例
    # 使用 with 陳述式可以確保連線在使用後自動關閉
    with GraphDatabase.driver(uri, auth=(user, password)) as driver:
        try:
            # 驗證連線
            driver.verify_connectivity()
            print("成功連接到 Neo4j。")

            # 創建資料庫需要透過 system 資料庫執行
            # 不能在 session 中指定 database="system"，而是在執行查詢時動態選擇
            session = driver.session()

            # 檢查資料庫是否已存在 (可選)
            check_query = f"SHOW DATABASES WHERE name = '{db_name}'"
            result = session.run(check_query)
            if result.single():
                print(f"資料庫 '{db_name}' 已經存在。")
                return

            # 創建資料庫的 Cypher 語法
            # 'IF NOT EXISTS' 可以避免在資料庫已存在時拋出錯誤
            create_query = f"CREATE DATABASE {db_name} IF NOT EXISTS"
            
            print(f"正在執行查詢: {create_query}")
            session.run(create_query)
            print(f"成功發送創建資料庫 '{db_name}' 的請求。")
            print("請注意：資料庫的啟動可能需要一些時間。")

        except Exception as e:
            print(f"發生錯誤: {e}")
        finally:
            session.close()

# --- 您要創建的資料庫名稱 ---
new_database_name = "mydatabase"
create_database(URI, AUTH[0], AUTH[1], new_database_name)

##### 插入

In [16]:
from neo4j import GraphDatabase
from langchain_ollama.embeddings import OllamaEmbeddings
from neo4j.time import DateTime
import json

# Neo4j 連線設定
uri = "bolt://localhost:7687"
user = "neo4j"
password = "test1234"
EMBEDDING_NAME = "mxbai-embed-large:latest"
OLLAMA_BASE_URL = "http://localhost:11435"

driver = GraphDatabase.driver(uri, auth=(user, password))

def prepare_data_for_batch(data, product_id, embed):
    records_df = data.copy()
    records_df.reset_index(names='index', inplace=True)
    
    # 將新欄位和 'ts' 欄位都轉換為 Neo4j driver 兼容的 DateTime 對象
    records_df['index'] = records_df['index'].apply(lambda x: DateTime.from_native(x.to_pydatetime()))
    records_df['ts'] = records_df['ts'].apply(lambda x: DateTime.from_native(x.to_pydatetime()))
    
    # 找出因子欄位 (fixed_cols 也使用新的欄位名 index)
    fixed_cols = ['index', 'ts', 'open', 'high', 'low', 'close', 'volume']
    factor_cols = [col for col in records_df.columns if col not in fixed_cols]
    
    # --- 預先批次處理所有唯一的 Factor 名稱 ---
    print("開始批次處理 Factor embeddings...")
    unique_factors = factor_cols
    factor_embeddings = embed.embed_documents(unique_factors)
    factor_embedding_cache = dict(zip(unique_factors, factor_embeddings))
    print("✅ Factor embeddings 快取完成！")
    
    # --- 準備所有 KLine 的文本，以進行批次處理 ---
    print("開始批次處理 KLine embeddings...")
    kline_texts_to_embed = [
        f"{row.index} {row.ts} {row.open} {row.high} {row.low} {row.close} {row.volume}"
        for row in records_df.itertuples()
    ]
    kline_embeddings = embed.embed_documents(kline_texts_to_embed)
    print("✅ KLine embeddings 批次處理完成！")
    
    # --- 準備 Product embedding ---
    product_text = f"Product {product_id}"
    product_embedding = embed.embed_query(product_text)
    
    # 將 DataFrame 轉換為字典列表
    records_list = []
    for i, row in enumerate(records_df.itertuples(index=False)):
        # 1. 準備 KLine 節點的屬性
        kline_props = {
            'index': row.index,
            'ts': row.ts,
            'open': row.open,
            'high': row.high,
            'low': row.low,
            'close': row.close,
            'volume': row.volume,
            'embedding': kline_embeddings[i],
            'description': f"KLine at {row.index} with open: {row.open}, high: {row.high}, low: {row.low}, close: {row.close}, volume: {row.volume}"
        }
        
        # 2. 準備 Factor 節點的屬性，從快取中獲取 embedding
        factors = []
        for factor_name in factor_cols:
            value = getattr(row, factor_name) # 使用 getattr 從 namedtuple 中獲取屬性
            if value is None or (isinstance(value, (list, dict)) and not value):
                continue
            
            factor_data = {
                'name': factor_name,
                'embedding': factor_embedding_cache[factor_name],
                'index': row.index,
                'ts': row.ts
            }
            
            if isinstance(value, list):
                converted_value = [{str(k): v for k, v in item.items()} for item in value]
                value = json.dumps(converted_value)

            factor_data['description'] = f"Factor {factor_name} with value {value}"
            factor_data['value'] = value
            
            factors.append(factor_data)

        # 3. 組合成單一記錄
        records_list.append({
            'product': {
                'id': product_id,
                'embedding': product_embedding,
                'description': f"Product node for {product_id}"
            },
            'kline': kline_props,
            'factors': factors
        })
        
    return records_list

def insert_data_to_neo4j(data, product_id, embed, database_name='neo4j'):
    if data.empty:
        print(f"數據為空，僅執行刪除 Product '{product_id}' 的舊數據操作。")
        with driver.session() as session:
            # 僅執行刪除
            session.run("""
                MATCH (p:Product {id: $product_id})-[:HAS_KLINE]->(k:KLine)
                DETACH DELETE k, f
            """, product_id=product_id)
        return

    # --- 1. 執行安全的刪除操作 (這部分維持不變，是好的做法) ---
    with driver.session(database=database_name) as session:
        session.run("""
            MATCH (p:Product {id: $product_id})-[:HAS_KLINE]->(k:KLine)
            OPTIONAL MATCH (k)-[:HAS_FACTOR]->(f)
            DETACH DELETE k, f
        """, product_id=product_id)
        print(f"已安全刪除 Product '{product_id}' 的所有舊數據。")

    # 如果刪除後沒有新數據要插入，就直接返回
    if data.empty:
        print(f"數據為空，已完成 Product '{product_id}' 的舊數據刪除操作，無新數據插入。")
        return

    # --- 2. 準備批次數據 ---
    records_list = prepare_data_for_batch(data, product_id, embed)
    
    # --- 3. 定義使用 UNWIND 的 Cypher 查詢 ---
    # 這個查詢只會被執行一次
    cypher_unwind = """
    UNWIND $rows AS row
    MERGE (p:Product {id: row.product.id})
    SET
      p.embedding = row.product.embedding,
      p.description = row.product.description
    WITH p, row
    // 商品專屬時間點 (注意 label 使用 apoc 動態拼接)
    CALL apoc.merge.node(
        [ "TimePoint_" + row.product.id ],
        {ts: row.kline.index},
        {minute: apoc.date.parse(toString(row.kline.index), "ms", "yyyy-MM-dd\\'T\\'HH:mm:ss")}
    ) YIELD node AS t

    // 建立 KLine
    MERGE (k:KLine {
        index: row.kline.index,
        ts: row.kline.ts,
        open: row.kline.open,
        high: row.kline.high,
        low: row.kline.low,
        close: row.kline.close,
        volume: row.kline.volume,
        embedding: row.kline.embedding,
        description: row.kline.description
    })

    // Product → KLine
    MERGE (p)-[:HAS_KLINE {ts: row.kline.index}]->(k)

    // KLine → TimePoint
    MERGE (k)-[:AT_TIME]->(t)

    // 建立 Factor 並掛勾 KLine 與 TimePoint
    FOREACH (factor IN row.factors |
        MERGE (f:Factor {
            name: factor.name,
            value: factor.value,
            embedding: factor.embedding,
            description: factor.description,
            index: factor.index,
            ts: factor.ts
        })
        MERGE (k)-[:HAS_FACTOR {name: factor.name, ts: factor.index}]->(f)
        MERGE (f)-[:AT_TIME]->(t)
    )

    // --- 建立時間序列鏈結 (針對商品專屬時間點) ---
    WITH $product_id AS pid
    MATCH (p:Product {id: pid})-[:HAS_KLINE]->(k:KLine)-[:AT_TIME]->(t)
    WITH t ORDER BY t.ts ASC
    WITH collect(t) AS time_list
    UNWIND range(0, size(time_list)-2) AS i
    WITH time_list[i] AS t1, time_list[i+1] AS t2
    MERGE (t1)-[:NEXT {
        duration: apoc.date.parse(toString(t2.ts), 'ms', "yyyy-MM-dd\\'T\\'HH:mm:ss")
                - apoc.date.parse(toString(t1.ts), 'ms', "yyyy-MM-dd\\'T\\'HH:mm:ss")
    }]->(t2)
    """

    # --- 4. 執行單一的高效批次創建操作 ---
    with driver.session(database=database_name) as session:
        session.run(cypher_unwind, rows=records_list, product_id=product_id)
    
    print(f"已成功透過批次插入 Product '{product_id}' 的 {len(records_list)} 筆 KLine 數據及其因子。")

database_name = 'neo4j'
embed = OllamaEmbeddings(model=EMBEDDING_NAME, base_url=OLLAMA_BASE_URL, keep_alive=0)
insert_data_to_neo4j(df1_k, 'FXFR', embed, database_name=database_name)

embed = OllamaEmbeddings(model=EMBEDDING_NAME, base_url=OLLAMA_BASE_URL, keep_alive=0)
insert_data_to_neo4j(df2_k, 'ZFFR', embed, database_name=database_name)
driver.close()

已安全刪除 Product 'FXFR' 的所有舊數據。
開始批次處理 Factor embeddings...
✅ Factor embeddings 快取完成！
開始批次處理 KLine embeddings...
✅ KLine embeddings 批次處理完成！
已成功透過批次插入 Product 'FXFR' 的 346 筆 KLine 數據及其因子。
已安全刪除 Product 'ZFFR' 的所有舊數據。
開始批次處理 Factor embeddings...
✅ Factor embeddings 快取完成！
開始批次處理 KLine embeddings...
✅ KLine embeddings 批次處理完成！
已成功透過批次插入 Product 'ZFFR' 的 346 筆 KLine 數據及其因子。


##### 刪除

In [1]:
from neo4j import GraphDatabase

uri = "bolt://localhost:7687"
user = "neo4j"
password = "test1234"
driver = GraphDatabase.driver(uri, auth=(user, password))

database_name = 'neo4j'

def clear_neo4j(tx):
    # DETACH DELETE 會先刪掉節點的所有關係再刪節點
    tx.run("MATCH (n) DETACH DELETE n")

with driver.session(database=database_name) as session:
    session.execute_write(clear_neo4j)
    print("已經清空 Neo4j 資料庫！")

driver.close()


已經清空 Neo4j 資料庫！


##### 查看

In [2]:
from neo4j import GraphDatabase

uri = "bolt://localhost:7687"
user = "neo4j"
password = "test1234"
driver = GraphDatabase.driver(uri, auth=(user, password))

# --- 你的連線設定和插入函式放在這裡 ---
def check_data_in_neo4j(driver, sample_size=3, database='neo4j'):
    """查詢並印出節點總數與部分節點內容，包括 KLine 節點的 Factor 資料"""

    def get_node_counts(tx):
        result = tx.run("MATCH (n) RETURN labels(n) AS Label, count(n) AS Count")
        return [record.data() for record in result]

    def get_sample_kline_with_factors(tx, limit):
        query = """
        MATCH (k:KLine)
        OPTIONAL MATCH (k)-[:HAS_FACTOR]->(f:Factor)
        RETURN k, collect(
            CASE WHEN f IS NOT NULL 
                 THEN {name: f.name, value: f.value} 
                 ELSE NULL 
            END
        ) AS factors
        LIMIT $limit
        """
        result = tx.run(query, limit=limit)
        # Filter out NULL entries in factors
        return [(record["k"], [f for f in record["factors"] if f is not None]) for record in result]

    def get_sample_nodes(tx, label, limit):
        if label == "KLine":
            return get_sample_kline_with_factors(tx, limit)
        else:
            query = f"MATCH (n:{label}) RETURN n LIMIT $limit"
            result = tx.run(query, limit=limit)
            return [(record["n"], []) for record in result]

    with driver.session(database=database) as session:
        counts = session.execute_read(get_node_counts)
        print("--- 資料庫節點統計 ---")
        if not counts:
            print("資料庫是空的！")
        else:
            for record in counts:
                label = record["Label"][0] if record["Label"] else "Unknown"
                count = record["Count"]
                print(f"\n\n標籤: {label}, 數量: {count}")
                samples = session.execute_read(get_sample_nodes, label, sample_size)
                for i, (node, factors) in enumerate(samples, 1):
                    print(f"  範例 {i}:")
                    print(f"    節點屬性: {dict(node)}")
                    if factors:
                        print("    關聯的 Factor:")
                        for factor in factors:
                            print(f"      - 名稱: {factor['name']}, 值: {factor['value']}")
                    else:
                        print("    關聯的 Factor: 無")
        print("--------------------")

database_name = 'neo4j'
check_data_in_neo4j(driver, database=database_name)

driver.close()

--- 資料庫節點統計 ---
資料庫是空的！
--------------------


##### 向量索引

In [30]:
from neo4j import GraphDatabase

uri = "bolt://localhost:7687"
user = "neo4j"
password = "test1234"

driver = GraphDatabase.driver(uri, auth=(user, password))

def create_factor_index(dim=1024):
    with driver.session() as session:
        session.run(f"""
        CREATE VECTOR INDEX factor_name_embeddings
        FOR (f:Factor) ON (f.embedding)
        OPTIONS {{
          indexConfig: {{
            `vector.dimensions`: {dim},
            `vector.similarity_function`: 'cosine'
          }}
        }}
        """)
    print("✅ Factor 向量索引建立完成")

def create_product_index(dim=1024):
    with driver.session() as session:
        session.run(f"""
        CREATE VECTOR INDEX product_embeddings
        FOR (p:Product) ON (p.embedding)
        OPTIONS {{
          indexConfig: {{
            `vector.dimensions`: {dim},
            `vector.similarity_function`: 'cosine'
          }}
        }}
        """)
    print("✅ Product 向量索引建立完成")
    
def create_kline_index(dim=1024):
    with driver.session() as session:
        session.run(f"""
        CREATE VECTOR INDEX kline_embeddings
        FOR (k:KLine) ON (k.embedding)
        OPTIONS {{
            indexConfig: {{
                `vector.dimensions`: {dim},
                `vector.similarity_function`: 'cosine'
            }}
        }}
        """)
    print("✅ KLine 向量索引建立完成")

def list_vector_indexes():
    with driver.session() as session:
        result = session.run("""
        SHOW INDEXES YIELD name, type, entityType, labelsOrTypes, properties
        WHERE type = 'VECTOR' AND labelsOrTypes IN [['Factor'], ['Product']]
        RETURN *
        """)
        indexes = [r.data() for r in result]

    print("📌 目前資料庫建立過的向量索引：")
    for idx in indexes:
        print(idx)

    return indexes

def drop_vector_index(index_name: str):
    with driver.session() as session:
        session.run(f"DROP INDEX {index_name}")
    print(f"🗑️ 已刪除索引: {index_name}")

def check_indexes():
    with driver.session() as session:
        result = session.run("SHOW INDEXES WHERE type = 'VECTOR'")
        for record in result:
            print(record["name"], record["labelsOrTypes"], record["properties"])
    print("✅ 索引檢查完成")

# --- 測試 ---
# create_factor_index()
# create_product_index()
# create_kline_index()
# list_vector_indexes()
check_indexes()
# drop_vector_index("factor_name_embeddings")
# drop_vector_index("product_embeddings")
# drop_vector_index("kline_embeddings")

factor_name_embeddings ['Factor'] ['embedding']
kline_embeddings ['KLine'] ['embedding']
product_embeddings ['Product'] ['embedding']
✅ 索引檢查完成


#### LLM
```
用「精確過濾」回答「精確問題」 (時間、價格、數值) => Cypher精準查詢法
用「向量搜索」回答「模糊/概念/相似性問題」 (形態、趨勢、感覺) => 知識圖譜向量搜索法
```

In [4]:
import requests

# --- 輔助函式：手動卸載模型 ---
def unload_ollama_model(model_name: str, base_url: str = "http://localhost:11434"):  # 注意：Ollama 預設端口為 11434，而非 11435
    """
    手動呼叫 Ollama API 來卸載指定的模型。
    通過發送一個空提示的 generate 請求，並設定 keep_alive=0 來實現卸載。
    """
    print(f"\n--- 正在請求從 GPU 卸載模型: {model_name} ---")
    try:
        url = f"{base_url}/api/generate"
        payload = {
            "model": model_name,
            "prompt": "",  # 空提示，避免不必要的生成
            "keep_alive": 0  # 設定 keep_alive=0 以立即卸載，注意：移到頂層，而非 options 中
        }
        response = requests.post(url, json=payload)
        response.raise_for_status()
        print(f"✅ 成功發送卸載請求。")
    except requests.exceptions.RequestException as e:
        print(f"❌ 卸載模型時發生錯誤: {e}")
        print("請確認 Ollama 服務正在運行，且 base_url 設定正確（預設為 http://localhost:11434）。")

def _test_ollama_connection(base_url: str):
    import torch
    print("\n   [連線測試] 偵測到 ResponseError，正在執行獨立的連線測試...")
    # 移除 URL 尾部的斜線，以確保路徑拼接正確
    if base_url.endswith('/'):
        base_url = base_url[:-1]
    
    # 使用 /api/tags 端點，這是一個輕量級的 GET 請求，適合用來測試連線
    test_url = f"{base_url}/api/tags" 
    
    try:
        # 設定一個較短的超時時間 (例如 5 秒)
        response = requests.get(test_url, timeout=5)
        
        # 檢查 HTTP 狀態碼
        if response.status_code == 200:
            print(f"   ✅ [連線測試] 成功！與 {base_url} 的連線正常。")
            print("      Ollama 服務有回應。錯誤可能與模型名稱或請求內容有關。")
            try:
                # 嘗試解析 JSON 內容，看看有哪些模型
                models_available = [m.get("name") for m in response.json().get("models", [])]
                print(f"      偵測到可用模型: {models_available}")
            except requests.exceptions.JSONDecodeError:
                print("      回應內容不是有效的 JSON 格式。")

        else:
            print(f"   ❌ [連線測試] 失敗！伺服器回應錯誤。")
            print(f"      URL: {test_url}")
            print(f"      HTTP 狀態碼: {response.status_code}")
            print(f"      回應內容: {response.text[:200]}") # 只顯示前 200 個字元
            
        if torch.cuda.is_available():
            print("\n   🔹 CUDA GPU 狀態:")
            print(f"      - GPU 數量: {torch.cuda.device_count()}")
            for i in range(torch.cuda.device_count()):
                print(f"      - GPU {i}: {torch.cuda.get_device_name(i)}, Memory Allocated: {torch.cuda.memory_allocated(i)/1e6:.1f} MB, Cached: {torch.cuda.memory_reserved(i)/1e6:.1f} MB")
        else:
            print("\n   🔹 CUDA GPU 狀態: 無可用 GPU")

    except requests.exceptions.ConnectionError:
        print(f"   ❌ [連線測試] 失敗！無法連接到 {base_url}。")
        print("      請檢查主機名稱和 Port 是否正確，以及防火牆設定。")
        
    except requests.exceptions.Timeout:
        print(f"   ❌ [連線測試] 失敗！請求超時。")
        print(f"      伺服器在 5 秒內沒有回應，可能已無回應或網路延遲過高。")
        
    except Exception as test_e:
        print(f"   ❌ [連線測試] 發生未預期的錯誤。")
        print(f"      錯誤類型: {type(test_e).__name__}")
        print(f"      詳細資訊: {test_e}")


##### 知識圖譜向量搜索法

In [None]:
from typing import Tuple, Dict
from langchain_ollama import ChatOllama
from langchain_neo4j import Neo4jGraph
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import Neo4jVector
from langchain_core.messages import HumanMessage, SystemMessage
from graph_retriever.strategies import Eager
from langchain_graph_retriever import GraphRetriever
from langchain_core.documents import Document
from typing import List, Dict, Tuple, Dict
import re, json

# --- 模型與資料庫設定 ---
MODEL_NAME = "phi4:latest"
EMBEDDING_NAME = "mxbai-embed-large:latest"
OLLAMA_BASE_URL = "http://localhost:11435"
NEO4J_URI = "neo4j://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "test1234"

# --- 模型初始化 ---
embed = OllamaEmbeddings(model=EMBEDDING_NAME, base_url=OLLAMA_BASE_URL, keep_alive=0)
graph = Neo4jGraph(url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)

# Product 檢索器
product_vector_store = Neo4jVector.from_existing_index(
    embedding=embed,
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,
    index_name="product_embeddings",
    node_label="Product",
    embedding_node_property="embedding",
    text_node_property="description",
    retrieval_query="""
    RETURN
        node.description AS text,
        score,
        {
            id: node.id,
            description: node.description
        } AS metadata
    """
)

# Factor 檢索器
factor_vector_store = Neo4jVector.from_existing_index(
    embedding=embed,
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,
    index_name="factor_embeddings",
    node_label="Factor",
    embedding_node_property="embedding",
    text_node_property="description",
    retrieval_query="""
    RETURN
        node.description AS text,
        score,
        {
            name: node.name,
            value: node.value,
            ts: node.ts
        } AS metadata
    """
)

# KLine 檢索器
kline_vector_store = Neo4jVector.from_existing_index(
    embedding=embed,
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,
    index_name="kline_embeddings",
    node_label="KLine",
    embedding_node_property="embedding",
    text_node_property="description",
    retrieval_query="""
    RETURN
        node.description AS text,
        score,
        {
            index: node.index,
            ts: node.ts,
            close: node.close,
            volume: node.volume
        } AS metadata
    """
)

vector_stores = {
    "Product": product_vector_store,
    "Factor": factor_vector_store,
    "KLine": kline_vector_store
}

def parse_query_to_json(question):
    # 定義 LLM
    llm = ChatOllama(base_url=OLLAMA_BASE_URL, model=MODEL_NAME, temperature=0, keep_alive=0)

    system_prompt = """
        你是問題解析專家。請將以下中文問題解析為 JSON 結構。
        輸出時**僅輸出 JSON**，不要附加任何文字、解釋或標記。
        輸出格式：
        {
            "variables": [],  // 列表，問句中所有提到的商品，如 ["productA", "productB"]
            "time_range": {"start": "YYYY-MM-DD HH:MM:SS", "end": "YYYY-MM-DD HH:MM:SS"},  // 如果只有時間，假設今天是2025-09-28
            "factors": [],  // 列表，所有提到的因子，如 ["flow_imbalance", "bid_ask"]
            "k_bar_variables": [],  // 列表，K棒相關變量，支持多個，對應商品有幾組就有幾個，如 ["K棒價格 for productA", "K棒價格 for productB"]
            "analysis_type": ""  // 如 "連動走勢"、"反向走勢"、"滯後分析" 或 "value query"。判斷關係類型，如連動、反向、滯後等。如果不明確，填 "value query"。
        }

        要求：
        1. variables: 提取所有問句中的商品名稱（括號內英文優先）。
        2. time_range: 轉成完整格式。如果只有時間如9:00~9:40，假設日期為2025-09-28，並補齊秒為00。
        3. factors: 提取所有因子名稱，支持多個。
        4. k_bar_variables: 如果提到K棒或價格走勢，填入相關描述列表，對應每個商品（如有多個商品）。
        5. analysis_type: 判斷分析類型，如 "連動走勢"、"反向走勢"、"滯後分析" 等，支持連動、反向、滯後等關係。如果不明確，填 "value query"。

        範例輸入：商品A (productA), 商品B (productB) 之間, 在9:00~9:40分之間, 他們的factor中的flow_imbalance和我的商品的K棒價格會有連動的走勢？
        範例輸出：{"variables": ["productA", "productB"], "time_range": {"start": "2025-09-28 09:00:00", "end": "2025-09-28 09:40:00"}, "factors": ["flow_imbalance"], "k_bar_variables": ["K棒價格"], "analysis_type": "連動走勢"}
    """

    # 建立 messages 列表：系統提示 + 用戶問題
    messages = [SystemMessage(content=system_prompt), HumanMessage(content=f"問題：{question}")]

    # 調用模型
    try:
        raw_response = llm.invoke(messages)
        text = raw_response.content if hasattr(raw_response, "content") else str(raw_response)
        m = re.search(r"\{(?:.|\n)*\}", text)
        if m:
            data = json.loads(m.group(0))
        else:
            data = {"variables": [], "factors": [], "time_range": None, "k_bar_variables": [], "analysis_type": ""}

        
    except Exception as e:
        data = {"variables": [], "factors": [], "time_range": None, "k_bar_variables": [], "analysis_type": ""}
        print(f"解析失敗：{e}，使用預設空值")
    
    return data

def retrieve_products(variables, vector_stores, limit=5):
    if variables:
        query = """
        MATCH (p:Product)
        WHERE p.id IN $variables
        RETURN p.id AS id, p.description AS description
        """
        return neo4j_driver.execute_query(query, {"variables": variables}, database_="neo4j")
    else:
        return vector_stores["Product"].similarity_search("找相似商品", k=limit)

def retrieve_factors(variables, factors, time_range, vector_stores, limit=10):
    if factors:
        query = """
        MATCH (p:Product)-[:HAS_KLINE]->(k:KLine)-[:HAS_FACTOR]->(f:Factor)
        WHERE p.id IN $variables
          AND f.name IN $factors
          AND k.ts >= datetime($start) AND k.ts <= datetime($end)
        RETURN f.name AS name, f.value AS value, f.ts AS ts
        ORDER BY f.ts ASC
        """
        return neo4j_driver.execute_query(query, {
            "variables": variables,
            "factors": factors,
            "start": time_range["start"],
            "end": time_range["end"]
        }, database_="neo4j")
    else:
        return vector_stores["Factor"].similarity_search("找相似因子", k=limit)

def retrieve_klines(variables, time_range, vector_stores, limit=50):
    if variables and time_range:
        query = """
        MATCH (p:Product)-[:HAS_KLINE]->(k:KLine)
        WHERE p.id IN $variables
          AND k.ts >= datetime($start) AND k.ts <= datetime($end)
        RETURN k.index AS index, k.ts AS ts, k.open AS open, k.close AS close, k.volume AS volume
        ORDER BY k.ts ASC
        """
        return neo4j_driver.execute_query(query, {
            "variables": variables,
            "start": time_range["start"],
            "end": time_range["end"]
        }, database_="neo4j")
    else:
        return vector_stores["KLine"].similarity_search("找相似K棒", k=limit)


def graphrag_query(question: str, vector_stores: Dict[str, 'Neo4jVector'], limit: int = 10, database_name='neo4j') -> Tuple[str, List[Dict]]:
    print("--- 步驟 1: 解析使用者問題 ---")
    query_struct = parse_query_to_json(question)
    variables = query_struct.get("variables", [])
    factors = query_struct.get("factors", [])
    time_range = query_struct.get("time_range", None)
    k_bar_variables = query_struct.get("k_bar_variables", [])
    analysis_type = query_struct.get("analysis_type", "")
    print(f"✅ 問題解析完成: {query_struct}")
    
    print("--- 步驟 2: Neo4j 檢索 ---")
    products = retrieve_products(variables, vector_stores)
    factors_data = retrieve_factors(variables, factors, time_range, vector_stores)
    klines = retrieve_klines(variables, time_range, vector_stores)

    raw_results = {
        "products": products,
        "factors": factors_data,
        "klines": klines
    }

    print("--- 步驟 3: LLM 分析 ---")
    llm = ChatOllama(base_url=OLLAMA_BASE_URL, model=MODEL_NAME, temperature=0, keep_alive=0)
    analysis_prompt = f"""
    使用以下資料，判斷 {variables or '相關商品'} 與 {factors or '所有因子'} 
    在 {time_range or '全時段'} 的 {analysis_type} 是否成立。
    
    商品: {raw_results['products']}
    KLine: {raw_results['klines']}
    Factor: {raw_results['factors']}
    """
    answer = llm.invoke([HumanMessage(content=analysis_prompt)]).content
    
    return answer, raw_results

try:
    database_name = 'neo4j'
    question = "商品FXFR和商品ZFFR之間, 在9:00~9:40分之間, 他們的factor中的flow_imbalance和我的商品的K棒價格會有連動的走勢？"
    answer, raw_results = graphrag_query(question, vector_stores, database_name=database_name)
    print(answer)

finally:
    # 在生成最終答案後直接卸載模型
    unload_ollama_model(model_name=MODEL_NAME, base_url=OLLAMA_BASE_URL)


--- 步驟 1: 解析使用者問題 ---
{'variables': ['FXFR', 'ZFFR'], 'time_range': {'start': '2025-09-28 09:00:00', 'end': '2025-09-28 09:40:00'}, 'factors': ['flow_imbalance'], 'k_bar_variables': ['K棒價格 for FXFR', 'K棒價格 for ZFFR'], 'analysis_type': '連動走勢'}
✅ 問題解析完成: {'variables': ['FXFR', 'ZFFR'], 'time_range': {'start': '2025-09-28 09:00:00', 'end': '2025-09-28 09:40:00'}, 'factors': ['flow_imbalance'], 'k_bar_variables': ['K棒價格 for FXFR', 'K棒價格 for ZFFR'], 'analysis_type': '連動走勢'}
--- 步驟 2: 開始動態圖遍歷檢索 ---

--- 正在請求從 GPU 卸載模型: phi4:latest ---
✅ 成功發送卸載請求。


ValueError: Expected adapter or supported vector store, but got langchain_community.vectorstores.neo4j_vector.Neo4jVector

##### cyper精準查詢法

In [None]:
from langchain_community.llms.ollama import Ollama
from ollama import ResponseError
from langchain_community.graphs import Neo4jGraph

# --- 模型與資料庫設定 ---
MODEL_NAME = "mixtral:8x7b"
OLLAMA_BASE_URL = "http://localhost:11435"
NEO4J_URI = "neo4j://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "test1234"

# --- 模型初始化 ---
llm = Ollama(base_url=OLLAMA_BASE_URL, model=MODEL_NAME, temperature=0)
graph = Neo4jGraph(
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD
)

# ========================
# Step 1: 用 LLM 生成 Cypher
# ========================
CYPHER_PROMPT = """
你是一個 Neo4j Cypher 專家。根據以下資料庫 schema，將用戶的問題轉換為可執行的 Cypher 查詢。

Schema:
{schema}

用戶問題:
{question}

只輸出可執行的 Cypher，不要解釋。
"""

cypher_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(input_variables=["schema", "question"], template=CYPHER_PROMPT)
)

def generate_cypher(question: str) -> str:
    # 注意 get_schema 可能是屬性，不要加 ()
    schema = graph.get_schema
    generated = cypher_chain.invoke({
        "schema": schema,
        "question": question
    })
    # invoke 返回字典，需要取 'text'
    return generated['text'].strip()

# ========================
# Step 2: 執行 Cypher 查詢
# ========================
def run_cypher_query(cypher: str):
    try:
        return graph.query(cypher)
    except Exception as e:
        return f"Cypher 執行錯誤: {e}"

# ========================
# Step 3: 根據查詢結果生成最終回答
# ========================
RESULT_PROMPT = """
你是一個金融數據分析專家。
用戶問題:
{question}

查詢結果:
{query_result}

請用自然語言生成對用戶問題的回答。
"""

result_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(input_variables=["question", "query_result"], template=RESULT_PROMPT)
)

def generate_final_answer(question: str, query_result) -> str:
    generated = result_chain.invoke({
        "question": question,
        "query_result": query_result
    })
    return generated['text'].strip()

# ========================
# Step 4: 完整流程封裝
# ========================
def graph_cypher_qa(question: str):
    try:
        # 1. 生成 Cypher
        cypher = generate_cypher(question)
        print(f"\033[32m[Generated Cypher]\n{cypher}\033[0m")

        # 2. 執行 Neo4j 查詢
        query_result = run_cypher_query(cypher)
        print(f"\033[34m[Query Result]\n{query_result}\033[0m")

        # 3. 生成最終自然語言回答
        final_answer = generate_final_answer(question, query_result)
        return final_answer
    
    except ResponseError as e:
        print("\n❌ 錯誤：與 Ollama LLM 的通訊失敗！")
        print(f"   錯誤類型: {type(e).__name__}")
        print(f"   詳細資訊: {e}")

        if 'OLLAMA_BASE_URL' in globals():
            _test_ollama_connection(OLLAMA_BASE_URL)
    
        else:
            print("\n   [連線測試] 未找到 OLLAMA_BASE_URL 全域變數，無法執行連線測試。")

        return "處理失敗：無法連接到語言模型。"

    except KeyError as e:
        print(f"\n❌ 錯誤：在 chain 的返回結果中找不到必要的鍵: {e}！")
        print(f"   錯誤類型: {type(e).__name__}")
        print(f"   完整的 chain 返回內容: {result}")
        print("   --- 可能原因與檢查點 ---")
        print("   1. 這通常意味著 chain 在某個環節執行失敗，但沒有拋出異常，導致輸出格式不完整。")
        print("   2. 請檢查您的 chain 配置，特別是 Prompt 是否有問題。")
        return "處理失敗：返回結果格式不正確。"

    except CypherSyntaxError as e:
        print(f"\n❌ 錯誤：資料庫執行 Cypher 查詢時失敗！")
        print(f"   錯誤類型: {type(e).__name__}")
        print(f"   詳細資訊: {e}")
        print("   --- 可能原因與檢查點 ---")
        print("   1. LLM 生成的 Cypher 查詢可能存在語法錯誤。")
        print("   2. 請檢查上面日誌中生成的 Cypher 語句。")
        return "處理失敗：資料庫查詢錯誤。"

    except Exception as e:
        print(f"\n❌ 發生未預期的錯誤！")
        print(f"   錯誤類型: {type(e).__name__}")
        print(f"   詳細資訊: {e}")
        return "處理失敗：發生未知錯誤。"

try:
    question_test = "列出所有 KLine 節點的 ts"
    graph_cypher_qa(question_test)
    
    # print("=== 問題 1 ===")
    # question_1 = "在index時間為 2024-06-03 08:50:00，flow_imbalance 的數值是多少？"
    # print(graph_cypher_qa(question_1))

    # print("\n=== 問題 2 ===")
    # question_2 = "2023年9月1日上午10點30分到11點，商品A與B的flow_imbalance是否和K棒有連動關係？"
    # print(graphrag_retrieve(question_2))

finally:
    # 在生成最終答案後直接卸載模型
    unload_ollama_model(model_name=MODEL_NAME, base_url=OLLAMA_BASE_URL)
