# Convergence and gradients of the Reinforce vs Gumbel-Softmax algorithm

Let's consider a simple optimization problem with the following loss:
$$\mathcal{L} = \mathbb{E}_{q(y)} \left( \sum_k (y_k - t_k)^2 \right) $$
where $t$ is a $K$-dimensional target variable and $y$ are categorical (discrete; one-hot encoded) samples. We search for $q$ s.t. the objective is minimized.

In [1]:
import tensorflow as tf
import tensorflow_probability as tfp

In [2]:
tf.__version__, tfp.__version__

('2.5.0', '0.12.2')

In [3]:
# Discretization via straight-through:

def st(y):
    K = y.shape[-1]
    y_hard = tf.cast(tf.round(y), y.dtype) if K==1 else tf.cast(tf.one_hot(tf.argmax(y,-1), K), y.dtype)  
    y = tf.stop_gradient(y_hard - y) + y  
    return y

### Problem specification

In [4]:
# We can run the experiment for a 1D Categorical variable or for a Bernoulli variable:

## Categorical
target = tf.constant([0.2, 0.7, 0.1])
RDist = tfp.distributions.OneHotCategorical # reference distribution
Dist = tfp.distributions.RelaxedOneHotCategorical # relaxed distribution

## Bernoulli
# target = tf.constant([0.9])
# RDist = tfp.distributions.Bernoulli  # reference distribution
# Dist = tfp.distributions.RelaxedBernoulli  # relaxed distribution

## REINFORCE

The Reinforce algorithm converges to zero loss and allocates all the probability mass at the most likely value.

In [5]:
logits = tf.Variable([1.0]*len(target))

In [6]:
eta = 0.01
for i in range(10000):    
    
    with tf.GradientTape() as tape:
        d = RDist(logits=logits)
        y = d.sample(1000)
        loss1 = tf.reduce_sum( (tf.cast(y, target.dtype)-target)**2, -1) * d.log_prob(y) 
        loss = tf.reduce_mean(loss1)
    
    grad = tape.gradient(loss, logits)    
    if i%500==0:
        print(f"{i}. loss={loss:.4f} grad={grad} "
          f"distribution={tf.reduce_mean( tf.cast(RDist(logits=logits).sample(10000), tf.float32), 0)}")
    
    logits.assign(logits-eta*grad)

0. loss=-0.9532 grad=[ 0.10523976 -0.24215983  0.13692033] distribution=[0.334 0.336 0.33 ]
500. loss=-0.6878 grad=[ 0.09391441 -0.21017459  0.11626166] distribution=[0.1321 0.7617 0.1062]
1000. loss=-0.2996 grad=[ 0.0473575  -0.07348796  0.02613076] distribution=[0.0517 0.9022 0.0461]
1500. loss=-0.3034 grad=[ 0.03656086 -0.07109018  0.03452918] distribution=[0.0332 0.9394 0.0274]
2000. loss=-0.1985 grad=[ 0.01989121 -0.04195623  0.022065  ] distribution=[0.0235 0.9597 0.0168]
2500. loss=-0.1576 grad=[ 0.01646709 -0.03144655  0.01497903] distribution=[0.0167 0.9698 0.0135]
3000. loss=-0.1076 grad=[ 0.00807082 -0.01964078  0.01157041] distribution=[0.0147 0.9733 0.012 ]
3500. loss=-0.1447 grad=[ 0.00607743 -0.02726551  0.02118774] distribution=[0.0124 0.9761 0.0115]
4000. loss=-0.1250 grad=[ 0.01095055 -0.02302097  0.01207027] distribution=[0.0096 0.9823 0.0081]
4500. loss=-0.1082 grad=[ 0.01117846 -0.01941768  0.00823947] distribution=[0.0088 0.9842 0.007 ]
5000. loss=-0.1118 grad=[ 0

## Gumbel softmax

The Gumbel-Softmax gradient is biased and therefore, the algorithm never achieves zero loss. The bias can be removed by annealing the temperature hyperparmeter down to 0.

In [7]:
logits = tf.Variable([1.0]*len(target))

In [8]:
eta = 0.01
for i in range(10000):    
    
    with tf.GradientTape() as tape:
        p = Dist(logits=logits, temperature=5.0)
        y = st(p.sample(1000))
        loss1 = tf.reduce_sum( (y-target)**2, -1)    
        loss = tf.reduce_mean(loss1)

    grad = tape.gradient(loss, logits)    
    if i%500==0:
        print(f"{i}. loss={loss:.4f} grad={grad} "
          f"distribution={tf.reduce_mean( st(Dist(logits=logits, temperature=0.001).sample(10000)), 0)}")
    
    logits.assign(logits-eta*grad)

0. loss=0.8886 grad=[ 0.02230619 -0.05010794  0.02780175] distribution=[0.3359 0.3287 0.3354]
500. loss=0.7864 grad=[ 0.01358167 -0.03801492  0.02443325] distribution=[0.2998 0.405  0.2952]
1000. loss=0.7070 grad=[ 0.01298843 -0.02954884  0.0165604 ] distribution=[0.2814 0.4716 0.247 ]
1500. loss=0.6660 grad=[ 0.0073073  -0.02408819  0.01678089] distribution=[0.2549 0.5233 0.2218]
2000. loss=0.6188 grad=[ 0.00434065 -0.01833881  0.01399817] distribution=[0.2374 0.5643 0.1983]
2500. loss=0.6006 grad=[ 0.00465481 -0.01618235  0.01152754] distribution=[0.2339 0.586  0.1801]
3000. loss=0.5496 grad=[ 0.00165005 -0.01009198  0.00844193] distribution=[0.2189 0.6071 0.174 ]
3500. loss=0.5524 grad=[ 0.00320495 -0.01043981  0.00723486] distribution=[0.2022 0.6426 0.1552]
4000. loss=0.5092 grad=[-0.00021451 -0.0050528   0.00526731] distribution=[0.2052 0.6499 0.1449]
4500. loss=0.5070 grad=[-0.00038981 -0.004733    0.00512281] distribution=[0.1933 0.6634 0.1433]
5000. loss=0.4958 grad=[-0.0018089

## Mixture of Discrete Normalizing Flows (MDNF)

MDNF is biased in a similar way e.g. the bias can be controlled by the temperature hyperparameter. The approximation uses hovewere a different parametrization and therefore the optimization and gradients behave different from the Gumbel-Softmax relaxation.

*Let's start by downloading and importing the necessary library:*

In [15]:
!git clone https://github.com/tkusmierczyk/mixture_of_discrete_normalizing_flows.git

In [11]:
import sys
sys.path.append("mixture_of_discrete_normalizing_flows/mdnf/")

from flows_factorized_mixture import FactorizedDiscreteFlowsMixture

*MDNF uses a different parametrization and therefore we create logits in a different way:*

In [12]:
eta = 0.01
p = FactorizedDiscreteFlowsMixture(N=1, K=len(target), B=10, temperature=5.)
logits = p.logits

In [13]:
# opt = tf.keras.optimizers.Adam(learning_rate=eta)

In [14]:
for i in range(10000):    
    #p.temperature = 1000./(i+1)
    with tf.GradientTape() as tape:
        y = st(p.sample(100)[:,0,:])
        loss1 = tf.reduce_sum( (y-target)**2, -1)    
        loss = tf.reduce_mean(loss1)

    grad = tape.gradient(loss, logits)    
    if i%500==0:
        y = st(p.sample(10000)[:,0,:])
        print(f"{i}. loss={loss:.4f} "
              f"distribution={tf.reduce_mean(y, 0)}"
             )
    
    #opt.apply_gradients(zip([grad], [logits]))
    logits.assign(logits-eta*grad)

0. loss=0.8000 distribution=[0.3 0.4 0.3]
500. loss=0.8200 distribution=[0.2 0.4 0.4]
1000. loss=0.8200 distribution=[0.2 0.4 0.4]
1500. loss=0.7200 distribution=[0.1 0.5 0.4]
2000. loss=0.7200 distribution=[0.1 0.5 0.4]
2500. loss=0.7200 distribution=[0.1 0.5 0.4]
3000. loss=0.6000 distribution=[0.1 0.6 0.3]
3500. loss=0.6000 distribution=[0.1 0.6 0.3]
4000. loss=0.5800 distribution=[0.2 0.6 0.2]
4500. loss=0.3600 distribution=[0.1 0.8 0.1]
5000. loss=0.2600 distribution=[0.  0.9 0.1]
5500. loss=0.3800 distribution=[0.  0.8 0.2]
6000. loss=0.3600 distribution=[0.1 0.8 0.1]
6500. loss=0.3600 distribution=[0.1 0.8 0.1]
7000. loss=0.2400 distribution=[0.1 0.9 0. ]
7500. loss=0.2600 distribution=[0.  0.9 0.1]
8000. loss=0.2600 distribution=[0.  0.9 0.1]
8500. loss=0.1400 distribution=[0. 1. 0.]
9000. loss=0.1400 distribution=[0. 1. 0.]
9500. loss=0.1400 distribution=[0. 1. 0.]
