# Retention
### Get a 12 month cohort analysis in a plotly heatmap from a Stripe csv

First of all, we open the csv and convert it to a dataframe. Next, we try to match unique user IDs and timestamps for the cohort analysis to work. Then we do a lot of magic for making cohort analyses, as copy and pasted from a tutorial. Finally we get an output plot, which we render inside the jupyter notebook, and export as a json that can be rendered in glint.

If the supplied CSV doesn't contain **user ids** and **timestamps** that can be parsed, this notebook will fail.

In [398]:
import pandas as pd
import numpy as np
import json
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

# stay offline for plotly -> fixme to ignore?
init_notebook_mode(connected=True)

In [371]:
# using sample data from https://github.com/thepag/stripe-csv-audit/blob/master/examples/sample.csv
file_name = "../data/sample.csv"

In [394]:
# read csv, if that doesn't work, coerce with latin-1 encoding
try:
    df = pd.read_csv(file_name)
except:
    df = pd.read_csv(file_name, encoding='latin-1')

In [372]:
def identify_user_timestamp_columns(df):
    """Takes in a dataframe, and tries to identify user ids
    and timestamps automatically. Returns the names of columns 
    that appear to match the criteria.
    
    NOTE: would be amazeballs to put in something like Sherlock
    or Sato at this stage to facilitate process.
    
    Args:
        df (DataFrame): an input DataFrame
    
    Returns:
        user_id_column (str): likely user id column
        timestamp_column (str): likely user id column"""

    # lists of common names for columns
    user_columns = ['user', 'client', 'userid', 'clientid', 'user_id', 'client_id', 'client.email', 'customer id']
    timestamp_columns = ['timestamp', 'state.openTimestamp', 'datetime', 'created (utc)']

    # create empty vars
    user_id_column = ""
    timestamp_column = ""

    for column in list(df.columns):
        if column.lower() in user_columns:
            user_id_column = column
        elif column.lower() in timestamp_columns:
            timestamp_column = column

    print(f'Identified "{user_id_column}" as user col, "{timestamp_column}" as timestamp col')
    
    return user_id_column, timestamp_column

In [373]:
user_id_column, timestamp_column = identify_user_timestamp_columns(df)

Identified "Customer ID" as user col, "Created (UTC)" as timestamp col


In [374]:
# clean up timestamps and parse out month
df['native_timestamp'] = pd.to_datetime(df[timestamp_column], errors='coerce')
df['month'] = df['native_timestamp'].dt.strftime('%Y-%m')

In [375]:
# create retention dataframe
retention = pd.DataFrame({'UserId': df[user_id_column], 'OrderId': df.index, 'OrderPeriod': df['month']})

# set the index to customer ids
retention.set_index('UserId', inplace=True)

# find the first time each customer ordered to get the cohort group
retention['CohortGroup'] = retention.groupby(level=0)['OrderPeriod'].min()
retention.reset_index(inplace=True)

# group by cohort and order period
grouped = retention.groupby(['CohortGroup', 'OrderPeriod'])

# count the unique users, orders, and total revenue per Group + Period
cohorts = grouped.agg({'UserId': pd.Series.nunique,
                       'OrderId': pd.Series.nunique})

# make the column names more meaningful
cohorts.rename(columns={'UserId': 'TotalUsers',
                        'OrderId': 'TotalOrders'}, inplace=True)

In [376]:
def cohort_period(df):
    """
    Creates a `CohortPeriod` column, which is the Nth period based on the user's first purchase.
    
    Example
    -------
    Say you want to get the 3rd month for every user:
        df.sort(['UserId', 'OrderTime', inplace=True)
        df = df.groupby('UserId').apply(cohort_period)
        df[df.CohortPeriod == 3]
    """
    df['CohortPeriod'] = np.arange(len(df)) + 1
    return df

cohorts = cohorts.groupby(level=0).apply(cohort_period)

In [377]:
# reindex the DataFrame
cohorts.reset_index(inplace=True)
cohorts.set_index(['CohortGroup', 'CohortPeriod'], inplace=True)

# create a Series holding the total size of each CohortGroup
cohort_group_size = cohorts['TotalUsers'].groupby(level=0).first()

# unstack cohort
user_retention = cohorts['TotalUsers'].unstack(0).divide(cohort_group_size, axis=1)

KeyError: "None of ['CohortGroup', 'CohortPeriod'] are in the columns"

In [382]:
# take last 12 months for clean plot
last_12 = user_retention.T.iloc[-12:]

# make it descending for nicer plotly viz
last_12 = last_12.sort_index(ascending=False)

# remove empty columns
last_12 = last_12.dropna(how='all', axis=1)

AttributeError: 'str' object has no attribute 'T'

In [383]:
# multiply by 100 to get the percentage
last_12[last_12.select_dtypes(include=['number']).columns] *= 100

# round down to 2 decimal places
last_12 = np.round(last_12, decimals=2)

# drop first column because it's 100%
last_12 = last_12.drop(1,1)

KeyError: '[1] not found in axis'

In [396]:
# create a fig for rendering inside of jupyter notebook
colorscale = [[0, 'rgb(255, 255, 255)'],[0.5, 'rgb(227, 0, 6)'], [1, 'rgb(208, 2, 100)']]
fig = go.Figure(data=go.Heatmap(z=last_12, y=last_12.index, colorscale=colorscale))
fig['layout']['yaxis']['autorange'] = "reversed"
fig['layout']['plot_bgcolor'] = 'rgba(0,0,0,0)'

fig.update_layout(
    title='12 month user retention',
    xaxis_title="Cohort period",
    yaxis_title="Cohort group",
    margin=dict(
        pad=10
    ),
    font=dict(
        family="-apple-system, BlinkMacSystemFont, 'Segoe UI', 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'Helvetica Neue', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';",
        size=12,
        color="#7f7f7f"
    )
)

fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='#EEEEEE')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='#EEEEEE')

fig.show()

ValueError: 
    Invalid value of type 'builtins.builtin_function_or_method' received for the 'y' property of heatmap
        Received value: <built-in method index of str object at 0x107f453b0>

    The 'y' property is an array that may be specified as a tuple,
    list, numpy array, or pandas Series

In [393]:
# convert to json
output_json = fig.to_json()

# apply tweaks by unconverting from jason
# (if you don't do this you get an ndarray error)
output = json.loads(output_json)
output['layout']['template']['data']['table'][0]['cells']['fill']['color'] = "#FFFFFF"
output['layout']['template']['layout']['yaxis']['autorange'] = "reversed"

# export for plotly component in glint
print(json.dumps(output))

AttributeError: 'str' object has no attribute 'to_json'