# Growth model

### Import libraries

Import Dynamax

In [1]:
try:
    import dynamax
except ModuleNotFoundError:
    %pip install -qq git+https://github.com/probml/dynamax.git
    import dynamax

from dynamax.hidden_markov_model import PoissonHMM, HMM
from dynamax.parameters import ParameterProperties

Note: you may need to restart the kernel to use updated packages.


Import Jax and Tensorflow Probability

In [41]:
import numpy as np
import jax.numpy as jnp
import jax.random as jr
from jax import jit, lax, vmap, value_and_grad

import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

Import other stuff

In [35]:
import optax
from itertools import count
from functools import partial
from matplotlib import pylab as plt

### Fake data

In [61]:
true_rates = [40, 3, 20, 50]
true_durations = [10, 20, 5, 35]
keys = map(jr.PRNGKey, count())

emissions = jnp.concatenate(
    [
        jr.poisson(key, rate, (num_steps,))
        for (key, rate, num_steps) in zip(keys, true_rates, true_durations)
    ]
).astype(jnp.float32)

# PoissonHMM requires are least 1D emissions
emissions = emissions[:, None]


### Import data

In [48]:
def import_data_json():
    import pandas as pd
    # Specify the path to your JSON file
    file_path = '/Users/patricksweeney/growth/03_Product/04_Markov Growth Model/Activation test.json'
    # Use pandas to read the JSON file
    data = pd.read_json(file_path)
    return data

# Call the function to import your data
data = import_data_json()

data.info()



<class 'pandas.core.frame.DataFrame'>
RangeIndex: 197370 entries, 0 to 197369
Data columns (total 30 columns):
 #   Column                                     Non-Null Count   Dtype         
---  ------                                     --------------   -----         
 0   transcription_hours_total                  1926 non-null    float64       
 1   shared_object_project_count                466 non-null     float64       
 2   project_count                              980 non-null     float64       
 3   shared_object_project_category_count       466 non-null     float64       
 4   converted                                  197370 non-null  int64         
 5   page_user_count                            13187 non-null   float64       
 6   workspace_created_at                       197370 non-null  datetime64[ns]
 7   shared_object_tag_count                    466 non-null     float64       
 8   insight_count                              928 non-null     float64       
 9   mrr_

### Clean data

In [49]:
import pandas as pd

def replace_nans(data, exempt=['converted_at', 'converted']):
    columns_to_fill = [col for col in data.columns if col not in exempt]
    data[columns_to_fill] = data[columns_to_fill].fillna(0)
    return data

data = replace_nans(data)
data.head()


Unnamed: 0,transcription_hours_total,shared_object_project_count,project_count,shared_object_project_category_count,converted,page_user_count,workspace_created_at,shared_object_tag_count,insight_count,mrr_converted,...,page_user_hll,shared_object_insight_count,date,comment_count,shared_object_workspace_field_group_count,transcription_count,reel_created_count,shared_object_count,reel_viewed_count,activity_flag
0,0.0,0.0,0.0,0.0,0,1.0,2023-10-12,0.0,0.0,0.0,...,\x128b7fd6b11f8423a8d5ca,0.0,2023-10-12,0.0,0.0,0.0,0.0,0.0,0.0,True
1,0.0,0.0,0.0,0.0,0,0.0,2023-10-12,0.0,0.0,0.0,...,\x118b7f,0.0,2023-10-13,0.0,0.0,0.0,0.0,0.0,0.0,False
2,0.0,0.0,0.0,0.0,0,0.0,2023-10-12,0.0,0.0,0.0,...,\x118b7f,0.0,2023-10-14,0.0,0.0,0.0,0.0,0.0,0.0,False
3,0.0,0.0,0.0,0.0,0,0.0,2023-10-12,0.0,0.0,0.0,...,\x118b7f,0.0,2023-10-15,0.0,0.0,0.0,0.0,0.0,0.0,False
4,0.0,0.0,0.0,0.0,0,0.0,2023-10-12,0.0,0.0,0.0,...,\x118b7f,0.0,2023-10-16,0.0,0.0,0.0,0.0,0.0,0.0,False


### Fill missing dates

In [50]:
import pandas as pd

def fill_dates(data, date_col='date', freq='D', id_col='workspace_id'):
    # Check if the date column is present and convert it to datetime
    if date_col not in data.columns:
        print("Available columns in the DataFrame:", data.columns)
        raise KeyError(f"{date_col} column not found in the DataFrame.")
    
    # Convert the date column to datetime format
    data[date_col] = pd.to_datetime(data[date_col])
    
    # Ensure that the date column is not set as the index
    if data.index.name == date_col:
        data.reset_index(inplace=True)
    
    # Function to process each group
    def process_group(group):
        # Ensure group has date column in datetime format
        group[date_col] = pd.to_datetime(group[date_col])
        
        # Generate a complete date range for the group
        idx = pd.date_range(start=group[date_col].min(), end=group[date_col].max(), freq=freq)
        
        # Set date as the index temporarily for reindexing
        group.set_index(date_col, inplace=True)
        
        # Reindex the group to the complete date range, filling non-existent rows with NaNs
        group_reindexed = group.reindex(idx)
        
        # Reset index to bring back the date column
        group_reindexed.reset_index(inplace=True)
        
        # Rename the 'index' column back to date_col
        group_reindexed.rename(columns={'index': date_col}, inplace=True)
        
        # Forward fill non-varying columns, set varying columns to 0
        for col in group_reindexed.columns:
            if col != id_col and col != date_col:  # Skip the workspace_id and date columns for processing
                if len(group_reindexed[col].dropna().unique()) == 1:  # Non-varying column
                    group_reindexed[col] = group_reindexed[col].ffill().bfill()
                else:  # Varying column
                    group_reindexed[col] = group_reindexed[col].fillna(0)
        
        return group_reindexed

    # Apply processing to each group and combine the results
    filled_data = data.groupby(id_col, group_keys=False).apply(process_group)
    
    # Calculate and print the mean and standard deviation of observations per workspace_id
    observation_counts = filled_data.groupby(id_col).size()
    mean_observations = observation_counts.mean()
    std_observations = observation_counts.std()
    
    print(f"Mean number of observations per {id_col}: {mean_observations}")
    print(f"Standard deviation of observations per {id_col}: {std_observations}")

    return filled_data

# Example usage:
data = fill_dates(data, 'date', 'D', 'workspace_id')
data.head()


Mean number of observations per workspace_id: 30.0
Standard deviation of observations per workspace_id: 0.0


Unnamed: 0,date,transcription_hours_total,shared_object_project_count,project_count,shared_object_project_category_count,converted,page_user_count,workspace_created_at,shared_object_tag_count,insight_count,...,note_count,page_user_hll,shared_object_insight_count,comment_count,shared_object_workspace_field_group_count,transcription_count,reel_created_count,shared_object_count,reel_viewed_count,activity_flag
0,2023-10-10,0.0,0.0,0.0,0.0,0,1.0,2023-10-10,0.0,0.0,...,1.0,\x128b7fc4ef88c343d5402f,0.0,0.0,0.0,0.0,0.0,0.0,0.0,True
1,2023-10-11,0.0,0.0,0.0,0.0,0,0.0,2023-10-10,0.0,0.0,...,0.0,\x118b7f,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False
2,2023-10-12,0.0,0.0,0.0,0.0,0,0.0,2023-10-10,0.0,0.0,...,0.0,\x118b7f,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False
3,2023-10-13,0.0,0.0,0.0,0.0,0,0.0,2023-10-10,0.0,0.0,...,0.0,\x118b7f,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False
4,2023-10-14,0.0,0.0,0.0,0.0,0,0.0,2023-10-10,0.0,0.0,...,0.0,\x118b7f,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False


### HMM Fixed K

#### Setup

Build the HMM class with Baum-Welch algorithm for parameter learning.

In [72]:
class NonconjugatePoissonHMM(PoissonHMM):
    """A Poisson HMM with a nonconjugate prior.    
    """
    def __init__(self, num_states, emission_dim, 
                 emission_prior_loc=0.0,
                 emission_prior_scale=1.0):
        HMM.__init__(self,
            num_states)
        self.emission_dim = emission_dim
        self.emission_prior_loc = emission_prior_loc
        self.emission_prior_scale = emission_prior_scale
        
    def initialize(self, key, method="prior", initial_probs=None, transition_matrix=None, rates=None):
        key1, key2 = jr.split(key)
        params, props = HMM.initialize(self, key=key1, 
                                               method=method, 
                                               initial_probs=initial_probs, 
                                               transition_matrix=transition_matrix)
        
        if rates is None:
            prior = tfd.LogNormal(self.emission_prior_loc, self.emission_prior_scale)
            rates = prior.sample(seed=key2, sample_shape=(self.num_states, self.emission_dim))
            
        params['emissions'] = dict(rates=rates)
        props['emissions'] = dict(rates=ParameterProperties(constrainer=tfb.Softplus()))
        return params, props
        
    def log_prior(self, params):
        return tfd.LogNormal(self.emission_prior_loc, self.emission_prior_scale).log_prob(
            params["emissions"]["rates"]
        ).sum()
        
    # Default to the standard E and M steps rather than the conjugate updates
    # for the PoissonHMM with a gamma prior.
    def e_step(self, params, batch_emissions):
        return StandardHMM.e_step(self, params, batch_emissions)
    
    def m_step(self, params, param_props, batch_emissions, batch_posteriors, **batch_covariates):
        return StandardHMM.m_step(self, params, param_props, batch_emissions, batch_posteriors, **batch_covariates)

Build the latent states.

In [60]:
def build_latent_state(num_states, max_num_states, daily_change_prob):
    # Give probability 0 to states outside of the current model.
    def prob(s):
        return jnp.where(s < num_states + 1, 1 / num_states, 0.0)

    states = jnp.arange(1, max_num_states + 1)
    initial_state_probs = vmap(prob)(states)

    # Build a transition matrix that transitions only within the current
    # `num_states` states.
    def transition_prob(i, s):
        return jnp.where(
            (s <= num_states) & (i <= num_states) & (1 < num_states),
            jnp.where(s == i, 1 - daily_change_prob, daily_change_prob / (num_states - 1)),
            jnp.where(s == i, 1, 0),
        )

    transition_probs = vmap(transition_prob, in_axes=(None, 0))(states, states)

    return initial_state_probs, transition_probs

num_states = 2
daily_change_prob = 0.05

initial_state_probs, transition_probs = build_latent_state(num_states, num_states, daily_change_prob)
print("Initial state probs:\n{}".format(initial_state_probs))
print("Transition matrix:\n{}".format(transition_probs))

Initial state probs:
[0.5 0.5]
Transition matrix:
[[0.95 0.05]
 [0.05 0.95]]


#### Learning with gradient descent

Isolate single workspace_id

In [58]:
import pandas as pd
import matplotlib.pyplot as plt

def isolate_id(data, id='24a55ae6-f255-4b32-b631-c5d414cf0d4d'):
    # Filter for the specified id
    filtered_data = data[data['workspace_id'] == id]
    
    # Ensure 'date' column is in datetime format for proper sorting
    filtered_data['date'] = pd.to_datetime(filtered_data['date'])
    
    # Sort by 'date' ascending
    filtered_data.sort_values(by='date', inplace=True)
    
    # Drop all non-numeric columns
    numeric_data = filtered_data.select_dtypes(include='number')
    
    # Drop the 't' column if it exists in numeric_data
    if 't' in numeric_data.columns:
        numeric_data = numeric_data.drop(columns=['t', 'transcription_hours_total'])
    
    # # Plot time series of all numeric variables as subplots
    # if len(numeric_data.columns) > 0:  # Check if there's any numeric column left
    #     fig, axs = plt.subplots(len(numeric_data.columns), 1, figsize=(10, 5*len(numeric_data.columns)), squeeze=False)
        
    #     for i, col in enumerate(numeric_data.columns):
    #         axs[i, 0].plot(filtered_data['date'], numeric_data[col])  # Use 'date' column for x-axis
    #         axs[i, 0].set_title(col)
    #         axs[i, 0].set_xlabel('Date')
    #         axs[i, 0].set_ylabel('Value')
        
    #     plt.tight_layout()
    #     plt.show()
    # else:
    #     print("No numeric columns to plot after filtering.")

    return numeric_data

# Example usage:
# Assuming 'data' is your DataFrame and it contains 'workspace_id' and 'date' columns, along with a 't' column
numeric_data = isolate_id(data, id='24a55ae6-f255-4b32-b631-c5d414cf0d4d')

emissions = jnp.array(numeric_data.values)
emissions.shape




A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_data['date'] = pd.to_datetime(filtered_data['date'])
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_data.sort_values(by='date', inplace=True)


(30, 22)

In [73]:
# Define variable to represent the unknown log rates.
hmm = NonconjugatePoissonHMM(num_states, 
                             emission_dim=1, 
                             emission_prior_loc=jnp.log(emissions.mean()), 
                             emission_prior_scale=1.0)

# The optimization gets stuck in local optima. This key should find the right states.
params, param_props = hmm.initialize(jr.PRNGKey(1),
                                     initial_probs=initial_state_probs,
                                     transition_matrix=transition_probs)

# Freeze the initial distribution and transition matrix
param_props["initial"]["probs"].trainable = False
param_props["transitions"]["transition_matrix"].trainable = False

# Fit the model with SGD
optimizer = optax.adam(1e-1)
num_epochs = 1000
params, losses = hmm.fit_sgd(params,
                     param_props,
                     emissions,
                     optimizer=optimizer,
                     num_epochs=num_epochs)

TypeError: HMM.__init__() missing 3 required positional arguments: 'initial_component', 'transition_component', and 'emission_component'

In [None]:
plt.plot(losses)
plt.ylabel("Negative log marginal likelihood")
plt.xlabel("iteration")

print("Inferred rates: {}".format(params["emissions"]["rates"]))

### HMM Variable K