In [1]:
from typing import Dict
import numpy as np
import pandas as pd
import scipy.stats as sps

# Data generation

In [2]:
class GenData:

  def __init__(
      self,
      beta=None, 
      base_rate=0.25, 
      cens_prop=0.2, 
      n_covar=3, 
      n_obs=100,
      tau=10,
  ) -> None:
    """Generative parameters.
    
    Args:
      beta: Log hazard ratios.
      base_rate: Exponential distribution base rate.
      cens_prop: Expected proportion of observations that are censored.
      n_covar: Number of covariates. Replaced by len(beta) if beta is provided.
      n_obs: Number of observations.
      tau: Truncation time.
    
    """
    if beta is None:
      beta = np.zeros((n_covar,))
      n_covar = len(beta)
    assert len(beta) == n_covar, "Length of beta must match number of covariates."
    self.beta = beta
    self.base_rate = base_rate
    self.cens_prop = cens_prop
    self.n_covar = n_covar
    self.n_obs = n_obs
    self.tau = tau
  
  def get_batch(self) -> Dict[str, np.ndarray]:
    
    # Generate covariates.
    x = np.random.normal(loc=0.0, scale=1.0, size=(self.n_obs, self.n_covar))

    # Generate event time.
    rate = self.base_rate * np.exp(np.matmul(x, self.beta))
    event_time = np.random.exponential(scale=1.0 / rate)

    # Generate censoring time.
    cens_rate = self.cens_prop / (1 - self.cens_prop) * rate 
    cens_time = np.random.exponential(scale=1.0 / cens_rate)
    
    # Censored data. 
    cens_time[cens_time > self.tau] = self.tau
    event_cens = np.stack((event_time, cens_time), axis=1)
    time = np.min(event_cens, axis=1)
    assert len(time) == self.n_obs
    status = np.where(event_time <= cens_time, 1.0, 0.0)
    
    # Output.
    return {
        "beta": self.beta,
        "cens_time": cens_time,
        "even_time": event_time,
        "rate": rate,
        "status": status,
        "time": time,
        "x": x,
    }


In [3]:
np.random.seed(201)
data_generator = GenData()
data = data_generator.get_batch()

In [4]:
time = data.get("time")
status = data.get("status")
print(time[:5])
print(status[:5])

[0.61086877 1.32996653 0.10600733 0.32355925 3.20244274]
[1. 1. 0. 1. 1.]


# Cox model

## Loss

In [5]:
class CoxLoss:

  def __init__(
      self,
      status: np.ndarray,
      time: np.ndarray,
      x: np.ndarray
  ) -> None:
    """Cox model data.
    
    Args:
      status: Event status, 1 for observed, 0 for censored.
      time: Observation time.
      x: (n_obs, n_covar) array of covariates.
    
    """

    # Sorting. 
    df_time = pd.DataFrame({"time": time, "status": status})
    df_x = pd.DataFrame(x)
    df = pd.concat([df_time, df_x], axis=1)
    df = df.sort_values(by=["time", "status"], ascending=[True, False])

    # Cache.
    self.df = df[["time", "status"]]
    self.n_obs, self.n_covar = x.shape
    self.x = np.array(df.drop(columns=["time", "status"]))


  @staticmethod
  def is_pd(x: np.ndarray) -> bool:
    """Check that matrix is positive definite."""
    return np.all(np.linalg.eigvals(x) > 0)
  

  def log_lik(self, beta: np.ndarray) -> float:
    """Calculate partial log likelihood."""
    df = self.df
    x = self.x
    df["risk"] = np.exp(np.matmul(x, beta))

    log_lik = 0.0
    for idx in range(self.n_obs):
      if df.status[idx] == 0.0:
        continue
      log_lik += np.log(df.risk[idx]) - np.log(np.sum(df.risk[idx:]))
    
    return log_lik


  def grad(self, beta: np.ndarray) -> np.ndarray:
    """Calculate gradient of the partial log likelihood."""
    df = self.df
    x = self.x
    df["risk"] = np.exp(np.matmul(x, beta))

    score = np.zeros((self.n_covar, ))
    for idx in range(self.n_obs):
      if df.status[idx] == 0.0:
        continue
      score += x[idx, :] - np.average(x[idx:, :], weights=df.risk[idx:], axis=0)
    return score


  def info(self, beta: np.ndarray) -> np.ndarray:
    """Calculate information matrix of the partial log likelihood."""
    df = self.df
    x = self.x
    df["risk"] = np.exp(np.matmul(x, beta))

    # Outer product tensor: (n_obs, n_covar, n_covar).
    x2 = x[:, np.newaxis, :] * x[:, :, np.newaxis]

    info = np.zeros((self.n_covar, self.n_covar))
    for idx in range(self.n_obs):
      if df.status[idx] == 0.0:
        continue
      avg = np.average(x[idx:, :], weights=df.risk[idx:], axis=0)
      info += np.average(x2[idx:, :, :], weights=df.risk[idx:], axis=0) - \
        np.outer(avg, avg)
    return info


  def inference(self, beta: np.ndarray, t1e=0.05) -> pd.DataFrame:
    """Tabulate parameter estimates, confidence intervals, and p-values.

    Note:
      Confidence intervals for beta are on the log hazard ratio scale.
    
    Args:
      beta: Parameter estimation.
      t1e: Type I error (for confidence intervals).
    
    """
    info = self.info(beta)
    if not CoxLoss.is_pd(info):
      print("Information matrix is not positive definite.")
    
    # Calculate standard errors
    inv_info = np.linalg.pinv(info)
    inv_info[inv_info < 0] = np.inf
    se = np.sqrt(np.diagonal(inv_info))

    # Calculate confidence intervals.
    z = sps.norm.ppf(1.0 - 0.5 * t1e)
    lower = beta - z * se
    upper = beta + z * se

    # Calculate p-value.
    p = 2 * sps.norm.sf(np.abs(beta / se))
    return pd.DataFrame({
        "beta": beta,
        "se": se,
        "lower": lower,
        "upper": upper,
        "p": p,
    })

In [6]:
beta = np.zeros((3, ))
loss_fn = CoxLoss(status=data["status"], time=data["time"], x=data["x"])

log_lik = loss_fn.log_lik(beta=beta)
print(f"Partial log likelihood: {log_lik:.3f}")

score = loss_fn.grad(beta=beta)
print("\nScore:\n", score)

info = loss_fn.info(beta=beta)
print("\nInformation:\n", info)

Partial log likelihood: -284.399

Score:
 [  4.05806144  -7.33756249 -15.02580012]

Information:
 [[65.01923434 -2.10372762  8.89499086]
 [-2.10372762 56.95119146 12.68413163]
 [ 8.89499086 12.68413163 59.58131175]]


## Estimation

In [7]:
class CoxEst:

  def __init__(
      self, 
      eps=1e-8, 
      init_beta=None, 
      max_iter=10,
      verbose=True
  ) -> None:
    """Fitting parameters.
    
    Args:
      eps: Minimum improvement in log likelihood.
      init_beta: Initial beta. 
      max_iter: Maximum number of iterations. 
      verbose: Report fitting progress? 

    """
    self.eps = eps
    self.init_beta = init_beta
    self.max_iter = max_iter
    self.verbose = verbose
  
  def fit(self, loss: CoxLoss) -> np.ndarray:
    
    # Initialize.
    beta = self.init_beta
    if beta is None:
      beta = np.zeros((loss.n_covar, ))
    else:
      assert len(beta) == loss.n_covar, "Length of initial beta does not match the number of covariates."

    # Newton-Raphson.
    loglik = loss.log_lik(beta)
    if self.verbose:
      print(f"Init loglik: {loglik:.3f}")

    for step in range(self.max_iter):
      grad = loss.grad(beta)
      info = loss.info(beta)
      beta_next = beta + np.linalg.solve(info, grad)

      loglik_next = loss.log_lik(beta_next)
      delta = loglik_next - loglik
      if self.verbose:
        print(f"Next loglik: {loglik_next:.3f}, delta: {delta:.3f}")
      if delta > self.eps:
        beta = beta_next
        loglik = loglik_next
      else:
        break
    
    return beta

In [8]:
fit_fn = CoxEst()
beta = fit_fn.fit(loss_fn)

Init loglik: -284.399
Next loglik: -283.818, delta: 0.581
Next loglik: -283.818, delta: 0.000
Next loglik: -283.818, delta: 0.000
Next loglik: -283.818, delta: -0.000


## Inference

In [9]:
results = loss_fn.inference(beta)
print(results)

       beta        se     lower     upper         p
0  0.091310  0.124382 -0.152474  0.335094  0.462880
1 -0.069019  0.131899 -0.327537  0.189498  0.600783
2 -0.254278  0.134428 -0.517753  0.009197  0.058551


# Overall example

In [10]:
np.random.seed(102)
data_generator = GenData(beta=np.zeros((3, )))
data = data_generator.get_batch()
loss_fn = CoxLoss(status=data["status"], time=data["time"], x=data["x"])
fit_fn = CoxEst(verbose=False)
beta = fit_fn.fit(loss_fn)
results = loss_fn.inference(beta)
print(results)

       beta        se     lower     upper         p
0 -0.106282  0.111093 -0.324020  0.111456  0.338721
1 -0.002975  0.118685 -0.235594  0.229644  0.980000
2  0.062114  0.118126 -0.169409  0.293637  0.599009


In [11]:
np.random.seed(103)
data_generator = GenData(beta=np.ones((3, )))
data = data_generator.get_batch()
loss_fn = CoxLoss(status=data["status"], time=data["time"], x=data["x"])
fit_fn = CoxEst(verbose=False)
beta = fit_fn.fit(loss_fn)
results = loss_fn.inference(beta)
print(results)

       beta        se     lower     upper             p
0  0.737528  0.167481  0.409272  1.065784  1.064399e-05
1  0.698765  0.164892  0.375584  1.021947  2.257962e-05
2  0.808253  0.161768  0.491194  1.125313  5.841691e-07
