# Cohort Retention Story (v1.2)

## Executive Decision This Supports
Which first-order product families should be prioritized for retention tests (replenishment + returns mitigation), based on M2 logo retention and net revenue proxy?


## Frozen Definitions and Limits
- Driver: `first_product_family` only
- Horizon: `H=6` (`months_since_first` in 0..6 full grid)
- Cohort universe: non-guest customers with >=1 valid purchase
- Net retention proxy: `sum(net_revenue_proxy_total_t) / sum(gross_revenue_valid_t0)`
- Chart 2 excludes cohorts with baseline denominator = 0


In [None]:
from pathlib import Path
import json
import re
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

repo_root = Path.cwd().resolve()
if repo_root.name == 'notebooks':
    repo_root = repo_root.parent

src_root = repo_root / 'src'
if str(src_root) not in sys.path:
    sys.path.insert(0, str(src_root))

from retention.policies import (
    HORIZON_H,
    MIN_COHORT_N,
    MIN_PLOT_COHORT_N,
    OBSERVED_ONLY,
    RIGHT_CENSOR_MODE,
)

p1 = repo_root / 'data_processed' / 'chart1_logo_retention_heatmap.csv'
p2 = repo_root / 'data_processed' / 'chart2_net_proxy_curves.csv'
p3 = repo_root / 'data_processed' / 'chart3_m2_by_family.csv'
p4 = repo_root / 'data_processed' / 'appendix_top_products_in_chart3_targets.csv'
pg = repo_root / 'data_processed' / 'gate_a.json'
pscope = repo_root / 'data_processed' / 'scope_receipts.json'
coverage_path = repo_root / 'docs' / 'DRIVER_COVERAGE_REPORT.md'

for pth in [p1, p2, p3, p4, pg, pscope, coverage_path]:
    if not pth.exists():
        raise FileNotFoundError(f'Missing required artifact: {pth}')

chart1 = pd.read_csv(p1)
chart2 = pd.read_csv(p2)
chart3 = pd.read_csv(p3)
appendix_top_products = pd.read_csv(p4)
gate_a = json.loads(pg.read_text(encoding='ascii'))
scope_receipts = json.loads(pscope.read_text(encoding='ascii'))
raw_max_date = scope_receipts.get('raw_max_date')
raw_max_ts = pd.to_datetime(raw_max_date, errors='coerce')
if pd.isna(raw_max_ts):
    raise ValueError(f'Invalid raw_max_date in scope receipts: {raw_max_date}')
max_observed_month = str(raw_max_ts.to_period('M'))
coverage_text = coverage_path.read_text(encoding='ascii')

m1 = re.search(r'% gross revenue mapped to non-Other families:\s*([0-9]+\.?[0-9]*)%', coverage_text)
m2 = re.search(r'% customers with non-Other first_product_family:\s*([0-9]+\.?[0-9]*)%', coverage_text)
gross_non_other_pct = float(m1.group(1)) if m1 else None
customer_non_other_pct = float(m2.group(1)) if m2 else None

print(pd.DataFrame({
    'table': ['chart1', 'chart2_curves', 'chart3'],
    'rows': [len(chart1), len(chart2), len(chart3)]
}))


## Chart 1 - Logo Retention Heatmap
Each cell is a **logo retention rate** for one `(cohort_month, months_since_first)` pair.


In [None]:
heat = chart1.copy()
heat['months_since_first'] = heat['months_since_first'].astype(int)
pivot = heat.pivot(index='cohort_month', columns='months_since_first', values='logo_retention')
pivot = pivot.reindex(columns=list(range(HORIZON_H + 1))).sort_index()

n0 = heat[heat['months_since_first'] == 0][['cohort_month', 'n_customers']].drop_duplicates()
n0_map = dict(zip(n0['cohort_month'].astype(str), n0['n_customers'].astype(int)))

fig, ax = plt.subplots(figsize=(10, 6))
img = ax.imshow(pivot.values, aspect='auto', cmap='Blues', vmin=0, vmax=1)
ax.set_xticks(range(HORIZON_H + 1))
ax.set_xticklabels([str(i) for i in range(HORIZON_H + 1)])
ax.set_yticks(range(len(pivot.index)))
ax.set_yticklabels([f"{cm} (n0={n0_map.get(str(cm), 0)})" for cm in pivot.index.tolist()])
ax.set_xlabel('months_since_first')
ax.set_ylabel('cohort_month')
ax.set_title('Chart 1: Logo Retention Heatmap')
cbar = fig.colorbar(img, ax=ax)
cbar.set_label('logo retention rate')
plt.tight_layout()
plt.show()


**Chart 1 caption**
- Cell value = **logo retention rate** (`mean(is_retained_logo)`) at that cohort-month position.
- Missing cells are **right-censored** or suppressed; **missing ? 0**.
- Dataset scope: both sheets (`Year 2009-2010` + `Year 2010-2011`), global range `2009-12` to `2011-12`.


## Chart 2: Net retention proxy curves (3 cohorts max)
**Formula:** `net_retention_proxy(c,t) = sum(net_revenue_proxy_total for cohort c at month t) / denom_month0_gross_valid(c)`.

Denominator guard: cohorts require `denom_month0_gross_valid > 0` and `n_customers_m0 >= MIN_COHORT_N`.
Right-censoring: unobserved months are masked in chart as `NA` (not treated as `0`).


In [None]:
curves = chart2.copy()
curves['months_since_first'] = curves['months_since_first'].astype(int)
curves['is_observed'] = curves['is_observed'].astype(bool)
curves['eligible_cohort'] = curves['eligible_cohort'].astype(bool)
curves['selected_for_plot'] = curves['selected_for_plot'].astype(bool)

chart2_policy = scope_receipts.get('chart2_policy', {})
policy_min_cohort_n = int(chart2_policy.get('MIN_COHORT_N', MIN_COHORT_N))
policy_min_plot_cohort_n = int(chart2_policy.get('MIN_PLOT_COHORT_N', MIN_PLOT_COHORT_N))
plot_pool_count = int(scope_receipts.get('chart2_plot_pool_count', 0))
used_fallback = bool(scope_receipts.get('chart2_used_fallback', False))
selected_from_receipts = [str(x) for x in scope_receipts.get('chart2_selected_cohorts', [])]
if len(selected_from_receipts) == 0:
    raise ValueError('Chart 2 selection metadata missing selected cohorts in scope_receipts.json')

plot_df = curves[curves['eligible_cohort'] & curves['selected_for_plot']].copy()
cohorts = sorted(plot_df['cohort_month'].astype(str).unique().tolist())
if cohorts != sorted(selected_from_receipts):
    raise ValueError(
        'Chart 2 selection drift: notebook table selected cohorts do not match scope_receipts '        f"(table={cohorts}, receipts={sorted(selected_from_receipts)})"
    )
if len(cohorts) > 3:
    raise ValueError(f'Chart 2 contract violation: selected cohorts > 3 ({len(cohorts)})')
if len(cohorts) == 0:
    raise ValueError('Chart 2 contract violation: no eligible cohorts selected for plotting')

print('Selection Note')
print(
    f'Eligibility: n0 >= MIN_COHORT_N ({policy_min_cohort_n}), '
    'denom_month0_gross_valid > 0, cohort_period <= max_observed_month - H'
)
print('Ranking signal: M2 logo retention')
print('Selection: bottom/mid/top by M2 logo retention')
print(f'Plot-size preference: n0 >= MIN_PLOT_COHORT_N ({policy_min_plot_cohort_n})')
print(
    f'MIN_COHORT_N={policy_min_cohort_n}, MIN_PLOT_COHORT_N={policy_min_plot_cohort_n}, '
    f'used_fallback={used_fallback}, plot_pool_count={plot_pool_count}'
)

fig, ax = plt.subplots(figsize=(10, 6))
for cohort in cohorts:
    g = plot_df[plot_df['cohort_month'].astype(str) == cohort].sort_values('months_since_first', kind='stable').copy()
    y = g['net_retention_proxy'].where(g['is_observed'], np.nan)
    n0 = int(g['n_customers_m0'].iloc[0])
    ax.plot(g['months_since_first'], y, marker='o', linewidth=2, label=f'{cohort} (n0={n0})')

ax.axhline(1.0, color='gray', linestyle='--', linewidth=1)
ax.set_xticks(list(range(HORIZON_H + 1)))
ax.set_xlabel('months_since_first')
ax.set_ylabel('net_retention_proxy')
ax.set_title('Chart 2: Net retention proxy curves (3 cohorts max)')
ax.legend(title='cohort_month', frameon=False)
plt.tight_layout()
plt.show()

# Non-chart table: selected cohort metadata + M2/M6 metrics.
logo_m2 = chart1[chart1['months_since_first'] == 2][['cohort_month', 'logo_retention']].copy()
logo_m2 = logo_m2.rename(columns={'logo_retention': 'm2_logo_retention'})

curves_m2 = plot_df[plot_df['months_since_first'] == 2][['cohort_month', 'net_retention_proxy']].copy()
curves_m2 = curves_m2.rename(columns={'net_retention_proxy': 'm2_net_retention_proxy'})
curves_m6 = plot_df[plot_df['months_since_first'] == 6][['cohort_month', 'net_retention_proxy']].copy()
curves_m6 = curves_m6.rename(columns={'net_retention_proxy': 'm6_net_retention_proxy'})

base = plot_df[plot_df['months_since_first'] == 0][['cohort_month', 'n_customers_m0', 'denom_month0_gross_valid']].copy()
summary = (
    base.merge(logo_m2, on='cohort_month', how='left')
        .merge(curves_m2, on='cohort_month', how='left')
        .merge(curves_m6, on='cohort_month', how='left')
        .rename(columns={
            'n_customers_m0': 'n0',
        })
        .sort_values('m2_logo_retention', kind='stable')
        .reset_index(drop=True)
)
summary['plot_floor_pass'] = summary['n0'] >= policy_min_plot_cohort_n
summary = summary[
    [
        'cohort_month',
        'n0',
        'denom_month0_gross_valid',
        'm2_logo_retention',
        'm2_net_retention_proxy',
        'm6_net_retention_proxy',
        'plot_floor_pass',
    ]
]
print(f'Selected cohorts table (MIN_COHORT_N={policy_min_cohort_n}, MIN_PLOT_COHORT_N={policy_min_plot_cohort_n})')
from IPython.display import display
display(summary)



## Chart 3: M2 retention by first_product_family
Top 8 families + Other with sample size labels (decision-oriented sort).


In [None]:
d = chart3.copy()
# chart3 is already sorted worst-opportunity first by the table builder.
families = d['family_group'].astype(str).tolist()
y = np.arange(len(families))
logo = d['m2_logo_retention'].astype(float).values
netp = d['m2_net_proxy_retention'].astype(float).values
n = d['n_customers'].astype(int).values

fig, ax = plt.subplots(figsize=(11, 6))
for i in range(len(families)):
    ax.plot([logo[i], netp[i]], [y[i], y[i]], color='#999999', linewidth=2, zorder=1)
ax.scatter(logo, y, color='#4C78A8', label='M2 logo retention', zorder=2)
ax.scatter(netp, y, color='#F58518', label='M2 net proxy retention', zorder=2)
for i in range(len(families)):
    ax.text(max(logo[i], netp[i]) + 0.01, y[i], f'n={int(n[i])}', va='center', fontsize=9)

ax.set_yticks(y)
ax.set_yticklabels(families)
ax.invert_yaxis()
ax.set_xlim(0, 1.05)
ax.set_xlabel('retention (ratio)')
ax.set_title('M2 Retention: Logo vs Net Proxy by first_product_family (Top 8 + Other)')
ax.legend(frameon=False)
plt.tight_layout()
plt.show()

# Non-chart callout: top refund-drag families by gap (net - logo).
callout = d.copy()
callout['net_minus_logo'] = callout['m2_net_proxy_retention'] - callout['m2_logo_retention']
callout = callout[callout['family_group'].astype(str) != 'Other'].copy()
callout = callout.sort_values(['net_minus_logo', 'n_customers', 'family_group'], ascending=[True, False, True], kind='stable')
refund_drag = callout[callout['net_minus_logo'] < 0].head(2).copy()
if len(refund_drag) == 0:
    print('Refund-drag callout: no families with net proxy below logo retention at M2.')
else:
    parts = []
    for _, row in refund_drag.iterrows():
        parts.append(f"{row['family_group']} ({row['net_minus_logo']*100:.1f}pp)")
    print('Refund-drag signal (top 2 by net minus logo at M2): ' + '; '.join(parts))


In [None]:
# Write memo_numbers.json (curve contract for Chart 2).
curves = chart2.copy()
curves['eligible_cohort'] = curves['eligible_cohort'].astype(bool)
curves['selected_for_plot'] = curves['selected_for_plot'].astype(bool)
curves['is_observed'] = curves['is_observed'].astype(bool)

plot_df = curves[curves['eligible_cohort'] & curves['selected_for_plot']].copy()
plot_df.loc[~plot_df['is_observed'], 'net_retention_proxy'] = np.nan
selected_cohorts = sorted(plot_df['cohort_month'].astype(str).unique().tolist())

m2 = plot_df[plot_df['months_since_first'] == 2]['net_retention_proxy'].dropna()
m6 = plot_df[plot_df['months_since_first'] == 6]['net_retention_proxy'].dropna()

fam_sorted = chart3.copy().sort_values(['m2_logo_retention', 'n_customers', 'family_group'], ascending=[True, False, True], kind='stable')
fam_no_other = fam_sorted[fam_sorted['family_group'] != 'Other'].reset_index(drop=True)
worst_fam = fam_no_other.iloc[0] if len(fam_no_other) else fam_sorted.iloc[0]
best_fam = fam_no_other.iloc[-1] if len(fam_no_other) else fam_sorted.iloc[-1]

memo_numbers = {
    'gate_a_pct_valid_nonpositive_net': float(gate_a['gate_a_pct_valid_nonpositive_net']),
    'gate_a_trigger_fired': bool(gate_a['trigger_fired']),
    'gate_b_gross_non_other_pct': gross_non_other_pct,
    'gate_b_customer_non_other_pct': customer_non_other_pct,
    'chart2_selected_cohorts': selected_cohorts,
    'chart2_selected_count': int(len(selected_cohorts)),
    'chart2_m2_net_proxy_median': float(m2.median()) if len(m2) else None,
    'chart2_m6_net_proxy_median': float(m6.median()) if len(m6) else None,
    'chart3_worst_family': str(worst_fam['family_group']),
    'chart3_worst_family_m2_logo': float(worst_fam['m2_logo_retention']),
    'chart3_worst_family_m2_net_proxy': float(worst_fam['m2_net_proxy_retention']),
    'chart3_worst_family_n': int(worst_fam['n_customers']),
    'chart3_best_family': str(best_fam['family_group']),
    'chart3_best_family_m2_logo': float(best_fam['m2_logo_retention']),
    'chart3_best_family_m2_net_proxy': float(best_fam['m2_net_proxy_retention']),
    'chart3_best_family_n': int(best_fam['n_customers']),
}

out_path = repo_root / 'data_processed' / 'memo_numbers.json'
out_path.write_text(json.dumps(memo_numbers, indent=2), encoding='ascii')
print(f'Wrote {out_path}')


## Appendix: Top Products Driving Net Proxy
How to use this table:
- Start with rank_priority target families (1..3), then scan SKUs with most negative `delta_net_proxy_vs_family_pp` for returns/credits mitigation.
- Pair those with high `share_of_family_m0_gross_pct` to prioritize high-impact fixes first.
- Use `m2_logo_retention_observed` and `m2_net_retention_proxy_observed` jointly: repeat weakness vs value drag require different plays.


In [None]:
from IPython.display import display

appendix = appendix_top_products.copy()
appendix['first_product_family'] = appendix['first_product_family'].astype(str)

if appendix.empty:
    print('appendix_top_products_in_chart3_targets.csv is empty under current policy filters.')
else:
    families = appendix['first_product_family'].drop_duplicates().tolist()
    print('Source: data_processed/appendix_top_products_in_chart3_targets.csv')
    for fam in families:
        fam_df = appendix[appendix['first_product_family'] == fam].copy()
        fam_df = fam_df.sort_values(['m0_gross_valid', 'sku'], ascending=[False, True], kind='stable')
        show = fam_df[
            [
                'sku',
                'description',
                'n_customers_m0',
                'm0_gross_valid',
                'm2_logo_retention_observed',
                'm2_net_retention_proxy_observed',
                'delta_net_proxy_vs_family_pp',
                'share_of_family_m0_gross_pct',
            ]
        ].copy()
        print(f'\nFamily: {fam} (rows={len(show)})')
        display(show)


## Interpretation Guardrails
- Descriptive analytics only; no causal claims.
- Wholesale-like sensitivity remains a QA table check (Gate C), not a chart.
- Exactly 3 charts are included in this notebook.
