# Paid retention sBG (Shifted Beta-Geometric)


## Fake data

In [5]:
import pandas as pd
import numpy as np
from faker import Faker

def create_data(num_customers=1000):
    fake = Faker()
    np.random.seed(42)
    
    customer_data = []
    
    for _ in range(num_customers):
        customer_id = fake.uuid4()
        start_date = fake.date_between(start_date='-4y', end_date='today')
        paid_conversion_date = fake.date_between(start_date=start_date, end_date='today')
        churn_date = fake.date_between(start_date=paid_conversion_date, end_date='today') if np.random.rand() > 0.5 else None
        
        dim_1 = fake.random_int(min=1, max=5)
        dim_2 = fake.random_element(elements=('A', 'B', 'C', 'D', 'E'))
        dim_3 = fake.random_number(digits=2)
        
        customer_data.append([
            customer_id,
            paid_conversion_date,
            churn_date,
            dim_1,
            dim_2,
            dim_3
        ])
    
    columns = ['customer_id', 'paid_conversion_date', 'churn_date', 'dim_1', 'dim_2', 'dim_3']
    return pd.DataFrame(customer_data, columns=columns)

# Generate the data
data = create_data()
data

Unnamed: 0,customer_id,paid_conversion_date,churn_date,dim_1,dim_2,dim_3
0,51114ad3-b445-4ac7-a58c-179d9d96dedd,2021-08-17,,3,C,98
1,1e798f6b-115f-4aba-b888-7bb313bfc579,2024-01-12,2024-02-17,5,D,60
2,c0e9712b-3e55-4e5e-afe0-be74792d3503,2024-02-03,2024-05-07,1,A,46
3,e58c8e0b-e84a-47d3-bdda-4fbb16e64b34,2024-05-17,2024-06-08,3,D,94
4,5a23b1be-65fc-4bb7-9b78-66cf70c93f3a,2024-05-05,,4,B,33
...,...,...,...,...,...,...
995,9ce5e8ce-2550-45bb-bce0-deb1dfe6e8f5,2024-06-24,,3,D,79
996,59de8361-ae72-4a69-a884-594f9a113b7c,2024-04-18,2024-05-24,2,D,11
997,8e31f4d9-d257-488c-8f76-17d3920872b9,2022-10-29,,1,E,72
998,25d45bf4-c99e-48a7-a67f-e5cb61184650,2024-04-09,2024-05-12,4,D,92


## Data preprocessing

In [19]:
import pandas as pd
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import matplotlib.pyplot as plt

def preprocess_data(data):
    data['paid_conversion_date'] = pd.to_datetime(data['paid_conversion_date'])
    data['churn_date'] = pd.to_datetime(data['churn_date'], errors='coerce')

    # Assume period is in months
    data['period'] = ((data['churn_date'].fillna(pd.Timestamp('today')) - data['paid_conversion_date']).dt.days // 30).astype(int)

    # Number of customers at each period
    period_max = data['period'].max()
    customers_remaining = [(data['period'] >= t).sum() for t in range(period_max + 1)]
    churned = [customers_remaining[t] - customers_remaining[t+1] if t < period_max else customers_remaining[t] for t in range(period_max + 1)]
    
    churn_data = pd.DataFrame({
        'period': range(period_max + 1),
        'customers_remaining': customers_remaining,
        'churned': churned
    })
    
    data_dict = {
        "customers_remaining": torch.tensor(customers_remaining, dtype=torch.float),
        "churned": torch.tensor(churned, dtype=torch.float)
    }
    
    return churn_data, data_dict

churn_data, data_dict = preprocess_data(data)

# Display the first few rows of the churn_data DataFrame
churn_data

Unnamed: 0,period,customers_remaining,churned
0,0,1000,180
1,1,820,110
2,2,710,75
3,3,635,54
4,4,581,50
5,5,531,37
6,6,494,42
7,7,452,30
8,8,422,28
9,9,394,29


## sBG model

- Customers renew or cancel their contracts based on a fixed probability (coin flip model).
- This probability varies across customers and follows a beta distribution.
- The model assumes customer behavior remains consistent over time.

In [15]:
def sBG_model(data_dict):
    customers_remaining = data_dict["customers_remaining"]
    churned = data_dict["churned"]
    
    # Priors for alpha and beta
    alpha = pyro.sample('alpha', dist.Gamma(1.0, 1.0))
    beta = pyro.sample('beta', dist.Gamma(1.0, 1.0))

    # Survivor function S(t) for periods 0 to n_periods
    S = [torch.tensor(1.0)] + [torch.exp(dist.Beta(alpha, beta + t).log_prob(torch.tensor(0.0))) for t in range(1, len(customers_remaining))]
    
    # Churn probability for each period
    P = torch.tensor([S[t-1] - S[t] for t in range(1, len(customers_remaining))], dtype=torch.float)

    # Observation of churned customers
    with pyro.plate('data', len(churned)):
        pyro.sample('obs', dist.Binomial(total_count=customers_remaining[:-1], probs=P), obs=churned)

# Inference using NUTS
nuts_kernel = NUTS(sBG_model)
mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=500)
mcmc.run(data_dict)

# Extract samples
posterior_samples = mcmc.get_samples()
alpha_post = posterior_samples['alpha']
beta_post = posterior_samples['beta']

# Plot posterior distributions
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(alpha_post.numpy(), bins=50, alpha=0.7, label='alpha')
plt.legend()

plt.subplot(1, 2, 2)
plt.hist(beta_post.numpy(), bins=50, alpha=0.7, label='beta')
plt.legend()
plt.show()

# Summarize the posterior distributions
print("Posterior summary:")
print(f"Alpha: Mean={alpha_post.mean():.2f}, Std={alpha_post.std():.2f}")
print(f"Beta: Mean={beta_post.mean():.2f}, Std={beta_post.std():.2f}")

Warmup:   0%|          | 0/2500 [00:00, ?it/s]

ValueError: Shape mismatch inside plate('data') at site obs dim -1, 36 vs 35
Trace Shapes:     
 Param Sites:     
Sample Sites:     
   alpha dist    |
        value    |
    beta dist    |
        value    |
    data dist    |
        value 36 |