In [4]:
# !pip install arviz --user

In [9]:
# !pip install pymc --user

In [2]:
# !pip install pymc.sampling_jax --user

In [1]:
# !pip install matplotlib --upgrade

In [4]:
# !pip install jax --user

In [6]:
# import arviz as az
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mtick
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
# import pymc.sampling_jax
import seaborn as sns
from scipy.special import expit
from sklearn.preprocessing import LabelEncoder, StandardScaler


plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

In [7]:
# !pip install jaxlib  --user

In [8]:
class CohortDataGenerator:
    def __init__(
        self,
        rng: np.random.Generator,
        start_cohort: str,
        n_cohorts,
        user_base: int = 10_000,
    ) -> None:
        self.rng = rng
        self.start_cohort = start_cohort
        self.n_cohorts = n_cohorts
        self.user_base = user_base

    def _generate_cohort_labels(self) -> pd.DatetimeIndex:
        return pd.period_range(
            start="2020-01-01", periods=self.n_cohorts, freq="M"
        ).to_timestamp()

    def _generate_cohort_sizes(self) -> npt.NDArray[np.int_]:
        ones = np.ones(shape=self.n_cohorts)
        trend = ones.cumsum() / ones.sum()
        return (
            (
                self.user_base
                * trend
                * self.rng.gamma(shape=1, scale=1, size=self.n_cohorts)
            )
            .round()
            .astype(int)
        )

    def _generate_dataset_base(self) -> pd.DataFrame:
        cohorts = self._generate_cohort_labels()
        n_users = self._generate_cohort_sizes()
        data_df = pd.merge(
            left=pd.DataFrame(data={"cohort": cohorts, "n_users": n_users}),
            right=pd.DataFrame(data={"period": cohorts}),
            how="cross",
        )
        data_df["age"] = (data_df["period"].max() - data_df["cohort"]).dt.days
        data_df["cohort_age"] = (data_df["period"] - data_df["cohort"]).dt.days
        data_df = data_df.query("cohort_age >= 0")
        return data_df

    def _generate_retention_rates(self, data_df: pd.DataFrame) -> pd.DataFrame:
        data_df["retention_true_mu"] = (
            -data_df["cohort_age"] / (data_df["age"] + 1)
            + 0.8 * np.cos(2 * np.pi * data_df["period"].dt.dayofyear / 365)
            + 0.5 * np.sin(2 * 3 * np.pi * data_df["period"].dt.dayofyear / 365)
            - 0.5 * np.log1p(data_df["age"])
            + 1.0
        )
        data_df["retention_true"] = expit(data_df["retention_true_mu"])
        return data_df

    def _generate_user_history(self, data_df: pd.DataFrame) -> pd.DataFrame:
        data_df["n_active_users"] = self.rng.binomial(
            n=data_df["n_users"], p=data_df["retention_true"]
        )
        data_df["n_active_users"] = np.where(
            data_df["cohort_age"] == 0, data_df["n_users"], data_df["n_active_users"]
        )
        return data_df

    def run(
        self,
    ) -> pd.DataFrame:
        return (
            self._generate_dataset_base()
            .pipe(self._generate_retention_rates)
            .pipe(self._generate_user_history)
        )

In [9]:
seed: int = sum(map(ord, "retention"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

start_cohort: str = "2020-01-01"
n_cohorts: int = 48


cohort_generator = CohortDataGenerator(rng=rng, start_cohort=start_cohort, n_cohorts=n_cohorts)
data_df = cohort_generator.run()

# calculate retention rates
data_df["retention"] = data_df["n_active_users"] / data_df["n_users"]

data_df.head()

Unnamed: 0,cohort,n_users,period,age,cohort_age,retention_true_mu,retention_true,n_active_users,retention
0,2020-01-01,150,2020-01-01,1430,0,-1.807373,0.140956,150,1.0
1,2020-01-01,150,2020-02-01,1430,31,-1.474736,0.186224,25,0.166667
2,2020-01-01,150,2020-03-01,1430,60,-2.281286,0.092685,13,0.086667
3,2020-01-01,150,2020-04-01,1430,91,-3.20661,0.038918,6,0.04
4,2020-01-01,150,2020-05-01,1430,121,-3.112983,0.042575,2,0.013333


In [10]:
data_df.tail()

Unnamed: 0,cohort,n_users,period,age,cohort_age,retention_true_mu,retention_true,n_active_users,retention
2206,2023-10-01,10891,2023-11-01,61,31,-1.175181,0.23592,2541,0.233312
2207,2023-10-01,10891,2023-12-01,61,61,-1.851651,0.135679,1487,0.136535
2254,2023-11-01,4769,2023-11-01,30,0,-0.328608,0.418579,4769,1.0
2255,2023-11-01,4769,2023-12-01,30,30,-1.488948,0.18408,861,0.180541
2303,2023-12-01,4535,2023-12-01,0,0,1.195787,0.767775,4535,1.0
