# Cohort Analysis

Understand customer retention over time.

## What is a Cohort?
A group of customers who share a characteristic (usually first purchase date).

## Goal
Visualize retention rates (e.g., how many users return 3 months later?).

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Load Data
try:
    df = pd.read_csv('ecommerce_data.csv')
    df['InvoiceDate'] = pd.to_datetime(df['InvoiceDate'])
except FileNotFoundError:
    print("Generate data first!")

## 1. Create Cohort Month

In [None]:
# Function to get month
def get_month(x): 
    return datetime(x.year, x.month, 1)

from datetime import datetime
df['InvoiceMonth'] = df['InvoiceDate'].apply(get_month)

# Get First Month for each customer
grouping = df.groupby('CustomerID')['InvoiceMonth']
df['CohortMonth'] = grouping.transform('min')

print(df[['CustomerID', 'InvoiceDate', 'CohortMonth']].head())

## 2. Calculate Cohort Index
Months passed since first purchase.

In [None]:
def get_date_int(df, column):
    year = df[column].dt.year
    month = df[column].dt.month
    return year, month

invoice_year, invoice_month = get_date_int(df, 'InvoiceMonth')
cohort_year, cohort_month = get_date_int(df, 'CohortMonth')

years_diff = invoice_year - cohort_year
months_diff = invoice_month - cohort_month

df['CohortIndex'] = years_diff * 12 + months_diff + 1
print(df['CohortIndex'].head())

## 3. Create Retention Matrix
Count active customers in each cohort.

In [None]:
grouping = df.groupby(['CohortMonth', 'CohortIndex'])
cohort_data = grouping['CustomerID'].apply(pd.Series.nunique)
cohort_data = cohort_data.reset_index()

cohort_counts = cohort_data.pivot(index='CohortMonth', columns='CohortIndex', values='CustomerID')
print(cohort_counts.head())

## 4. Retention Rate Percentage

In [None]:
cohort_sizes = cohort_counts.iloc[:,0]
retention = cohort_counts.divide(cohort_sizes, axis=0)
retention.index = retention.index.strftime('%Y-%m')

plt.figure(figsize=(12, 8))
plt.title('Retention Rates')
sns.heatmap(data=retention, annot=True, fmt='.0%', vmin=0.0, vmax=0.5, cmap='BuGn')
plt.show()