# Damage model 
The original model is 

\begin{equation}\label{eq: original model}
y^*  = y \cdot I\{c \cdot y > l\} + \alpha \cdot y \cdot I\{c \cdot y < l\}
\end{equation}



# Smooth function

Note that we have the indicator function in the model we might consider the smooth function:
\begin{equation}\label{eq: trans function}
S(x;s) = \frac{1}{1 + \exp(-s \cdot x)}
\end{equation},
where $s$ is the smoothing hyper-parameter to control the smoothness. 

Then the model becomes 
$$
y^*  = y \cdot \frac{1}{1 + \exp(-s\cdot(cy-l))} + \alpha \cdot y \cdot \frac{1}{1 + \exp(-s\cdot(l-cy))}
$$

We have $y \geq y^*$.

The lumber have three groups:

- Group 1: $y <l$, $y^* < l$, i.e., $y^* <y < l$. The lumber pieces are broken blow the proof loading.
- Group 2: $y >l$, $y^* < l$, i.e. $y^*<l<y$. The lumber pieces are broken during the proof loading process. This groups we only knows how many pieces. 

update: this group should be $0<y^*<l<y$. So $F_y(h^{-1}(l)) - F_y(h^{-1}(\max(0,h^{-1}(l)))$
- Group 3: $y >l$, $y^* > l$, i.e. $l<y^*<y$. The lumber pieces survived in the proof-loading. And then we destruct them to test their strength.


# The PDF calculation 

Given $Y \sim N(\mu, \sigma^2)$, $Y^* = h(Y)$. Then the pdf of $Y^*$,
$$
f_{Y^*}(y^*) = f_{Y}(h^{-1}(y^*))|\frac{d}{dy^*}h^{-1}(y^*)|,
$$
where $f_Y()$ is the pdf of $Y$, i.e., normal. 

Following this, We need the numerical function of $h^{-1}(y^*)$, and its numerical gradient $\frac{d}{dy^*}h^{-1}(y^*)$. (The analytical form doesn't seem available.)

# The range of alpha

For the model 

\begin{equation}
y^*  = y \cdot I\{c \cdot y > l\} + \alpha \cdot y \cdot I\{c \cdot y < l\}.
\end{equation}

What happened if $\alpha < c$?

then $\alpha*y < c*y <l $. It means that all damaged pieces are censored. So we don't have damaged pieces. The remaining pieces in group 3 are all undamaged. Then we can only have the range of $\alpha$ but no specific estimate.

In bivarite dataset, we don't have this problem because we have $c*x$ and $\alpha*y$.

In [1]:
import jax
import jaxopt
import jax.numpy as jnp
import pyreadr
import projplot as pjp

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
@jax.jit

def indicator(x):
    return(jnp.select([x>0,x<=0],[1,0]))

def logit(x):
    return(jnp.log(x/(1-x)))


def expit(x):
    return 1/(1+jnp.exp(-x))

def exp_smooth(x):
    return(jnp.select(
    [x >0, x<=0],[jnp.exp(-1/(x+ 0.001)),0]))

def g_smooth(x):
    return(exp_smooth(x)/(exp_smooth(x) + exp_smooth(1-x)))

def sigmoid(x, s):
    # x = jnp.array(x)
    # a = jnp.array(a)
    return 0.5 * (jnp.tanh(x * s / 2) + 1)

def dmgmodel_ind(y,alpha,l,c):
    return(y*indicator(c*y-l) + alpha*y*indicator(l-c*y))


def dmgmodel_py(y,alpha,l,c,s):
    #return(y*jax.scipy.stats.norm.cdf(c*y-l) + alpha*y*jax.scipy.stats.norm.cdf(l-c*y))
    return(y*sigmoid(c*y-l,s) + alpha*y*sigmoid(l-c*y,s))

    #return(y*g_smooth(c*y-l) + alpha*y*g_smooth(l-c*y))

def dmgmodel_root_py(y,alpha,l,c,s,ystar):
    return(dmgmodel_py(y,alpha,l,c,s) - ystar)



def dmginverse_py(ystar,alpha,l,c,s):
    ystar = jnp.array(ystar)
    bisec = jaxopt.Bisection(
        optimality_fun=dmgmodel_root_py,
        lower = 0,
        upper = 10000,
        check_bracket = False)
    return(bisec.run(alpha = alpha,l = l, c= c ,s = s,ystar = ystar).params)

def dmginvgrad_py(ystar,alpha,l,c,s):
    grad_func = jax.grad(dmginverse_py,0)
    return(jnp.abs(grad_func(ystar,alpha,l,c,s)))

def dmglik_py(ystar,alpha,l,c,s,mu,sigma):
    y =  dmginverse_py(ystar,alpha,l,c,s)
    return(jax.scipy.stats.norm.logpdf(y,loc = mu,scale = sigma)+ 
           jnp.log(dmginvgrad_py(ystar,alpha,l,c,s))
          )


def dmglik_vmap(y_group,alpha,l,c,s,mu,sigma):
    y_group = jnp.array(y_group)
    lik = jax.vmap(lambda y_group: dmglik_py(ystar = y_group,
                                             alpha = alpha,l = l, c= c,s =s,mu = mu, sigma=sigma))(y_group)
    return(jnp.sum(lik))



change $s = 10$ to $s =1$ can make loglik of alpha and c more smooth.

# In the second case, $\alpha <c$

In [3]:
# the orignal sample size 
#N = 30000
N = 300
mu = 48
sigma = 19
l =  32
# alpha = 0.6
# c = 0.65
alpha = 0.65
c = 0.68
s = 10

N_marginal = 139





In [4]:
params_draw = jnp.zeros(shape = (100, 4))
@jax.jit
def negdmglik_jax(theta,y_obs_g1,y_obs_g2,y_obs_g3,y_obs_g4):
    mu = theta[0]
    sigma = theta[1]
    alpha = theta[2]
    c = theta[3]
    lik1 = jnp.sum(jax.scipy.stats.norm.logpdf(y_obs_g1,loc = mu, scale = sigma))
    #lik1 = dmglik_vmap(y_group = y_obs_g1,alpha = alpha,l = l, c = c,s = s, mu = mu, sigma = sigma)
#     lik2 = y_obs_g2*jnp.log(
#         jax.scipy.stats.norm.cdf(dmginverse_py(l,alpha,l,c,s), loc=mu, scale=sigma) - 
#         jax.scipy.stats.norm.cdf(l, loc=mu, scale=sigma)
#     )
    lik2 = y_obs_g2*jnp.log(
    jax.scipy.stats.norm.cdf(dmginverse_py(l,alpha,l,c,s), loc=mu, scale=sigma) - 
    jax.scipy.stats.norm.cdf(dmginverse_py(jnp.maximum(0.1,dmgmodel_py(l,alpha,l,c,s)),alpha,l,c,s), loc=mu, scale=sigma)
    )
    lik3 = dmglik_vmap(y_group = y_obs_g3,alpha = alpha,l = l, c = c,s = s, mu = mu, sigma = sigma)
    lik4 = jnp.sum(jax.scipy.stats.norm.logpdf(y_obs_g4,loc = mu, scale = sigma))

    return(-lik1 - lik2-lik3-lik4)
#params_draw

In [5]:
# data generation 
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, num=N)


for ii in range(100):

    y = sigma*jax.random.normal(subkeys[ii], shape=(N, )) + mu

    #y = y[y>0]

    # g1
    y_obs_g1  = y[y<l]


    # g23_star
    g23 = y[y>l]
    g23_star = jax.vmap(lambda y: dmgmodel_py(y,alpha,l,c,s))(g23)


    # g3
    y_obs_g3 = g23_star[g23_star > l]

    # g2
    y_obs_g2 = N - len(y_obs_g1) - len(y_obs_g3)

    #y_obs_g2 = jnp.shape(y)[0] - len(y_obs_g1) - len(y_obs_g3)

    # g4, is all marginal data
    y_obs_g4 = sigma*jax.random.normal(subkeys[0], shape=(N_marginal, )) + mu
    
    theta0 = jnp.array([mu,sigma,alpha,c])

    solver = jaxopt.ScipyMinimize(method = "Nelder-Mead",fun=negdmglik_jax)
    res = solver.run(theta0,y_obs_g1,y_obs_g2,y_obs_g3,y_obs_g4)

    # solver = jaxopt.BFGS(fun=negdmglik_jax)
    # res = solver.run(theta0)

    params_draw = params_draw.at[ii,:].set(res.params)
 

SyntaxError: invalid syntax (3911772289.py, line 6)

In [None]:
# # negdmglik_jax(theta0)
# #dmglik_vmap(y_group = y_obs_g3,alpha = alpha,l = l, c = c,s = s, mu = mu, sigma = sigma)

# #dmglik_py(y_obs_g3[0][0],alpha,l,c,s,mu,sigma)

# #dmglik_vmap(y_obs_g3,alpha,l,c,s,mu,sigma)

# s= 10
# l = l 
# y_obs_g1 = y_obs_g1
# y_obs_g2 = y_obs_g2
# y_obs_g3 = y_obs_g3
# y_obs_g4 = y_obs_g4
# @jax.jit
# def negdmglik_jax(theta,y_obs_g1,y_obs_g2,y_obs_g3,y_obs_g4):
#     mu = theta[0]
#     sigma = theta[1]
#     alpha = theta[2]
#     c = theta[3]
#     lik1 = jnp.sum(jax.scipy.stats.norm.logpdf(y_obs_g1,loc = mu, scale = sigma))
#     #lik1 = dmglik_vmap(y_group = y_obs_g1,alpha = alpha,l = l, c = c,s = s, mu = mu, sigma = sigma)
# #     lik2 = y_obs_g2*jnp.log(
# #         jax.scipy.stats.norm.cdf(dmginverse_py(l,alpha,l,c,s), loc=mu, scale=sigma) - 
# #         jax.scipy.stats.norm.cdf(l, loc=mu, scale=sigma)
# #     )
#     lik2 = y_obs_g2*jnp.log(
#     jax.scipy.stats.norm.cdf(dmginverse_py(l,alpha,l,c,s), loc=mu, scale=sigma) - 
#     jax.scipy.stats.norm.cdf(dmginverse_py(jnp.maximum(0.1,dmgmodel_py(l,alpha,l,c,s)),alpha,l,c,s), loc=mu, scale=sigma)
#     )
#     lik3 = dmglik_vmap(y_group = y_obs_g3,alpha = alpha,l = l, c = c,s = s, mu = mu, sigma = sigma)
#     lik4 = jnp.sum(jax.scipy.stats.norm.logpdf(y_obs_g4,loc = mu, scale = sigma))

#     return(-lik1 - lik2-lik3-lik4)

# theta0 = jnp.array([mu,sigma,alpha,c])



# # negdmglik_jax(theta0)

In [None]:
# theta0 = jnp.array([mu,sigma,alpha,c])

# solver = jaxopt.ScipyMinimize(method = "Nelder-Mead",fun=negdmglik_jax)
# res = solver.run(theta0,y_obs_g1,y_obs_g2,y_obs_g3,y_obs_g4)

# # solver = jaxopt.BFGS(fun=negdmglik_jax)
# # res = solver.run(theta0)

# res.params,res.state
# theta0