In [1]:
from datetime import timedelta
from typing import Iterable
import pandas as pd
import numpy as np
from scipy.stats import bernoulli
import datetime
from scipy.optimize import fmin

# Define Utility For Caching

In [2]:
class cached_property(object):
    """
    Provides a nice way to cache properties on a class.
    
    This is a direct copy-paste of Django's cached property from
    https://github.com/django/django/blob/2456ffa42c33d63b54579eae0f5b9cf2a8cd3714/django/utils/functional.py#L38-50
    """
    def __init__(self, func):
        self.func = func

    def __get__(self, instance, type=None):
        if instance is None:
            return self
        res = instance.__dict__[self.func.__name__] = self.func(instance)
        return res


# Define a class for generating seasonal survival data

In [15]:
class Data:
    """
    This class provides a (really inneficient) way of generating seasonal
    survival data
    """
    def __init__(
        self,
        q1_tau: float = 10,
        q2_tau: float = 20,
        q3_tau: float =30,
        q4_tau: float =40,
        censoring_tau: float = 60,
    ):
        """
        Instantiate with the exponential time constants for each quarter's
        hazard rate
        """
        # Save the censoring rate
        self._censoring_tau = censoring_tau
        
        # Save the exponential time constants for each quarter
        self.tau_for_quarter = {
            1: q1_tau,
            2: q2_tau,
            3: q3_tau,
            4: q4_tau,
        }
    
    @cached_property
    def _df_hazard(self):
        """
        This is just a lookup table that registers the exponential time constant
        for the hazard rate of each day.
        """
        # Create a dataframe with daily rows spanning 10 years
        df = pd.DataFrame({'dates': pd.date_range('1/1/2019', '12/31/2029')})
        
        # Populate the time constants for each day from the lookup dict
        df['tau'] = [self.tau_for_quarter[d.quarter] for d in df.dates]
        
        # Index on date so that accessing this dataframe as a lookup table is fast
        df.index = pd.Index(df.dates)
        df.index.name = None
        
        # Return the frame
        return df
    
    def _initialize(self):
        """
        This initializes to empty data
        """
        self.records = []
    
    def _generate_events(self, num_samples: int):
        """
        Generate seasonal survival rates
        """
        # Define Min/Max dates that will fall well within the time bounds of _df_hazard
        min_date = pd.Timestamp('1/1/2019')
        max_date = pd.Timestamp('1/1/2026')
        
        # Will use this for generating random dates
        max_days = (max_date - min_date).days
        
        # This is the latest date of _df_hazard which sets loop limits below
        latest_date = self._df_hazard.dates.iloc[-1]
        
        # I will model censoring as a exponential process as well with this time constant
        censor_prob = 1 / self._censoring_tau
        
        # Make sure the data list is initialized to empty
        self._initialize()
        
        # Generate num_samples records.
        for nn in range(num_samples):
            # Randomly pull a number of days from the valid range
            days = np.random.randint(0, max_days)
            
            # Start observing this sample at the randomly generated start date
            start_date = pd.Timestamp(min_date + datetime.timedelta(days=days))
            date = start_date
            
            # Now just step forward in time for every valid day in the simulation
            while date < latest_date:
                # Lookup the exponential rate from the hazard table
                tau = self._df_hazard.loc[date, 'tau']
                
                # Use the rate to compute an event probability for this day
                event_prob = 1. / tau
                
                # Randomly draw to see if this event should be censored today
                if bernoulli(censor_prob).rvs():
                    # If the event is censored, handle that and move on to the next record
                    self._handle_event(start_date, date, 0)
                    break
                    
                # Randomly draw to see if the event occured today
                if bernoulli(event_prob).rvs():
                    # If the event occured, handle that and move on to the next record
                    self._handle_event(start_date, date, 1)
                    break
                
                # If no censoring and not event, try again tomorrow.
                date += datetime.timedelta(days=1)
                    
    def _handle_event(self, start_date: pd.Timestamp, date: pd.Timestamp, event_occured: int):
        """
        Handle an event
        """
        # Compute the survival time for this record
        duration = (date - start_date).days
        
        # Create a record
        rec = {
            'date': date,
            'duration': duration,
            'event_occured': event_occured
        }
        
        # Add the record to the record list
        self.records.append(rec)
        
    def get_frame(self, num_samples: int = 10):
        """
        Get a pandas dataframe of seasonal survival data.
        Returns a dataframe like the following:
        
                    date  duration  event_occured
            0 2019-10-26         9              0
            1 2019-11-22         4              0
            2 2020-02-24         2              0        
        """
        self._generate_events(num_samples)
        return pd.DataFrame(self.records)
                    

# Define a class for fitting seasonal survival

In [16]:
class Seasonal:
    """
    A class for fitting seasonal survival
    """
    
    def __init__(
        self,
        durations: Iterable[float],
        dates: Iterable[pd.Timestamp],
        observed: Iterable[int]
    ):
        """
        Args:
            durations: The durations of the observations
            dates: The date at which event or censoring occured
            observed: 1 if observed, 0 if censored
            
        """
        # Save off the constructor args for later use
        self.durations = durations
        self.dates = dates
        self.observed = observed
        
        # For every observation, we create a dict of {quarter: days_in_quarter, ...}
        # This will be a list of those dicts corresponding to every element in durations
        self._days_in_quarter = None
    
    def _get_days_in_quarter(self, durations: Iterable[float], dates: Iterable[pd.Timestamp]):
        """
        This populates the days-in-quarter for every record
        """
        dates = pd.Series(dates)
        days_in_quarter = []
        for duration, end_date in zip(durations, dates):
            start_date = end_date - datetime.timedelta(days=duration)
            df = pd.DataFrame({'date': pd.date_range(start_date, end_date)})
            df['quarter'] = df.date.dt.quarter
            days_in_quarter.append(df.quarter.value_counts().to_dict())
        return days_in_quarter

    def _hazard(self, params: np.array, day: pd.Timestamp):
        """
        This computes the hazard at a single date.  Ideally we wouldn't need
        this, and could just use autograd to compute it.  I couldn't figure out
        how to make that work though, so I just hard coded this.
        """
        day = pd.Timestamp(day)
        return 1. / params[day.quarter - 1]

    def _cumulative_hazard(self, params: np.array):
        """
        Uses the params to generate the cumulative hazard experienced by all records
        """
        # This will be filled with cumulative hazrds
        cum_hazards = []
        
        # If the lookup for days-in-quarter hasn't been generated, do so now.
        if self._days_in_quarter is None:
            self._days_in_quarter = self._get_days_in_quarter(self.durations, self.dates)
        
        # Loop over all records of the days-in-quarter list
        for diq in self._days_in_quarter:
            # Initialize the cumulative hazard to zero
            cum_haz = 0.
            
            # Now loop over all quarters spanned by this record
            for quarter, days_in_quarter in diq.items():
                
                # Add the cumulative hazard for that quarter
                cum_haz += days_in_quarter / params[quarter - 1]
            
            # Add this cumulative hazard to the output list
            cum_hazards.append(cum_haz)
                
        # Return an array of cumulative hazards for each record
        return np.array(cum_hazards, dtype=np.float64)
    
    def _neg_log_likelihood(self, params: np.array):
        """
        Compute the negative log-likelihood of these params
        """
        # Get the cumulative hazards for all records
        cum_hazards = self._cumulative_hazard(params)
        
        # Get the hazards for all records
        hazards = np.array([self._hazard(params, day) for day in self.dates])
        
        # Compute the log-likelihood for each record
        d_log_likelihood = self.observed * np.log(hazards) - cum_hazards
        
        # Return the negative of the integrated log-likelihood
        return - np.sum(d_log_likelihood)
    
    def fit(self):
        """
        Runs a fit of the data and returns a pandas series of the best fit
        """
        params = np.array([1., 1., 1., 1.])
        res = fmin(func=seasonal._neg_log_likelihood, x0=params, maxiter=1000)
        res = pd.Series(res, index=[f'tau_for_q_{nn}' for nn in range(1, 5)])
        return res

# Run the simulation

In [23]:
# Generate some seasonal survival data with these rates
data = Data(q1_tau=10, q2_tau=20, q3_tau=30, q4_tau=40, censoring_tau=60)
df = data.get_frame(num_samples=500)

# Fit the data and display results
seasonal = Seasonal(df.duration, df.date, df.event_occured)
res = seasonal.fit()
print('\n')
display(res)

Optimization terminated successfully.
         Current function value: 1454.094310
         Iterations: 200
         Function evaluations: 341




tau_for_q_1    10.656485
tau_for_q_2    18.831441
tau_for_q_3    32.891915
tau_for_q_4    38.591581
dtype: float64