In [None]:
# default_exp losses

# Losses
> All the losses used in SA.

In [None]:
# export
from abc import ABC, abstractmethod
from typing import Callable, Tuple
import torch

Suppose that we have:
$$
t_i = \mu + \xi_i
$$
and $\xi_i\sim p(\xi_i|\theta)$. Then $\xi_i|\mu\sim p_\mu(\xi_i|\theta)$ where $p_\mu(\xi_i|\theta)$ is simply the distribution $p(\xi_i|\theta)$ shifted to the left by $\mu$.

In the event that the event is censored ($e_i=0$), we know that $t_i < \mu + \xi_i$ since the 'death' offset of $\xi_i$ is not observed. 

Therefore we may write the likelihood of 
$$
\begin{aligned}
p(t_i, e_i|\mu) =& \left(p(t_i-\mu)\right)^{e_i} \left(\int_{t_i}^\infty p(t-\mu) dt\right)^{1-e_i}\\
\log p(t_i, e_i|\mu) =& e_i \log p(t-\mu) + (1 - e_i) \log \left(1 - \int_{-\infty}^{t_i} p(t-\mu) dt \right)
\end{aligned}
$$

In [None]:
# export
class Loss(ABC):
    @abstractmethod
    def __call__(event:torch.Tensor, *args):
        pass

In [None]:
# export
LossType = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]

In [None]:
# export
class AFTLoss(Loss):
    @staticmethod
    def __call__(event:torch.Tensor, log_pdf: torch.Tensor, log_icdf: torch.Tensor) -> torch.Tensor:
        lik = event * log_pdf + (1 - event) * log_icdf
        return -lik.mean()

In [None]:
N = 5
event = torch.randint(0, 2, (N,))
log_pdf = torch.randn((N,))
log_cdf = -torch.rand((N,))

aft_loss = AFTLoss()
aft_loss(event, log_pdf, log_cdf)

tensor(-0.0400)

In [None]:
# export
def _aft_loss(
    log_pdf: torch.Tensor, log_cdf: torch.Tensor, e: torch.Tensor
) -> torch.Tensor:
    lik = e * log_pdf + (1 - e) * log_cdf
    return -lik.mean()


def aft_loss(log_prob, e):
    log_pdf, log_cdf = log_prob
    return _aft_loss(log_pdf, log_cdf, e)

We use the following loss function to infer our model. See [here](./SAT#Likelihood-Function) for theory.
$$
-\log L = \sum_{i=1}^N \Lambda(t_i) - d_i \log \lambda(t_i)
$$

In [None]:
# export
class HazardLoss(Loss):
    @staticmethod
    def __call__(event: torch.Tensor, logλ: torch.Tensor, Λ: torch.Tensor) -> torch.Tensor:
        log_lik = event * logλ - Λ
        return -log_lik.mean()

In [None]:
# export
def _hazard_loss(logλ: torch.Tensor, Λ: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
    log_lik = e * logλ - Λ
    return -log_lik.mean()


def hazard_loss(
    hazard: Tuple[torch.Tensor, torch.Tensor], e: torch.Tensor
) -> torch.Tensor:
    """
    parameters:
    - hazard: log hazard and Cumulative hazard
    - e: torch.Tensor of 1 if death event occured and 0 otherwise
    """
    logλ, Λ = hazard
    return _hazard_loss(logλ, Λ, e)

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_index.ipynb.
Converted 10_SAT.ipynb.
Converted 20_KaplanMeier.ipynb.
Converted 30_overall_model.ipynb.
Converted 50_hazard.ipynb.
Converted 55_hazard.PiecewiseHazard.ipynb.
Converted 59_hazard.Cox.ipynb.
Converted 60_AFT_models.ipynb.
Converted 65_AFT_error_distributions.ipynb.
Converted 80_data.ipynb.
Converted 90_model.ipynb.
Converted 95_Losses.ipynb.
