In [None]:
# ============================================
# 0. Imports & Colab setup
# ============================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
from scipy.stats import spearmanr

plt.rcParams["figure.figsize"] = (12, 6)

# ============================================
# 1. Load data from MyDrive
# ============================================

metric_data = "E:\\5 Code\\2025_cu_qmim\\data\\price_metrics.parquet"
std_data    = "E:\\5 Code\\2025_cu_qmim\\data\\factors_std.parquet"

px_all      = pd.read_parquet(metric_data)
factors_std = pd.read_parquet(std_data)

# Use PX_LAST only (MultiIndex with level "metric")
px = px_all.xs("PX_LAST", axis=1, level="metric")

# Monthly prices (month end)
px_m = px.resample("M").last()

# ============================================
# 2. Factor set for K-means
# ============================================
keep_metrics = [
    # VALUE
    "PE_RATIO", "PX_TO_BOOK_RATIO", "PX_TO_SALES_RATIO",
    "CURRENT_EV_TO_T12M_EBITDA", "FREE_CASH_FLOW_YIELD",
    "EQY_DVD_YLD_12M",

    # QUALITY
    "EBITDA_MARGIN", "GROSS_MARGIN", "OPER_MARGIN",
    "PROF_MARGIN", "RETURN_ON_ASSET",

    # LEVERAGE
    "TOT_DEBT_TO_EBITDA", "TOT_DEBT_TO_TOT_EQY",

    # SIZE
    "CURRENT_MARKET_CAP_SHARE_CLASS",

    # RISK
    "BETA_ADJ_OVERRIDABLE", "VOLATILITY_30D", "VOLATILITY_90D",
    "VOLATILITY_180D", "VOLATILITY_360D",

    # TAIL RISK
    "RET_SKEW_30D", "RET_SKEW_90D", "RET_SKEW_180D", "RET_SKEW_360D",
    "RET_KURT_30D", "RET_KURT_180D", "RET_KURT_360D", "RET_KURT_90D",

    # LIQUIDITY / MOM
    "TURNOVER", "RET_30D",
]

# ============================================
# 3. Helpers: cleaning & EWMA
# ============================================
def _winsorize_row(row, lower, upper):
    if row.isna().all():
        return row
    lo, hi = row.quantile([lower, upper])
    return row.clip(lo, hi)

def clean_data(df, lower=0.01, upper=0.99):
    # cross-sectional winsorization
    df_w = df.apply(_winsorize_row, axis=1, args=(lower, upper))
    # cross-sectional z-score
    mean_cs = df_w.mean(axis=1)
    std_cs  = df_w.std(axis=1).replace(0, np.nan)
    df_z    = df_w.sub(mean_cs, axis=0).div(std_cs, axis=0)
    return df_z

def ewma(factors_kmeans, lambda_=0.94):
    alpha = 1 - lambda_
    metrics_in_data = factors_kmeans.columns.get_level_values("metric").unique()

    for m in metrics_in_data:
        X = factors_kmeans.xs(m, axis=1, level="metric")
        X_smooth = X.ewm(alpha=alpha, adjust=False, min_periods=1).mean()
        # put back metric level
        X_smooth.columns = pd.MultiIndex.from_product(
            [[m], X_smooth.columns], names=["metric", "stock"]
        )
        mask = factors_kmeans.columns.get_level_values("metric") == m
        factors_kmeans.loc[:, mask] = X_smooth.values

    return factors_kmeans

# ============================================
# 4. Build momentum, factors, trade dates
# ============================================
def get_trade_setup(px_m, factors_std, gap=6):
    # Forward 1M log returns
    fwd_ret_m = np.log(px_m.shift(-1)) - np.log(px_m)

    # "6-1"-style momentum with gap:
    # mom[t] = log(P[t-1]) - log(P[t-gap-1])
    mom_raw = np.log(px_m.shift(1)) - np.log(px_m.shift(gap + 1))

    # Clean & shift by "gap" to avoid look-ahead
    mom_z = clean_data(mom_raw.shift(gap))

    # Subset factor metrics for K-means
    factors_kmeans = factors_std.loc[
        :, factors_std.columns.get_level_values("metric").isin(keep_metrics)
    ]
    factors_kmeans = factors_kmeans.ffill()
    factors_kmeans_m = ewma(factors_kmeans)

    # Dates where all needed pieces exist
    rebalance_dates = mom_z.dropna(how="all").index
    rebalance_dates = rebalance_dates.intersection(factors_kmeans_m.index)
    rebalance_dates = rebalance_dates.intersection(fwd_ret_m.index)

    # Sample window
    rebalance_dates = rebalance_dates[
        (rebalance_dates >= "2011-01-31") & (rebalance_dates <= "2020-12-31")
    ]

    return rebalance_dates, factors_kmeans_m, mom_z, fwd_ret_m

gap = 4
rebalance_dates, factors_kmeans_m, mom_z, fwd_ret_m = get_trade_setup(
    px_m, factors_std, gap=gap
)

# ============================================
# 5. K-means clusters (K = 30)
# ============================================
def get_clusters(rebalance_dates, factors_kmeans_m, K=30):
    clusters_dict = {}
    for t in rebalance_dates:
        row_t = factors_kmeans_m.loc[t]           # row: MultiIndex (metric, stock)
        X_t = row_t.unstack("metric")            # index: stock, columns: metric

        # require at least 70% of metrics per stock
        min_valid = int(0.7 * X_t.shape[1])
        X_t = X_t.dropna(axis=0, thresh=min_valid).fillna(0)

        if X_t.shape[0] < K:
            # not enough stocks, skip this date
            continue

        km = KMeans(n_clusters=K, n_init=50, random_state=0)
        labels = km.fit_predict(X_t.values)

        clusters_t = pd.Series(labels, index=X_t.index, name="cluster")
        clusters_dict[t] = clusters_t

    return clusters_dict

K = 30
clusters_dict = get_clusters(rebalance_dates, factors_kmeans_m, K=K)

# Keep only dates where we actually got clusters
rebalance_dates = pd.Index(clusters_dict.keys()).sort_values()

# ============================================
# 6. Clustered momentum signal (equal-weight LS within cluster)
# ============================================
def build_cluster_mom_signal_equal(rebalance_dates, clusters_dict, mom_z,
                                   top_per=0.2, bottom_per=0.2):

    all_stocks = mom_z.columns
    signal_list = []

    for t in rebalance_dates:
        clusters_t = clusters_dict[t]
        mom_t = mom_z.loc[t]

        sig_t = pd.Series(0.0, index=all_stocks)

        for cl in clusters_t.unique():
            in_cl = clusters_t[clusters_t == cl].index
            mom_cl = mom_t.reindex(in_cl).dropna()
            if len(mom_cl) < 5:
                continue

            n_top = max(1, int(len(mom_cl) * top_per))
            n_bot = max(1, int(len(mom_cl) * bottom_per))

            top_idx = mom_cl.nlargest(n_top).index
            bot_idx = mom_cl.nsmallest(n_bot).index

            sig_t.loc[top_idx] = 1.0
            sig_t.loc[bot_idx] = -1.0

        sig_t.name = t
        signal_list.append(sig_t)

    signal_df = pd.DataFrame(signal_list)
    signal_df.index.name = "date"
    return signal_df

cluster_mom_sig_equal = build_cluster_mom_signal_equal(
    rebalance_dates, clusters_dict, mom_z, top_per=0.2, bottom_per=0.2
)

# ============================================
# 7. Clustered momentum signal (momentum-weighted LS within cluster)
# ============================================
def build_cluster_mom_signal_weighted(rebalance_dates, clusters_dict, mom_z,
                                      top_per=0.2, bottom_per=0.2):
    """
    Within each cluster:
    - select top_per and bottom_per by momentum
    - long weights ∝ positive momentum
    - short weights ∝ negative momentum (more negative → larger short)
    Result is already dollar-neutral per cluster in spirit; final portfolio
    is normalized later anyway.
    """
    all_stocks = mom_z.columns
    signal_list = []

    for t in rebalance_dates:
        clusters_t = clusters_dict[t]
        mom_t = mom_z.loc[t]

        w_t = pd.Series(0.0, index=all_stocks)

        for cl in clusters_t.unique():
            in_cl = clusters_t[clusters_t == cl].index
            mom_cl = mom_t.reindex(in_cl).dropna()
            if len(mom_cl) < 5:
                continue

            n_top = max(1, int(len(mom_cl) * top_per))
            n_bot = max(1, int(len(mom_cl) * bottom_per))

            top = mom_cl.nlargest(n_top)
            bot = mom_cl.nsmallest(n_bot)

            # Long weights proportional to positive momentum
            if top.abs().sum() > 0:
                w_long = top / top.abs().sum()
            else:
                w_long = pd.Series(0.0, index=top.index)

            # Short weights proportional to |negative momentum|
            # End result: negative weights for shorts
            if bot.abs().sum() > 0:
                w_short = -bot.abs() / bot.abs().sum()
            else:
                w_short = pd.Series(0.0, index=bot.index)

            w_cluster = pd.concat([w_long, w_short])

            # Assign to overall vector
            w_t.loc[w_cluster.index] += w_cluster

        w_t.name = t
        signal_list.append(w_t)

    signal_df = pd.DataFrame(signal_list)
    signal_df.index.name = "date"
    return signal_df

cluster_mom_sig_weighted = build_cluster_mom_signal_weighted(
    rebalance_dates, clusters_dict, mom_z, top_per=0.2, bottom_per=0.2
)

# ============================================
# 8. Basic momentum signal (no clustering, equal-weight LS)
# ============================================
def build_basic_mom_signal(mom_df, dates, long_q=0.8, short_q=0.2):
    sig_list = []
    all_stocks = mom_df.columns

    for dt in dates:
        row = mom_df.loc[dt]
        r = row.rank(pct=True)

        sig = pd.Series(0.0, index=all_stocks)
        sig[r >= long_q] = 1.0
        sig[r <= short_q] = -1.0
        sig.name = dt

        sig_list.append(sig)

    signal_df = pd.DataFrame(sig_list)
    signal_df.index.name = "date"
    return signal_df

basic_mom_sig = build_basic_mom_signal(mom_z, rebalance_dates,
                                       long_q=0.8, short_q=0.2)

# ============================================
# 9. IC series for all three strategies
# ============================================
def calc_ic_series(signal_df, fwd_ret_df, dates):
    ic_list = []
    for dt in dates:
        sig = signal_df.loc[dt]
        ret = fwd_ret_df.loc[dt]

        mask = sig.notna() & ret.notna()
        if mask.sum() < 5:
            ic = np.nan
        else:
            ic, _ = spearmanr(sig[mask], ret[mask])
        ic_list.append(ic)

    return pd.Series(ic_list, index=dates)

ic_basic    = calc_ic_series(basic_mom_sig,              fwd_ret_m, rebalance_dates)
ic_cluster  = calc_ic_series(cluster_mom_sig_equal,      fwd_ret_m, rebalance_dates)
ic_clust_wt = calc_ic_series(cluster_mom_sig_weighted,   fwd_ret_m, rebalance_dates)

# ============================================
# 10. Convert signals to daily returns & cum returns
# ============================================
def signal_to_daily_returns(signal_df, px):
    """
    signal_df: index = rebalance_dates (monthly), values: weights or signals
    px       : daily price dataframe
    """
    daily_ret = px.pct_change()

    # Normalize each rebalance date: sum |w_i| = 1 (dollar-neutral gross)
    w = signal_df.copy()
    gross = w.abs().sum(axis=1).replace(0, np.nan)
    w = w.div(gross, axis=0)

    # Align to daily index, forward-fill
    w_daily = w.reindex(daily_ret.index, method="ffill").fillna(0)

    # Portfolio daily returns
    port_ret = (w_daily * daily_ret).sum(axis=1)

    # Start after first rebalance date
    first_reb = signal_df.index.min()
    port_ret = port_ret[port_ret.index >= first_reb]

    return port_ret

ret_basic    = signal_to_daily_returns(basic_mom_sig,            px)
ret_cluster  = signal_to_daily_returns(cluster_mom_sig_equal,    px)
ret_clust_wt = signal_to_daily_returns(cluster_mom_sig_weighted, px)

cum_basic    = (1 + ret_basic).cumprod()
cum_cluster  = (1 + ret_cluster).cumprod()
cum_clust_wt = (1 + ret_clust_wt).cumprod()

# ============================================
# 11. Plot IC over time (three lines)
# ============================================
plt.figure()
plt.plot(ic_basic.index,    ic_basic.values,
         label="Basic Momentum (gap=4)", alpha=0.7)
plt.plot(ic_cluster.index,  ic_cluster.values,
         label="Clustered Equal-Weight (K=30)", alpha=0.7)
plt.plot(ic_clust_wt.index, ic_clust_wt.values,
         label="Clustered Momentum-Weighted (K=30)", alpha=0.7)
plt.axhline(0, linestyle="--")
plt.title("Information Coefficient Over Time\nBasic vs Clustered (Equal vs Momentum-Weighted)")
plt.xlabel("Date")
plt.ylabel("Spearman Rank IC")
plt.grid(True)
plt.legend()
plt.show()

# Optional: 12M rolling IC
ic_basic_roll    = ic_basic.rolling(12).mean()
ic_cluster_roll  = ic_cluster.rolling(12).mean()
ic_clust_wt_roll = ic_clust_wt.rolling(12).mean()

plt.figure()
plt.plot(ic_basic_roll.index,    ic_basic_roll.values,
         label="Basic Momentum 12M Rolling IC", linewidth=2)
plt.plot(ic_cluster_roll.index,  ic_cluster_roll.values,
         label="Clustered Equal-Weight 12M Rolling IC", linewidth=2)
plt.plot(ic_clust_wt_roll.index, ic_clust_wt_roll.values,
         label="Clustered Momentum-Weighted 12M Rolling IC", linewidth=2)
plt.axhline(0, linestyle="--")
plt.title("12-Month Rolling IC\nBasic vs Clustered (Equal vs Momentum-Weighted)")
plt.xlabel("Date")
plt.ylabel("Rolling Spearman IC")
plt.grid(True)
plt.legend()
plt.show()

# ============================================
# 12. Plot cumulative returns (three lines)
# ============================================
plt.figure()
plt.plot(cum_basic.index,    cum_basic.values,
         label="Basic Momentum (gap=4)")
plt.plot(cum_cluster.index,  cum_cluster.values,
         label="Clustered Equal-Weight (K=30)")
plt.plot(cum_clust_wt.index, cum_clust_wt.values,
         label="Clustered Momentum-Weighted (K=30)")
plt.title("Cumulative Returns\nBasic vs Clustered (Equal vs Momentum-Weighted)")
plt.xlabel("Date")
plt.ylabel("Cumulative Return (× initial capital)")
plt.grid(True)
plt.legend()
plt.show()

# Quick sanity checks
print("Mean IC - Basic              :", ic_basic.mean())
print("Mean IC - Cluster Equal      :", ic_cluster.mean())
print("Mean IC - Cluster Weighted   :", ic_clust_wt.mean())
print("IC corr (Basic, Cluster Eq)  :", ic_basic.corr(ic_cluster))
print("IC corr (Basic, Cluster Wt)  :", ic_basic.corr(ic_clust_wt))
print("IC corr (Cluster Eq, Wt)     :", ic_cluster.corr(ic_clust_wt))