# Tobit Model -  Censored Normal Regression

**关于算法的故事**

在任一给定年份，有相当数量家庭的医疗保险费用支出为0

因此，虽然年度家庭医疗保险费用支出的总体分布散布于一个很大的正数范围内，但在数字0上却相当集中

# Algorithm

**定义**

1. 因变量 $y$ 在正值上大致是连续分布，但也包含一部分以正概率取值为0的观测，对于 $y$ 如下定义，$y^*$ 是一个隐藏的变量，它与0共同决定了观测 $y$

$$y=\left\{
\begin{aligned}
y^* ,& y^* \geqslant 0 \\
- ,& y^* < 0 
\end{aligned}
\right.
$$

2. 隐藏变量 $y^*$ 由一个参数为 $\beta$ 的线性方程和 $\epsilon$，定义如下
2
$$y^* = x^T\beta + \epsilon, \text{ 其中 } \epsilon \sim N(0, \sigma^2)$$
$$$$

3. 需要优化的参数就是 $\beta$ 和 $\sigma$

$$\hat{\theta}= (\hat{\beta}, \hat{\sigma}^2)$$

**目标函数**

1. 定义一个 indicator variable $d$，取 $L=0$ 如下:

$$d=\left\{
\begin{aligned}
1 ,& y > L\\
0 ,& y = L 
\end{aligned}
\right.$$

2. 由定义式可知 $y^* \sim (x^T\beta, \sigma^2)$

$$y^* = x^T\beta + \epsilon \rightarrow y^* \sim (x^T\beta, \sigma^2)$$

3. 当 $y^*<L$，用 CDF概率分布函数，当 $y^*>L$ 可观测时候用 PDF概率密度函数，这里取 $L=0$

$$f(y)=\left\{
\begin{aligned} F^*(0)=1-\Phi(\frac{x^T\beta}{\sigma}), d=0\\
\frac{1}{\sqrt{2\pi\sigma^2}}exp(-\frac{1}{2\sigma^2}(y-x^T\beta)^2), d=1
\end{aligned}
\right.
$$

4. 综合上面两个分段函数就有如下形式，也就是我们通过 MLE 最大化的函数

$$f(y) =[\frac{1}{\sqrt{2\pi\sigma^2}}exp(-\frac{1}{2\sigma^2}(y-x^T\beta)^2)]^d[1-\Phi(\frac{x^T\beta}{\sigma})]^{1-d}$$

**参数梯度**

对以上公式取 $log$ 之后的参数梯度，其中 $\phi_i=\phi(\frac{x_i^T\beta}{\sigma})$，$\Phi_i=\Phi(\frac{x_i^T\beta}{\sigma})$

$$\frac{\partial ln L_N}{\partial \beta}=\sum^N_{i=1}\frac{1}{\sigma^2}(d_i(y_i-x^T\beta)-(1-d_i)\frac{\sigma\phi_i}{1-\Phi_i})x_i$$


$$\frac{\partial ln L_N}{\partial \sigma^2}=\sum^N_{i=1}[d_i(-\frac{1}{2\sigma^2}+\frac{(y_i-x^T\beta)^2}{2\sigma^4})+(1-d_i)\frac{\phi_ix_i^T\beta}{1-\Phi_i} · \frac{1}{2\sigma^3}]$$

# Implementation

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, vmap
import numpy as np

## Data

In [2]:
key = jax.random.PRNGKey(199)
batch = 10000
input_dim = 5
L = 0

true_beta = jnp.array([2.0, 2.0, 3.0, 3.0, 4.0])
true_sigma = 2.0
epsilon = jax.random.normal(key, (batch, )) * (true_sigma**2)

x = jax.random.normal(key, (batch, input_dim))
d = (y_star > L).astype(jnp.float32)
y_star = jnp.dot(x, true_beta) + epsilon
y =  d * y_star

beta = jax.random.normal(key, (input_dim, ))
sigma = jax.random.normal(key)

print('epsilon mean', epsilon.mean())
print('epsilon std', epsilon.std())
print('y_star std', y_star.std())
print('x_dot_b mean', jnp.dot(x, true_beta).mean())
print('y_star mean', y_star.mean())



NameError: ignored

In [None]:
y

In [None]:
y_star

## Function

In [None]:
def cdf(x):
    return jax.scipy.stats.norm.cdf(x, loc=0, scale=1)

def pdf(x):
    return jax.scipy.stats.norm.pdf(x, loc=0, scale=1)

def linear(x, beta):
    return jnp.dot(x, beta)


def tobit_model_grad(x, y_true, beta, sigma):
    x_dot_beta = linear(x, beta)
    residual = y_true - x_dot_beta
    
    d_beta_temp = d*residual - (1-d) * sigma * pdf(x_dot_beta) / (1-cdf(x_dot_beta) + 0.01)
    d_beta = (1 / sigma**2) * jnp.dot(d_beta_temp, x) / x.shape[0]
    
    d_sigma2_temp_1 = d * (-1 / 2 * sigma**2 + (residual**2) / (2 * sigma**4))
    d_sigma2_temp_2 = 1 / (2 * sigma**3 ) * (1 - d) * jnp.dot(pdf(x_dot_beta), x_dot_beta) / (1 - cdf(x_dot_beta) + 0.01)
    d_sigma2 = -jnp.mean(d_sigma2_temp_1 + d_sigma2_temp_2)
    
    return d_beta, d_sigma2
    

def tobit_model_train(x, y_true, beta, sigma, lr, esp, max_iter):
    
    
    d_beta, d_sigma2 = tobit_model_grad(x, y, beta, sigma)
    beta = beta - lr * d_beta
    sigma = -jnp.sqrt(sigma**2 - lr * d_sigma2)
    
    pass



tobit_model_grad(x, y, beta, sigma)

In [None]:
lr = 1e-4
for i in range(1000) :
    d_beta, d_sigma2 = tobit_model_grad(x, y, beta, sigma)
    beta = beta - lr * d_beta
    sigma = jnp.sqrt(sigma**2 - lr * d_sigma2)
    
#     print(beta, sigma)

In [None]:
beta

In [None]:
sigma

In [None]:
true_beta

In [None]:
true_sigma

In [None]:
b1

In [None]:
jnp.dot(x, beta)

In [None]:
y_star

In [None]:
normal_cdf(linear(x, beta)).shape