In [None]:
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
import gin
from tqdm.auto import tqdm
import pandas as pd
import seaborn as sns

In [None]:
gin.enter_interactive_mode()

# Regression result with $l_2$ regularization and $l_1$ fit loss

We solve for $L=\|z-z_0\|_1+\mu\|z-1\|^2_2\to\min\limits_{z}$. The first term mimicks the "common to model" loss term from the Tournesol loss, and the second term models the "common to 1" loss.

The subgradient of the (1-dim) loss function is
$$
\partial L=\begin{cases}
1+2\mu(z-1),&z>z_0\\
-1+2\mu(z-1),&z<z_0\\
[-1,1]+2\mu(z-1),&z=z_0
\end{cases}
$$

The optimality condition $0\in \partial L$ gives us

$z=\begin{cases}
z_0,&if z_0\in 1+\frac{1}{2\mu}[-1,1]\\
1-\frac{1}{2\mu},&if z_0<1-\frac{1}{2\mu}\\
1+\frac{1}{2\mu},&if z_0>1+\frac{1}{2\mu}\\
\end{cases}$

Which is equivalent to
$$
z=\max\left(1-\frac{1}{2\mu}, \min\left[1+\frac{1}{2\mu}, z_0\right]\right)= clamp(z_0, min=1-\frac{1}{2\mu}, max=1+\frac{1}{2\mu})
$$

In [None]:
# dimensionality
n = 100
z0 = np.linspace(-10, 10, n)

In [None]:
plt.plot(z0)
plt.xlabel('Dimension')
plt.ylabel('Value z_0')

In [None]:
def get_z():
    """Get the trainable variable."""
    z = tf.Variable(tf.zeros(n))
    return z

In [None]:
@gin.configurable
def loss(mu, z, z0):
    """Compute the loss function |z-z0|_1+mu|z-1|^2_2."""
    loss_fit = tf.reduce_sum(tf.abs(z - z0))
    loss_reg = tf.reduce_sum(tf.square(z - 1))
    loss_total = loss_fit + mu * loss_reg
    return {'fit': loss_fit,
            'reg': loss_reg,
            'total': loss_total}

In [None]:
def opt_step(z, opt):
    """One optimization step."""
    
    z_vars = [z]
    
    with tf.GradientTape() as tape:
        losses = loss(z=z)
        loss_total = losses['total']

    grads = tape.gradient(loss_total, z_vars)
    opt.apply_gradients(zip(grads, z_vars))

    losses['grad_norm'] = tf.linalg.norm(grads)

    losses = {x: y.numpy() for x, y in losses.items()}
    
    return losses

In [None]:
def plot_z_z0(z, z0, mu, show=True):
    """Plot z-vs-z0 scatter plot."""
    plt.plot(z0, z)
    plt.xlabel('z0')
    plt.ylabel('z')
    plt.title(f'z vs z0, mu={round(mu, 2)}')
    if show:
        plt.show()

In [None]:
def experiment(z0, mu=0, epochs=25000):
    """One optimization experiment."""
    # learnable parameter
    z = get_z()
    
    opt = tf.optimizers.Adam()
    gin.bind_parameter('loss.z0', z0)
    gin.bind_parameter('loss.mu', mu)
    
    losses = []

    for _ in tqdm(range(epochs)):
        losses.append(opt_step(z, opt))
    
    df = pd.DataFrame(losses)
    
    plt.figure(figsize=(13, 5))
    n_plots = len(df.columns)
    for i, col in enumerate(sorted(df.columns), 1):
        plt.subplot(1, n_plots, i)
        plt.plot(df[col])
        plt.yscale('log')
        plt.title(col)
    plt.show()
    
    plot_z_z0(z.numpy(), z0, mu)
    plt.show()
        
    return {'z': np.array(z.numpy()), 'losses': df}

In [None]:
experiment(z0, mu=0.1, epochs=10)

## Trying different mus

In [None]:
mus = np.logspace(np.log10(0.02), np.log10(1), num=10)
results = []

In [None]:
mus

In [None]:
for mu in tqdm(mus):
    results.append(experiment(z0, mu=mu, epochs=25000))

In [None]:
for mu, z in zip(mus, [r['z'] for r in results]):
    plt.plot(z0, z, label=mu)
plt.legend()
plt.show()

In [None]:
sns.heatmap(np.array([r['z'] for r in results]), yticklabels=mus)

In [None]:
for mu, z in zip(mus, [r['z'] for r in results]):
    plt.plot(z0, z, label=mu)
plt.legend()
plt.show()

In [None]:
?sns.heatmap