In [8]:
def cohort_analysis(df, idCol, dateCol, valCol, agg = 'nunique'):
    
    # Step 0: Import libraries
    import pandas as pd
    import matplotlib.pyplot as plt
    import matplotlib.colors as mcolors
    import seaborn as sns
    import numpy as np
    from operator import attrgetter
    
    # Step 1: Select relevant columns
    df = df[[idCol, dateCol, valCol]]

    # Step 2: Create cohort and order_month variables
    df['order_month'] = df[dateCol].dt.to_period('M')

    df['cohort'] = df.groupby(idCol)[dateCol] \
                     .transform('min') \
                     .dt.to_period('M')

    # Step 3: Aggregate the data by cohort and order_month
    df_cohort = df.groupby(['cohort','order_month']) \
                  .agg(aggregate_value=(valCol, agg)) \
                  .reset_index(drop=False)

    df_cohort['period_number'] = (df_cohort['order_month'] - df_cohort['cohort']).apply(attrgetter('n'))

    # Step 4: Pivot the data
    cohort_pivot = df_cohort.pivot_table(index = 'cohort'
                                        ,columns = 'period_number'
                                        ,values = 'aggregate_value')

    # Step 5: Divide by initial cohort size
    cohort_size = cohort_pivot.iloc[:,0]

    retention_matrix = cohort_pivot.divide(cohort_size, axis=0)

    # Step 6: Create heatmap
    with sns.axes_style("dark"):
        fig, ax = plt.subplots(1, 2, figsize=(12, 8), sharey=True, gridspec_kw={'width_ratios': [1, 11]})

        # retention matrix
        sns.heatmap(retention_matrix, 
                    mask=retention_matrix.isnull(), 
                    annot=True, 
                    fmt='.0%', 
                    cmap='RdYlGn', 
                    ax=ax[1])
        ax[1].set_title('Monthly Cohort Analysis', fontsize=16)
        ax[1].set(xlabel='# of periods',
                  ylabel='')

        # cohort size
        cohort_size_df = pd.DataFrame(cohort_size).rename(columns={0: 'cohort_size'})
        sns.heatmap(cohort_size_df, 
                    annot=True, 
                    cbar=False, 
                    fmt='g', 
                    cmap='RdYlGn',
                    ax=ax[0])

        fig.tight_layout()
        
        return cohort_pivot, retention_matrix