# Counterparty **Re‑clustering** Workflow
This notebook demonstrates an end‑to‑end workflow for:
1. **Rolling‑window re‑clustering** of equities trading counterparties.
2. **Evaluating cluster quality** (silhouette, Davies‑Bouldin, etc.).
3. **Tracking feature importance drift** over time.
4. Identifying **emerging clusters** and intra‑cluster characteristics.

It uses the `constrained_clustering` utility module we developed.
Feel free to tweak window sizes, feature sets, or visual styles.

## 0. Setup & dependencies
Install any missing packages (run once):

In [None]:
## Uncomment if needed
# !pip install polars scikit-learn xgboost matplotlib nbformat tqdm

## 1. Imports & data

In [None]:
import polars as pl
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm

import constrained_clustering as cc

# Adjust the glob to your parquet location
TRADE_GLOB = Path('/data/equities/trades_*.parquet')

trades = cc.load_trade_history(str(TRADE_GLOB))
print(f'Trades loaded: {trades.shape[0]:,}')

## 2. Generate rolling windows

In [None]:
# 90‑day windows stepped every 30 days
windows = cc.generate_time_windows(trades, window_days=90, step_days=30)
print(f'{len(windows)} windows spanning {windows[0][0].date()} – {windows[-1][1].date()}')

## 3. Choose *k* (number of clusters) via silhouette

In [None]:
k_range = range(4, 13)  # test 4–12 clusters
silhouette_by_k = {k: [] for k in k_range}

for start, end in tqdm(windows, desc='Windows'):
    win_trades = trades.filter((pl.col('ts') >= start) & (pl.col('ts') < end))
    if win_trades.is_empty():
        continue
    X, agg, _ = cc.build_counterparty_features(win_trades)
    for k in k_range:
        labels, _ = cc.cluster_counterparties(X, n_clusters=k, size_min=5)
        sil = cc.evaluate_clustering(X, labels)['silhouette']
        silhouette_by_k[k].append(sil)

import pandas as pd
sil_df = pd.DataFrame({k: v for k, v in silhouette_by_k.items()})
sil_df.index.name = 'window_idx'
sil_df.head()

### 3.1 Silhouette distribution across *k*

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
ax.boxplot([sil_df[k].dropna() for k in k_range], labels=k_range, showfliers=False)
ax.set_xlabel('Number of clusters (k)')
ax.set_ylabel('Silhouette score')
ax.set_title('Silhouette distribution over windows')
plt.tight_layout()

## 4. Run full pipeline with selected *k*

In [None]:
K_SELECTED = int(sil_df.mean().idxmax())  # best average silhouette
print(f'Choosing k = {K_SELECTED}')

results = cc.cluster_over_time(trades, windows,
                               n_clusters=K_SELECTED,
                               size_min=5,
                               max_share=0.20)

## 5. Cluster stability metrics

In [None]:
stability = cc.evaluate_stability_over_time(results)
stability.head()

### 5.1 Plot Adjusted Rand & silhouette over time

In [None]:
fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(stability['window_end'], stability['adjusted_rand'], marker='o')
ax.set_xlabel('Window end date')
ax.set_ylabel('Adjusted Rand index')
ax.set_title('Cross‑window cluster stability')
plt.xticks(rotation=45)
plt.tight_layout()

## 6. Feature importance drift (XGBoost)

In [None]:
import xgboost as xgb

feat_cols = ['avg_pnl', 'std_pnl', 'tot_ntl', 'trade_count', 'avg_participation']
fi_over_time = {c: [] for c in feat_cols}

for res in tqdm(results, desc='Windows (XGB)'):
    X = res.agg.select(feat_cols).to_numpy()
    y = res.labels
    model = xgb.XGBClassifier(max_depth=3, n_estimators=100, learning_rate=0.1)
    model.fit(X, y)
    importances = model.feature_importances_
    for c, imp in zip(feat_cols, importances):
        fi_over_time[c].append(imp)

fi_df = pd.DataFrame(fi_over_time)
fi_df.head()

### 6.1 Plot top‑3 features importance trend

In [None]:
top_feats = fi_df.mean().sort_values(ascending=False).head(3).index
fig, ax = plt.subplots(figsize=(9, 4))
for c in top_feats:
    ax.plot(fi_df.index, fi_df[c], marker='o', label=c)
ax.set_xlabel('Window index')
ax.set_ylabel('Gain importance')
ax.set_title('Top‑3 feature importance drift')
ax.legend()
plt.tight_layout()

## 7. Intra‑cluster feature distributions (latest window)

In [None]:
latest = results[-1]
clusters = np.unique(latest.labels)
fig, axes = plt.subplots(len(clusters), 3, figsize=(9, 3*len(clusters)), sharex=False)
for i, k in enumerate(clusters):
    mask = latest.labels == k
    subset = latest.agg.filter(mask)
    axes[i, 0].hist(subset['avg_pnl'], bins=30)
    axes[i, 0].set_ylabel(f'Cluster {k}')
    axes[i, 1].hist(subset['std_pnl'], bins=30)
    axes[i, 2].hist(subset['tot_ntl'], bins=30)
axes[0, 0].set_title('avg_pnl')
axes[0, 1].set_title('std_pnl')
axes[0, 2].set_title('tot_ntl')
plt.tight_layout()

## 8. Emerging clusters
Identify clusters that have grown by >50 % in counterparties or notional since the previous window.

In [None]:
emerging = []
for prev, curr in zip(results[:-1], results[1:]):
    prev_df = prev.summary.to_pandas().set_index('cluster_notional').sort_index()
    curr_df = curr.summary.to_pandas().set_index('cluster_notional').sort_index()
    common = prev_df.index.intersection(curr_df.index)
    growth = (curr_df.loc[common]['n_counterparties'] - prev_df.loc[common]['n_counterparties']) / prev_df.loc[common]['n_counterparties']
    big = growth[growth > 0.5]
    if not big.empty:
        emerging.append({'window_end': curr.end, 'clusters': big.index.tolist(), 'growth': big.values.tolist()})

pd.DataFrame(emerging)

---
### Next steps
* Tune `size_min`, `max_share`, or window specs.
* Push summary tables into your reporting layer (Dash, Superset, etc.).
* Automate re‑clustering on a schedule with your simulator back‑tests.