In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (8, 8)
plt.rcParams["font.size"] = 14

np.random.seed(42)

# The perceptron

$${\bf y} = \operatorname{sgn} (X {\bf w})$$

$$\left\{
\begin{aligned}
    &x_i \sim \frac12 \, \delta(x_i - 1) + \frac12 \, \delta (x_i + 1), \\
    &y_\mu \sim \frac{1}{2} \operatorname{erfc} \big( \frac{1}{\sqrt{2 \sigma^2}} {\bf F}_\mu \cdot {\bf x} \big).
\end{aligned}
\right.
$$

In [2]:
def sample_instance(size_x, rows_to_columns, var_noise):
    """Samples F from P(F) and {x, y} from P(x, y | F)"""
    
    # Some pre-processing
    size_y = int(np.ceil(rows_to_columns * size_x))
    
    # Sample x from P_0(x)
    x0 = np.random.choice([+1, -1], size_x)
    
    # Generate F and y = sgn(Fx + noise)
    F = np.random.randn(size_y, size_x) / np.sqrt(size_x)
    noise = np.sqrt(var_noise) * np.random.randn(size_y)
    y = np.sign(F.dot(x0) + noise)

    return x0, F, y

In [3]:
x, F, y = sample_instance(2000, 1.6, 1e-10)

In [4]:
def iterate_gamp(F, y, var_noise,
                 x0=None, max_iter=100, tol=1e-7, verbose=1):
    """Iterates GAMP to solve y = sign(Fx), w/ x Rademacher"""
    
    # Some pre-processing
    size_y, size_x = F.shape
    sqrF = F * F
    
    # Initialize variables
    A = np.ones(size_x)
    B = np.zeros(size_x)
    a = np.zeros(size_x)
    c = np.ones(size_x)
    
    a_old = np.zeros(size_x)
    
    v = np.ones(size_y)
    w = np.copy(y)
    g = np.zeros(size_y)
    dg = np.ones(size_y)
    
    mses = np.zeros(max_iter)
    for t in range(max_iter):
        # Iterate w and v, and compute g and g'
        g_old = np.copy(g)
        v = sqrF.dot(c)
        w = F.dot(a) - v * g_old
        g, dg = channel(y, w, v, var_noise)
        
        # Iterate A and B, and compute a and c
        a_old = np.copy(a)
        A = -sqrF.T.dot(dg)
        B = F.T.dot(g) + A * a
        a, c = prior(A, B)
        
        # Compute metrics
        diff = np.mean(np.abs(a - a_old))
        mses[t] = np.mean((a - x0) ** 2) if x0 is not None else 0
        
        # Print iteration status on screen
        if verbose:
            print("t = %d, diff = %g; mse = %g" % (t, diff, mses[t]))
        
        # Check for convergence
        if diff < tol or mses[t] == 0:
            break
            
    return mses

In [5]:
from scipy.special import erfcx

def prior(A, B):
    """Compute f and f' for Rademacher prior"""
    
    a = np.tanh(B)
    c = 1 - a ** 2
    return a, c

def channel(y, w, v, var_noise):
    """Compute g and g' for probit channel"""
    
    phi = -y * w / np.sqrt(2 * (v + var_noise))
    g = 2 * y / (np.sqrt(2 * np.pi * (v + var_noise)) * erfcx(phi))
    dg = -g * (w / (v + var_noise) + g)
    
    return g, dg

In [6]:
mses_gamp = iterate_gamp(F, y, var_noise=1e-10, x0=x);

t = 0, diff = 0.689799; mse = 0.428197
t = 1, diff = 0.241056; mse = 0.279662
t = 2, diff = 0.127886; mse = 0.212067
t = 3, diff = 0.0913546; mse = 0.165681
t = 4, diff = 0.0688579; mse = 0.130487
t = 5, diff = 0.0605754; mse = 0.103494
t = 6, diff = 0.0498199; mse = 0.0818637
t = 7, diff = 0.0414233; mse = 0.0613063
t = 8, diff = 0.0365946; mse = 0.0408108
t = 9, diff = 0.02819; mse = 0.0216691
t = 10, diff = 0.0211583; mse = 0.00749102
t = 11, diff = 0.00924465; mse = 1.8211e-05
t = 12, diff = 0.000203339; mse = 0


## State evolution

In [7]:
from scipy.integrate import quad
from scipy.special import erfc

def iterate_se(rows_to_columns, var_noise, max_iter=100, tol=1e-7, verbose=1):
    """Iterates state evolution associated to AMP implementation above"""
    
    # Define function to be integrated at each step
    f = lambda A: lambda z: np.exp(-z ** 2 / 2) / np.sqrt(2 * np.pi) * \
            (1 - np.tanh(A + np.sqrt(A) * z) ** 2)
    g = lambda v: lambda z: np.exp(-z ** 2 / 2) / np.sqrt(2 * np.pi) * \
            .5 * (erfc(-np.sqrt(.5 * (1 - v) / (var_noise + v)) * z) * \
                channel(+1, np.sqrt(1 - v) * z, v, var_noise)[1] + \
            erfc(+np.sqrt(.5 * (1 - v) / (var_noise + v)) * z) * \
                channel(-1, np.sqrt(1 - v) * z, v, var_noise)[1])
    
    v = np.zeros(max_iter)
    v[0] = 1
    
    for t in range(max_iter - 1):
        A = -rows_to_columns * quad(g(v[t]), -10, 10)[0]
        v[t + 1] = quad(f(A), -10, 10)[0]
        
        diff = np.abs(v[t + 1] - v[t])
        if verbose:
            print("t = %d, diff = %g; v = %g" % (t, diff, v[t + 1]))
            
        if diff < tol or v[t + 1] < 1e-5:
            break
    
    return v

In [8]:
mses_se = iterate_se(rows_to_columns=1.492, var_noise=1e-10, max_iter=500);

t = 0, diff = 0.534175; v = 0.465825
t = 1, diff = 0.144424; v = 0.321401
t = 2, diff = 0.0634054; v = 0.257995
t = 3, diff = 0.0349774; v = 0.223018
t = 4, diff = 0.0219967; v = 0.201021
t = 5, diff = 0.0150531; v = 0.185968
t = 6, diff = 0.0109252; v = 0.175043
t = 7, diff = 0.00827912; v = 0.166764
t = 8, diff = 0.00648431; v = 0.160279
t = 9, diff = 0.00521232; v = 0.155067
t = 10, diff = 0.00427869; v = 0.150788
t = 11, diff = 0.00357352; v = 0.147215
t = 12, diff = 0.00302807; v = 0.144187
t = 13, diff = 0.00259759; v = 0.141589
t = 14, diff = 0.00225194; v = 0.139337
t = 15, diff = 0.00197026; v = 0.137367
t = 16, diff = 0.00173768; v = 0.135629
t = 17, diff = 0.00154345; v = 0.134086
t = 18, diff = 0.00137958; v = 0.132706
t = 19, diff = 0.00124006; v = 0.131466
t = 20, diff = 0.00112031; v = 0.130346
t = 21, diff = 0.00101675; v = 0.129329
t = 22, diff = 0.0009266; v = 0.128403
t = 23, diff = 0.000847643; v = 0.127555
t = 24, diff = 0.000778099; v = 0.126777
t = 25, diff = 0.0