<a href="https://colab.research.google.com/github/xx529/ML-Algorithm-Implementation/blob/main/Proximal%20Algorithms%20-%20L1%20Regularization/Proximal%20Algorithms%20-%20L1%20Regularization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Proximal Algorithms - L1 Regularization

# Algorithm explanation

1. 采用 L1 正则的求逻辑回归的参数梯度

$$\underset{\theta}{\text{min}} \sum_i^N(y_i-\sigma(x_i^T\theta))^2 + \lambda||\theta||_1$$


2. 简略算法原理

    1. 利用 proxmial operator 方法求解，其定义
        
        $$\text{prox}_f(v) = \underset{x}{\text{argmin}}(f(x)+\frac{1}{2}||x-v||^2_2)$$
        
    2. 对于 proxmial operator 有一个良好的性质
        
        $$ x^* = \underset{x}{\text{argmin}} f(x) \text{    当且仅当    } x^* = prox_f(x^*)$$
        
    3. 对于一个函数 $z$ 由一个可微函数 $f$ 和一个不可微分函数 $g$ 组成，对于 $z$ 有如下定义重要公式
        
        $$z(x) = f(x) + g(x)$$
        
        $$x^{k+1} := \text{prox}_{\lambda^kg}(x^k - \lambda^k \nabla f(x^k))$$

    4. 以上更新公式实际就是求解下式子
        $$\theta^*=\underset{\theta}{\text{argmin}}f(\theta)+g(\theta)$$
        
    4. 应用在参数L1参数求解当中 $prox_f(x)$ 中的 $x$ 就是我门需要求的 $\theta$， $\lambda^k$ 是学习步长
        
        $$w:=\theta^k-\lambda^k\nabla f_{\theta^k}(X)$$
        
        $$\theta^{k+1} = \text{prox}_{{\lambda^k}g}(w)$$

    5. 令 $g(\theta)=\lambda||\theta||_1$ 代入 proximal operator 定义式得，其中 $\lambda$ 就是惩罚系数 penalty
        
        $$\theta^{k+1}=\text{prox}_{\lambda^kg}(w)=\underset{\theta}{\text{argmin}}(\lambda^k \lambda||\theta||_1+\frac{1}{2}||\theta-w||_2^2)$$
        
    6. 求解上面式子的结果，$\lambda$ 是 L1 正则的惩罚系数，$\lambda^k$ 是学习步长，$\theta$ 是一个向量

$$ [\theta^{k+1}]_i=\left\{
\begin{matrix}
w_i - \lambda\lambda^k, & \text{ if } & w_i > \lambda\lambda^k\\
        0, & \text{ if } & \lambda\lambda^k < w_i < -\lambda\lambda^k\\
        w_i + \lambda\lambda^k, & \text{ if } & w_i < -\lambda\lambda^k\\
\end{matrix}
\right.
$$

3. 详细证明过程附件中

# Test Data

In [20]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from jax import grad, vmap

In [21]:
key = jax.random.PRNGKey(0)

true_theta = jnp.array([2.0, 2.0, 0.0, 0.0, 0.0, 5.0])
x = jax.random.normal(key, (1000, len(true_theta)))
theta = jax.random.normal(key, (len(true_theta), ))
y = (jnn.sigmoid(jnp.dot(x, true_theta)) >= 0.5).astype(jnp.float32)

print('generated theta', true_theta)
print('initial theta', theta)

generated theta [2. 2. 0. 0. 0. 5.]
initial theta [ 0.18784384 -1.2833426   0.6494181   1.2490593   0.24447003 -0.11744965]


# Code Implementation

## Normal Version

In [22]:
def predict(theta, x):
    return jnn.sigmoid(jnp.dot(x, theta))

def get_lr_grad(theta, x, y_true):
    y_pre = predict(theta, x)
    return jnp.dot(x.T, y_pre - y_true) / x.shape[0]

def get_lr_loss(theta, x, y_true):
    y_pre = predict(theta, x)
    return -jnp.mean(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre))


def soft_threshold(vec_theta, threshold):
    new_vec_theta = []

    for theta in vec_theta:
        if theta > threshold:
            new_vec_theta.append(theta - threshold)
        elif theta < -threshold:
            new_vec_theta.append(theta + threshold)
        else:
            new_vec_theta.append(0)
    
    return jnp.array(new_vec_theta)
    

def proximal_method_update(theta, x, y_true, lr, penalty):
    w = theta - lr * get_lr_grad(theta, x, y_true)
    new_theta = soft_threshold(w, lr * penalty)

    return jnp.array(new_theta)


def model_train(theta, x, y_true, lr, penalty, esp, max_iter):
    current_iter = 1
    current_esp = 1
    current_loss = jnp.inf
    
    while max_iter > current_iter and current_esp > esp:
        
        new_loss = get_lr_loss(theta, x, y_true) + penalty * jnp.linalg.norm(theta, 1)
        new_theta = proximal_method_update(theta, x, y_true, lr, penalty)
        current_esp = jnp.abs(new_loss - current_loss)
        
        theta = new_theta
        current_loss = new_loss
        current_iter += 1
        
        # print('{}  loss: {:.4f} theta: {}'.format(current_iter-1, loss, theta))

    return new_theta

In [23]:
%time
lr = 0.05
penalty = 0.1
esp = 1e-4
max_iter = 3000

proximal_method_theta = model_train(theta, x, y, lr, penalty, esp, max_iter)
proximal_method_theta

CPU times: user 2 µs, sys: 1e+03 ns, total: 3 µs
Wall time: 6.91 µs


DeviceArray([0.20016596, 0.12244455, 0.        , 0.        , 0.        ,
             1.0061318 ], dtype=float32)

## Jax Version

In [24]:
def jax_loss(theta, x, y_true):
    y_pre = jnn.sigmoid(jnp.dot(x, theta))
    return -(jnp.mean(y*jnp.log(y_pre) + (1-y)*jnp.log(1-y_pre)))

def jax_lr_grad(theta, x, y_true):
    y_pre = jnn.sigmoid(jnp.dot(x, theta))
    return jnp.dot(x.T, -(y_true - y_pre)) / x.shape[0]

@jax.partial(vmap, in_axes=(0, None))
def jax_soft_threshold(w, threshold):
    return jax.lax.cond(
        w > threshold, 
        lambda _: w - threshold,
        lambda _: jax.lax.cond(
            w < - threshold,
            lambda _: w + threshold,
            lambda _: 0.0,
            None
            ),
        None
        )
    
def jax_proximal_method_update(theta, x, y_true, lr, penalty):
    w = theta - lr * jax_lr_grad(theta, x, y_true)
    return jax_soft_threshold(w, lr * penalty)

def jax_model_train(theta, x, y_true, lr, penalty, esp, max_iter):
    current_iter = 1
    current_esp = 1
    current_loss = jnp.inf
    
    while max_iter > current_iter and current_esp > esp:
        
        new_loss = jax_loss(theta, x, y_true) + penalty * jnp.linalg.norm(theta, 1)
        new_theta = jax_proximal_method_update(theta, x, y_true, lr, penalty)
        current_esp = jnp.abs(new_loss - current_loss)
        
        theta = new_theta
        current_loss = new_loss
        current_iter += 1
        # print('{}  loss: {:.4f} theta: {}'.format(current_iter-1, current_loss, theta))

    return new_theta

In [31]:
%time
lr = 0.05
penalty = 0.1
esp = 1e-4
max_iter = 100000

jax_proximal_method_theta = model_train(theta, x, y, lr, penalty, esp, max_iter)
jax_proximal_method_theta

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 6.91 µs


DeviceArray([0.20016596, 0.12244455, 0.        , 0.        , 0.        ,
             1.0061318 ], dtype=float32)