## Imports and Utility Functions

---




In [None]:
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
# tf.reset_default_graph()
# !pip install tf-nightly
import tensorflow as tf
tf.random.set_seed(1234)
import numpy as np
np.random.seed(123)
from tqdm import tqdm
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc
from collections import defaultdict
from google.colab import files


In [None]:
class AverageMeter(object):
    def __init__(self, alpha=0.9):
        self.reset()
        self.alpha = alpha

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        if not hasattr(self, "moving_avg"):
            self.moving_avg = val
        else:
            self.moving_avg = self.moving_avg * self.alpha + val * (1 - self.alpha)

## Objective Function for Maximimization

---

$ (1/D) \mathbb{E}_{q_\eta(b)}\left[\|b - t\|^2\right], \quad b^1,b^2,\dots,b^K \sim q_{\eta}(b) := \mathrm{Bernoulli}(\theta), \quad \theta = 1 / (1 + e^{-\eta}), \quad \eta \in \mathbb{R}^D$

In [None]:
K = 2
D = 200
t = 0.499
target = np.array([[t for i in range(D)]], dtype=np.float64)
W = tf.constant(np.random.normal(0.0, 0.01, size=[D, D]).astype(np.float64))
print("Target is {}".format(target))

# @tf.function
def exact_grad(theta):
    return (1. - theta) * theta * ((1 - target[0, 0])**2 - target[0, 0]**2)

# @tf.function
def loss_func(b, t):
    return tf.reduce_mean((b - t)**2, axis=1) # - 0.1*tf.reduce_mean(t*tf.math.log(t), axis=1)

# @tf.function
def safe_log(x, eps=1e-8):
    return tf.math.log(tf.clip_by_value(x, eps, 1.0))

# @tf.function
def reparameterize_q(eta, noise):
    # noise: uniform [0, 1]
    return eta + safe_log(noise) - safe_log(1 - noise)

## Gradient Estimators

---



### Standard REINOFRCE



In [None]:
# @tf.function
def reinforce(eta, theta, b, K, baseline=None, baseline_ema=None, dry_run=None):
    # loss function evaluations
    f_b = loss_func(b, target)
    loss = tf.reduce_mean(f_b)

    fb_moving_avg = baseline_ema.average(baseline)
    fb_keep = f_b
    if fb_moving_avg is not None:
        f_b = f_b - 0.0*fb_moving_avg

    # f_b: [K]
    # b: [K, D]
    # eta: [1, D]
    # dlog_q: [K, D]
    dlog_q = b - theta
    eta_grads = tf.reduce_mean(f_b[:, None] * dlog_q, axis=0, keepdims=True)

    if not dry_run:
        baseline.assign(tf.reduce_mean(fb_keep))
        baseline_ema.apply([baseline])

    return loss, eta_grads

### REINFORCE Leave-One-Out (RLOO)

In [None]:
# @tf.function
def reinforce_loo(eta, theta, b, K):
    if K < 2:
        raise NotImplementedError("Leave-one-out requires K > 1.")
    # loss function evaluations
    f_b = loss_func(b, target)
    loss = tf.reduce_mean(f_b)
    # f_b: [K]
    # f_not_k: [K]
    f_not_k = tf.reduce_sum(f_b, axis=0) - f_b
    fk_minus_avg_f_not_k = f_b - f_not_k / (K - 1)
    # b: [K, D]
    # eta: [1, D]
    # dlog_q: [K, D]
    dlog_q = b - theta
    eta_grads = tf.reduce_mean(fk_minus_avg_f_not_k[:, None] * dlog_q, axis=0, keepdims=True)
    return loss, eta_grads

### MuProp

In [None]:
# @tf.function
def muprop(eta, theta, b, K, eta_mu=None, alpha=None):
    def _get_df_mu(mu):
        with tf.GradientTape() as tape:
            tape.watch(mu)
            # f_mu: [1]
            f_mu = loss_func(mu, target)
            f_mu_sum = tf.reduce_sum(f_mu)
        # df_mu: [1, D]
        df_mu = tape.gradient(f_mu_sum, mu)
        return f_mu, df_mu

    # loss function evaluations
    f_b = loss_func(b, target)
    loss = tf.reduce_mean(f_b)
    mu = theta
    f_mu, df_mu = _get_df_mu(theta)
    correction = df_mu * theta * (1. - theta)  
    # f_b: [K]
    # b: [K, D]
    # eta: [1, D]
    # dlog_q: [K, D]
    dlog_q = b - theta
    # baseline: [K]
    baseline = f_mu + tf.reduce_sum(df_mu * (b - mu), axis=-1) 
    eta_grads = tf.reduce_mean(
        (f_b - baseline)[:, None] * dlog_q, axis=0, keepdims=True) + correction
    return loss, eta_grads

### DisARM

In [None]:
def disarm(eta, theta, u, b, K):
    if K % 2 != 0:
        raise RuntimeError("DisARM requires K % 2 = 0.")
    # u: [K // 2, D]
    u = u[:(K // 2), :]
    # b: [K // 2, D]
    b = b[:(K // 2), :]
    # b_anti: [K // 2, D]
    b_anti = tf.cast(u < theta, tf.float64)
    # f_b: [K // 2]
    f_b = loss_func(b, target)
    # f_b_anti: [K // 2]
    f_b_anti = loss_func(b_anti, target)
    loss = 0.5 * tf.reduce_mean(f_b) + 0.5 * tf.reduce_mean(f_b_anti)
    # dlog_q: [K // 2, D]
    dlog_q = (-1.)**b_anti * tf.cast(b != b_anti, tf.float64) * tf.sigmoid(tf.abs(eta))
    eta_grads = 0.5 * tf.reduce_mean((f_b - f_b_anti)[:, None] * dlog_q, axis=0, keepdims=True)
    return loss, eta_grads

### Double Control Variates

In [None]:
# @tf.function
def double_control_variate(eta, theta, b, K, alpha):
    if K < 2:
        raise NotImplementedError("Leave-one-out requires K > 1.")
    with tf.GradientTape() as tape:
        tape.watch(b)
        # f_b: [K]
        f_b = loss_func(b, target)
        loss = tf.reduce_mean(f_b)
        f_b_sum = tf.reduce_sum(f_b)
    # grad_b: [K, D]
    grad_b = tape.gradient(f_b_sum, b)

    b1 = alpha*tf.reduce_sum(grad_b[1,:]*(b[0,:] - theta)) 
    c1 = alpha*tf.reduce_sum(grad_b[0,:]*(b[1,:] - theta)) 
    dlog_q = b - theta
    grad_avg = 0.5*(grad_b[1,:]+grad_b[0,:])
    global_corr = alpha*grad_avg*(theta*(1. - theta)) 
    diffs = f_b[0] + b1 - (f_b[1] + c1)
    eta_grads = 0.5 * ( diffs * dlog_q[0,:] - diffs * dlog_q[1,:] ) - global_corr
    return loss, eta_grads

# The following two auxiliary methods are only used to produce figure 5 in the appendix.
def double_control_variate_onlybxk(eta, theta, b, K, alpha):
    if K < 2:
        raise NotImplementedError("Leave-one-out requires K > 1.")
    with tf.GradientTape() as tape:
        tape.watch(b)
        # f_b: [K]
        f_b = loss_func(b, target)
        loss = tf.reduce_mean(f_b)
        f_b_sum = tf.reduce_sum(f_b)
    # grad_b: [K, D]
    grad_b = tape.gradient(f_b_sum, b)

    b0 = alpha*tf.reduce_sum(grad_b[1,:]*(b[0,:] - theta))  # + beta*tf.reduce_mean((b[0,:]-theta)**2)
    b1 = alpha*tf.reduce_sum(grad_b[0,:]*(b[1,:] - theta))  # + beta*tf.reduce_mean((b[1,:]-theta)**2)
    dlog_q = b - theta
    grad_avg = 0.5*(grad_b[0,:]+grad_b[1,:])
    global_corr = alpha*grad_avg*(theta*(1. - theta))  #+ (beta / D)*(theta*(1.-theta)*(1. - 2.*theta))
    diffs0 = f_b[0] + b0 - f_b[1]
    diffs1 = f_b[1] + b1 - f_b[0] 
    eta_grads = 0.5 * ( diffs0 * dlog_q[0,:] + diffs1 * dlog_q[1,:] ) - global_corr
    return loss, eta_grads

def double_control_variate_onlybxj(eta, theta, b, K, alpha):
    if K < 2:
        raise NotImplementedError("Leave-one-out requires K > 1.")
    with tf.GradientTape() as tape:
        tape.watch(b)
        # f_b: [K]
        f_b = loss_func(b, target)
        loss = tf.reduce_mean(f_b)
        f_b_sum = tf.reduce_sum(f_b)
    # grad_b: [K, D]
    grad_b = tape.gradient(f_b_sum, b)

    b0 = alpha*tf.reduce_sum(grad_b[1,:]*(b[0,:] - theta))  # + beta*tf.reduce_mean((b[0,:]-theta)**2)
    b1 = alpha*tf.reduce_sum(grad_b[0,:]*(b[1,:] - theta))  # + beta*tf.reduce_mean((b[1,:]-theta)**2)
    dlog_q = b - theta
    diffs0 = f_b[0] - (f_b[1] + b1)
    diffs1 = f_b[1] - (f_b[0] + b0)
    grad_avg = 0.5*(grad_b[0,:]+grad_b[1,:])
    global_corr = 0.*grad_avg*(theta*(1. - theta))
    eta_grads = 0.5 * ( diffs0 * dlog_q[0,:] + diffs1 * dlog_q[1,:] ) - global_corr
    return loss, eta_grads         

### Exact Mean Baseline (R*)

In [None]:
# @tf.function
def exact_mean_control_variate(eta, theta, b, K):
    if K < 2:
        raise NotImplementedError("Leave-one-out requires K > 1.")
    with tf.GradientTape() as tape:
        tape.watch(b)
        # f_b: [K]
        f_b = loss_func(b, target)
        loss = tf.reduce_mean(f_b)
        f_b_sum = tf.reduce_sum(f_b)
    loss = tf.reduce_mean(f_b)
    st_baseline = tf.reduce_mean(theta*((1.0 - target)**2) + (1-theta)*(target**2) ) # exact mean of f
    # dlog_q: [K, D]
    dlog_q = b - theta
    eta_grads = tf.reduce_mean((f_b[:, None] - st_baseline) * dlog_q, axis=0, keepdims=True)
    return loss, eta_grads

## Training

---


In [None]:
def main(estimator="reinforce_loo", eps=-1, lr=0.01, iters=2000):
    @tf.function 
    def train_one_step(eta, inf_opt, hyper_opt, estimator="reinforce_loo", eta_mu=None, baseline=None, baseline_ema=None, alpha=None, control_nn=None, dry_run=False):
        theta = tf.sigmoid(eta)

        print(tf.reduce_mean(theta))
        u = tf.random.uniform([K, D], dtype=tf.float64)
        z = reparameterize_q(eta, u)  # z(u)
        b = tf.cast(tf.stop_gradient(z > 0), dtype=tf.float64)

        # Methods
        if estimator == "reinforce_loo":
            loss, eta_grads = reinforce_loo(eta, theta, b, K)
        elif estimator == "reinforce":
            loss, eta_grads = reinforce(eta, theta, b, K, baseline, baseline_ema, dry_run=dry_run)
        elif estimator == "muprop":
            loss, eta_grads = muprop(eta, theta, b, K, eta_mu=None)
        elif estimator == "disarm":
            loss, eta_grads = disarm(eta, theta, u, b, K)
        elif estimator == "exact_mean_control_variate":
            loss, eta_grads = exact_mean_control_variate(eta, theta, b, K)   
        elif estimator == "double_control_variate":
            with tf.GradientTape() as tape:
                tape.watch(alpha)
                loss, eta_grads = double_control_variate(eta, theta, b, K, alpha)
                variance_loss = tf.reduce_mean(tf.square(eta_grads))
            if not dry_run:
                alpha_grads = tape.gradient(variance_loss, alpha)
                hyper_opt.apply_gradients([(alpha_grads, alpha)])                
        elif estimator == "double_control_variate_onlybxk":
            with tf.GradientTape() as tape:
                tape.watch(alpha)
                loss, eta_grads = double_control_variate_onlybxk(eta, theta, b, K, alpha)
                variance_loss = tf.reduce_mean(tf.square(eta_grads))
            if not dry_run:
                alpha_grads = tape.gradient(variance_loss, alpha)
                hyper_opt.apply_gradients([(alpha_grads, alpha)])
        elif estimator == "double_control_variate_onlybxj":
            with tf.GradientTape() as tape:
                tape.watch(alpha)
                loss, eta_grads = double_control_variate_onlybxj(eta, theta, b, K, alpha)
                variance_loss = tf.reduce_mean(tf.square(eta_grads))
            if not dry_run:
                alpha_grads = tape.gradient(variance_loss, alpha)
                hyper_opt.apply_gradients([(alpha_grads, alpha)])                  
        else:
            raise NotImplementedError()

        if dry_run:
            return eta_grads

        exact_grads = exact_grad(theta)
        inf_opt.apply_gradients([(-eta_grads, eta)])

        return loss, theta, eta_grads, exact_grads
    
    eta = tf.Variable(
        [[0.0 for i in range(D)]],
        trainable=True,
        name='eta',
        dtype=tf.float64
    )
    eta_mu = tf.Variable(
        [[0.0 for i in range(D)]],
        trainable=True,
        name='eta_mu',
        dtype=tf.float64
    )
    alpha = tf.Variable(
        0.,
        trainable=True,
        name="alpha",
        dtype=tf.float64
    )
   
    control_nn = tf.keras.Sequential()
    control_nn.add(
       tf.keras.layers.Dense(137, activation=tf.keras.layers.LeakyReLU(alpha=0.3), dtype=tf.float64))
    control_nn.add(tf.keras.layers.Dense(1, dtype=tf.float64))
    baseline = tf.Variable(initial_value=0., dtype=tf.float64)
    baseline_ema = tf.train.ExponentialMovingAverage(0.6)
    inf_opt = tf.keras.optimizers.RMSprop(0.01)
    hyper_opt = tf.keras.optimizers.RMSprop(0.0005) 

    meters = [AverageMeter(alpha=0.6) for i in range(4)]
    thetas = []
    losses = []
    alphas = []
    variances = defaultdict(list)
    epsilons = []
    eta_grad_vals = []
    exact_grad_vals = []
    compare_basket = []
    if len(compare_basket) == 0:
        compare_basket.append(estimator)
    var_meters = {est: AverageMeter(alpha=0) for est in compare_basket}

    for i in tqdm(range(iters)):

        if (i + 1) % 10 == 0:
            loss_value, theta_value, eta_grad_val, exact_grad_val = train_one_step(
                eta, inf_opt, hyper_opt, estimator=estimator, eta_mu=eta_mu, baseline=baseline, baseline_ema=baseline_ema, alpha=alpha, control_nn=control_nn)

            target_theta = tf.cast(tf.sigmoid(eta) > 0.5, tf.float64)
            tv = tf.reduce_mean(tf.abs(theta_value)).numpy()
            meters[0].update(tv)
            meters[1].update(loss_value.numpy())
            thetas.append(meters[0].moving_avg)
            losses.append(meters[1].moving_avg)
            eta_grad_vals.append(eta_grad_val.numpy()[0,0])
            exact_grad_vals.append(exact_grad_val.numpy()[0,0])

            meters[2].update(alpha.numpy())
            alphas.append(meters[2].moving_avg)

            for est in compare_basket:
                grads = []
                for i in range(2000):
                    g = train_one_step(eta, inf_opt, hyper_opt, estimator=est, eta_mu=eta_mu, baseline=baseline, baseline_ema=baseline_ema, alpha=alpha, control_nn=control_nn, dry_run=True)
                    grads.append(g.numpy()[0])
                m, v = np.mean(grads), np.mean(np.std(grads, axis=0)**2)
                var_meters[est].update(v)

                variances[est].append(var_meters[est].moving_avg)
        else:
            train_one_step(eta, inf_opt, hyper_opt, estimator=estimator, eta_mu=eta_mu, baseline=baseline, baseline_ema=baseline_ema, alpha=alpha, control_nn=control_nn)

    return tv, thetas, losses, variances, epsilons, eta_grad_vals, exact_grad_vals, alphas


For damping in Fisher control variates, search over $\epsilon$ in $10^{-6}$ to $10^3$.

In [None]:
eps_list = []
results = {}
for method in ["muprop", "reinforce", "reinforce_loo", "exact_mean_control_variate", "double_control_variate", "disarm"]:
    print("\n{}...\n".format(method))
    results[method] = main(method)

## Plots

---



### Variance

In [None]:
display_eps_list = []
rc('text', usetex=True)
rc('font', family='serif')
font_size = 18
matplotlib.style.use('default')
plt.rc('font', size=font_size)         # controls default text sizes
plt.rc('axes', titlesize=font_size)    # fontsize of the axes title
plt.rc('axes', labelsize=font_size)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=font_size)   # fontsize of the tick labels
plt.rc('ytick', labelsize=font_size)   # fontsize of the tick labels
plt.rc('legend', fontsize=font_size)   # legend fontsize
plt.rc('figure', titlesize=font_size)  # fontsize of the figure title
plt.figure()
for method in ["reinforce_loo", "double_control_variate", "disarm", "exact_mean_control_variate", "reinforce", "muprop"]:
    tv, thetas, losses, variances, epsilons, eta_grad_vals, exact_grad_vals, alphas = results[method]    
    plt.plot(variances[method], label=method)

plt.ylim([0., 1.3e-9])
plt.ylabel("Gradient Variance")
plt.xlabel("Step")
path = "log_var_toy_D{}_p0{}.pdf".format(D, target[0][0])
plt.savefig(path, dpi=300, bbox_inches='tight')
files.download(path) 

### Objective

In [None]:
plt.figure()
for method in ["reinforce_loo", "double_control_variate", "disarm", "exact_mean_control_variate", "reinforce", "muprop"]:
    tv, thetas, losses, variances, epsilons, eta_grad_vals, exact_grad_vals, alphas = results[method]    
    plt.plot(losses, label=method)  

plt.legend(['RLOO', 'Double CV', 'DisARM', 'R$^*$', 'Double CV', "Reinforce", "MuProp"])
plt.xticks([50, 100, 150, 200], ['500', '1000', '1500', '2000']) 
print(np.max(losses), np.min(losses))
plt.ylabel("Average $f(x)$")
plt.xlabel("Step")
path = "loss_toy_D{}_p0{}.pdf".format(D, target[0][0])
plt.savefig(path, dpi=300, bbox_inches='tight')
files.download(path)

### Average sigmas

In [None]:
plt.figure()
for method in ["reinforce_loo", "double_control_variate", "disarm", "exact_mean_control_variate"]: 
    tv, thetas, losses, variances, epsilons, eta_grad_vals, exact_grad_vals, alphas = results[method]
    print(method, thetas)
    plt.plot(thetas, label=method)  

plt.xticks([50, 100, 150, 200], ['500', '1000', '1500', '2000']) 
plt.ylabel("Average $\sigma(\eta_i)$")
plt.xlabel("Step")
path = "average_mu_D{}_p0{}.pdf".format(D, target[0][0])
plt.savefig(path, dpi=300, bbox_inches='tight')
files.download(path)