<a href="https://colab.research.google.com/github/xx529/Algorithm/blob/main/Tobit%20Model%20-%20%20Censored%20Normal%20Regression/Tobit%20Model%20-%20%20Censored%20Normal%20Regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tobit Model -  Censored Normal Regression

**关于算法的故事**

在任一给定年份，有相当数量家庭的医疗保险费用支出为0

因此，虽然年度家庭医疗保险费用支出的总体分布散布于一个很大的正数范围内，但在数字0上却相当集中

# Algorithm

**定义**

1. 因变量 $y$ 在正值上大致是连续分布，但也包含一部分以正概率取值为0的观测，对于 $y$ 如下定义，$y^*$ 是一个隐藏的变量，它与0共同决定了观测 $y$

$$y=\left\{
\begin{aligned}
y^* ,& y^* \geqslant 0 \\
- ,& y^* < 0 
\end{aligned}
\right.
$$

2. 隐藏变量 $y^*$ 由一个参数为 $\beta$ 的线性方程和 $\epsilon$，定义如下
2
$$y^* = x^T\beta + \epsilon, \text{ 其中 } \epsilon \sim N(0, \sigma^2)$$
$$$$

3. 需要优化的参数就是 $\beta$ 和 $\sigma$

$$\hat{\theta}= (\hat{\beta}, \hat{\sigma}^2)$$

**目标函数**

1. 定义一个 indicator variable $d$，取 $L=0$ 如下:

$$d=\left\{
\begin{aligned}
1 ,& y > L\\
0 ,& y = L 
\end{aligned}
\right.$$

2. 由定义式可知 $y^* \sim (x^T\beta, \sigma^2)$

$$y^* = x^T\beta + \epsilon \rightarrow y^* \sim (x^T\beta, \sigma^2)$$

3. 当 $y^*<L$，用 CDF概率分布函数，当 $y^*>L$ 可观测时候用 PDF概率密度函数，这里取 $L=0$

$$f(y)=\left\{
\begin{aligned} F^*(0)=1-\Phi(\frac{x^T\beta}{\sigma}), d=0\\
\frac{1}{\sqrt{2\pi\sigma^2}}exp(-\frac{1}{2\sigma^2}(y-x^T\beta)^2), d=1
\end{aligned}
\right.
$$

4. 综合上面两个分段函数就有如下形式，也就是我们通过 MLE 最大化的函数

$$f(y) =[\frac{1}{\sqrt{2\pi\sigma^2}}exp(-\frac{1}{2\sigma^2}(y-x^T\beta)^2)]^d[1-\Phi(\frac{x^T\beta}{\sigma})]^{1-d}$$

**参数梯度**

对以上公式取 $log$ 之后的参数梯度，其中 $\phi_i=\phi(\frac{x_i^T\beta}{\sigma})$，$\Phi_i=\Phi(\frac{x_i^T\beta}{\sigma})$

$$\frac{\partial ln L_N}{\partial \beta}=\sum^N_{i=1}\frac{1}{\sigma^2}(d_i(y_i-x^T\beta)-(1-d_i)\frac{\sigma\phi_i}{1-\Phi_i})x_i$$


$$\frac{\partial ln L_N}{\partial \sigma^2}=\sum^N_{i=1}[d_i(-\frac{1}{2\sigma^2}+\frac{(y_i-x^T\beta)^2}{2\sigma^4})+(1-d_i)\frac{\phi_ix_i^T\beta}{1-\Phi_i} · \frac{1}{2\sigma^3}]$$

**优化方法**

[Proximal Methods](https://github.com/xx529/Algorithm/blob/main/Proximal%20Algorithms%20-%20L1%20Regularization/Proximal%20Algorithms%20-%20L1%20Regularization.ipynb)

# Implementation

In [245]:
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
from jax.lax import cond
import numpy as np

## Data

In [246]:
key = jax.random.PRNGKey(199)
batch = 10000
input_dim = 5
L = 0

true_beta = jnp.array([2.0, 1.0, -2.0, 0.0, 4.0])
true_sigma = 2.0
epsilon = jax.random.normal(key, (batch, )) * (true_sigma**2)

x = jax.random.normal(key, (batch, input_dim))
y_star = jnp.dot(x, true_beta) + epsilon
d = (y_star > L).astype(jnp.float32)
y =  d * y_star

beta = jax.random.normal(key, (input_dim, ))
sigma = jax.random.normal(key)

print('epsilon mean', epsilon.mean())
print('epsilon std', epsilon.std())
print('y_star std', y_star.std())
print('x_dot_b mean', jnp.dot(x, true_beta).mean())
print('y_star mean', y_star.mean())

epsilon mean 0.0054394654
epsilon std 3.9888053
y_star std 6.449384
x_dot_b mean 0.0039303103
y_star mean 0.009369775


In [247]:
y

DeviceArray([-0.       ,  2.393379 , 13.202604 , ...,  1.0419834,
              4.944159 , -0.       ], dtype=float32)

In [248]:
y_star

DeviceArray([-9.858387 ,  2.393379 , 13.202604 , ...,  1.0419834,
              4.944159 , -7.051079 ], dtype=float32)

## Normal Version

In [258]:
def normal_cdf(x):
    return jax.scipy.stats.norm.cdf(x, loc=0, scale=1)

def normal_pdf(x):
    return jax.scipy.stats.norm.pdf(x, loc=0, scale=1)

def linear(x, beta):
    return jnp.dot(x, beta)

def log_mle(x, beta, sigma, y_true):
    d = (y_star > L).astype(jnp.float32)
    temp1 = d * jnp.log(1 / (jnp.sqrt(2 * jnp.pi) * sigma)) + (-1 / (2 * sigma**2)) * (y_true - linear(x, beta))**2
    temp2 = (1 - d) * jnp.log(1 - normal_cdf(linear(x, beta) / sigma) + 0.001)
    return jnp.mean(temp1 + temp2)

def tobit_model_grad(x, y_true, beta, sigma):
    x_dot_beta = linear(x, beta)
    residual = y_true - x_dot_beta
    sigma_square = sigma**2
    cdf = normal_cdf(x_dot_beta / sigma)
    pdf = normal_pdf(x_dot_beta / sigma)
    d = (y_true > L).astype(jnp.float32)
    
    d_beta_temp = d*residual - (1-d) * sigma * pdf / (1 - cdf + 0.01)
    d_beta = - (1 / sigma_square) * jnp.dot(d_beta_temp, x) / x.shape[0]
    
    d_sigma2_temp_1 = d * (-1 / (2 * sigma_square) + (residual**2) / (2 * sigma**4))    
    d_sigma2_temp_2 = 1 / (2 * sigma**3 ) * (1 - d) * (pdf * x_dot_beta) / (1 - cdf + 0.01)
    d_sigma2 = - jnp.mean(d_sigma2_temp_1 + d_sigma2_temp_2)
    
    return d_beta, d_sigma2
    

def tobit_model_train(x, y_true, beta, sigma, lr, esp, max_iter):
    current_mle, current_esp, current_iter = 1, 1, 1

    while current_iter < max_iter and current_esp > esp:
        d_beta, d_sigma2 = jit(tobit_model_grad)(x, y, beta, sigma)

        beta = beta - lr * d_beta
        sigma = sigma - lr * jnp.sqrt(d_sigma2)

        mle = log_mle(x, beta, sigma, y_true)
        current_esp = jnp.abs(mle - current_mle)
        current_mle = mle
        current_iter += 1

        if current_iter % 100 == 0:
            print(beta, sigma, mle)

    return beta, sigma

In [259]:
%%time
lr = 1e-3
esp = 1e-6
max_iter = 30000

opt_beta, opt_sigma = tobit_model_train(x, y, beta, sigma, lr, esp, max_iter)

1
CPU times: user 493 ms, sys: 7.26 ms, total: 500 ms
Wall time: 351 ms


In [251]:
print(opt_beta, opt_sigma)
print(true_beta, true_sigma)

[ 3.041456    0.81613755 -1.6294001  -1.0765293   0.56011343] nan
[ 2.  1. -2.  0.  4.] 2.0


## JAX Version

In [223]:
@jax.partial(vmap, in_axes=(0, 0, None, None, None))
def select_grad(x, y_true, beta, sigma, threshold):
    return jax.lax.cond(
        y_true > threshold,
        lambda _: ( # > L
            - 1 / (sigma**2) * jnp.dot((y_true - linear(x, beta)), x),
            - (- 1 / (2*sigma**2) + (y_true - linear(x, beta))**2 / (2 * sigma**4))
        ), 
        lambda _: ( # < L
             1 / (sigma**2) * jnp.dot((sigma * normal_pdf(linear(x, beta) / sigma) / (1 - normal_cdf(linear(x, beta) / sigma) + 0.01)), x),
            - 1 / (2 * sigma**3) * normal_pdf(linear(x, beta) / sigma) * linear(x, beta) / (1 - normal_cdf(linear(x, beta) / sigma) + 0.01)
        ), 
        operand = None
    )

In [219]:

a, b = jit(select_grad)(x, y, beta, sigma, 0)
jnp.mean(a, axis=0), jnp.mean(b)

(DeviceArray([ 0.40317717, -0.02721015,  0.10851616, -0.33432332,
              -1.1207985 ], dtype=float32),
 DeviceArray(-3.938704, dtype=float32))

In [242]:
lr = 1e-3
esp = 1e-6
max_iter = 30000
# d_beta
# d_sigma2

# x, y_true, beta, sigma, lr, max_iter
# val = x, y_true, beta, sigma, lr, current_iter

def cond_fun(val):
    return val[5] < max_iter

def body_fun(val):
    x, y_true, beta, sigma, lr, current_iter = val
    d_beta, d_sigma2 = select_grad(x, y_true, beta, sigma, 0)
    new_beta = beta - lr * jnp.mean(d_beta, axis=0)
    new_sigma = sigma - lr * jnp.mean(jnp.sqrt(d_sigma2))



    current_iter += 1



    return (x, y_true, new_beta, new_sigma, lr, current_iter)


def jax_model_train(init_val):
    return jax.lax.while_loop(cond_fun, body_fun, init_val)

In [236]:
a = (1,2,3,4,4,6)
*_, b = a

In [244]:
cond_fun((1,2,3,4,5,5))

True

In [238]:
_

6