# Variational Learning of Posteriors for Discrete Bayesian Networks using Gumbel-Softmax with relaxed priors


A Bayes network (BN) represents a joint distribution of random variables factorized according to a directed
acyclic graph (DAG) that determines their conditional independence. For BNs with latent nodes determining their joint posterior is difficult even if the structure is known. For sufficiently small networks the true posterior can be evaluated by direct enumeration of all configurations, with exponential cost in the number of latent nodes. We use this as ground truth, and compare against approximated posteriors represented with Gumbel-Softmax.

## Imports

In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import numpy as np
import pandas as pd
from collections import Counter

import copy
import time

import gc

In [2]:
import sys
sys.path.append('../mdnf')

In [3]:
import aux
import time_profiling
import inference

import prob_recovery
import bayesian_networks

In [4]:
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(relativeCreated)6d %(message)s')

In [5]:
logger.info("TF version=%s" % tf.__version__)

  3537 TF version=2.3.0


## Configuration

In [6]:
# can be run as a script with args in format KEY=VAL,KEY=[STRVAL],...
args = aux.parse_script_args() 

  3541 parsing: <-f>


In [7]:
SEED = args.get("SEED", 1) # fix randomness

# MODEL = args.get("MODEL", "bnets/cancer.bif")
# EVIDENCE = args.get("EVIDENCE", "Cancer-True")
MODEL = args.get("MODEL", "bnets/earthquake.bif")
EVIDENCE = args.get("EVIDENCE", "MaryCalls-True")
# MODEL = args.get("MODEL", "bnets/asia.bif")
# EVIDENCE = args.get("EVIDENCE", "asia-yes/xray-yes")
#MODEL = args.get("MODEL", "bnets/sachs.bif")
#EVIDENCE = args.get("EVIDENCE", "Akt-LOW")
# MODEL = args.get("MODEL", "bnets/hepar2.bif")
# EVIDENCE = args.get("EVIDENCE", "carcinoma-present")

EVIDENCE = aux.parse_dict(EVIDENCE, entries_separator="/", key2val_separator="-")  

OPTIMIZER = args.get("OPTIMIZER", "RMS").upper()
LR = args.get("LR", 0.01)

NSAMPLES = args.get("NSAMPLES", 100) # how many samples to estimate ELBO
MAX_NITER = args.get("MAX_NITER", 10001) 
NOIMPROV_NITER = args.get("NOIMPROV_NITER", MAX_NITER) # stop if no improvement seen in niters

# temperature settings
PRIORS_TEMP =  args.get("PRIORS_TEMP", 1.0)
BASE_TEMP = args.get("BASE_TEMP", 1.0) 
ANNEAL_RATE = args.get("ANNEAL_RATE", 0.0)
MIN_TEMP = args.get("MIN_TEMP", 0.001)

ST = bool(args.get("ST", False))

# where to save results
OUT = args.get("OUT", "BN_GS_KL.csv")
logger.info("Results output file: %s" % OUT)

  3548 Results output file: BN_GS_KL.csv


In [20]:
COLS = ["ALG", "SEED", "MODEL", "EVIDENCE", "OPTIMIZER", "LR", "NSAMPLES", "MAX_NITER", 
               "PRIORS_TEMP", "BASE_TEMP", "ANNEAL_RATE", "MIN_TEMP", "ST", "OUT"]+\
       ["iteration", "temp", "loss", "loss_evidence", "loss_priors", "loss_entropy", 
                            "kl1", "kl2", "nonzeros", "time_forward", "time_backward"]

CFG = ["GSKL", SEED, MODEL, " ".join("%s=%s" % (k,v) for k,v in EVIDENCE.items()), 
       OPTIMIZER, LR, NSAMPLES, MAX_NITER, 
       PRIORS_TEMP, BASE_TEMP, ANNEAL_RATE, MIN_TEMP, int(ST), OUT]

logger.info("CONFIGURATION:\n "+"\n ".join("%s=%s" % (name, val) 
                                     for name, val in zip(COLS, CFG)) )

397735 CONFIGURATION:
 ALG=GSKL
 SEED=1
 MODEL=bnets/earthquake.bif
 EVIDENCE=MaryCalls=True
 OPTIMIZER=RMS
 LR=0.01
 NSAMPLES=100
 MAX_NITER=10001
 PRIORS_TEMP=1.0
 BASE_TEMP=1.0
 ANNEAL_RATE=0.0
 MIN_TEMP=0.001
 ST=0
 OUT=BN_GS_KL.csv


## Network & evidence

Load Bayesian network and fix evidence.

In [9]:
net = bayesian_networks.BayesianNetworkVI(EVIDENCE, MODEL)

  4904 +--------------+------------------+-------------------+------------------+-------------------+
| Burglary     | Burglary(True)   | Burglary(True)    | Burglary(False)  | Burglary(False)   |
+--------------+------------------+-------------------+------------------+-------------------+
| Earthquake   | Earthquake(True) | Earthquake(False) | Earthquake(True) | Earthquake(False) |
+--------------+------------------+-------------------+------------------+-------------------+
| Alarm(True)  | 0.95             | 0.94              | 0.29             | 0.001             |
+--------------+------------------+-------------------+------------------+-------------------+
| Alarm(False) | 0.05             | 0.06              | 0.71             | 0.999             |
+--------------+------------------+-------------------+------------------+-------------------+
  4907 +-----------------+------+
| Burglary(True)  | 0.01 |
+-----------------+------+
| Burglary(False) | 0.99 |
+-----------------+----

If possible obtain posterior by enumeration

In [10]:
if net.enumeration_size < 10e6:
    positions, probs = net.posteriors_via_enumeration()
    TARGET = bayesian_networks.as_tensor(positions, probs)
else:
    TARGET = None

# Inference

Create an optimizer:

In [11]:
OPTIMIZERS = {"RMS": tf.keras.optimizers.RMSprop,
              "ADAM": tf.keras.optimizers.Adam}
if OPTIMIZER not in OPTIMIZERS: raise ValueError("Unknown optimizer!")
optimizer_class = OPTIMIZERS[OPTIMIZER]
optimizer = optimizer_class(learning_rate=LR)
logger.info("optimizer=%s" % (optimizer))

  4955 optimizer=<tensorflow.python.keras.optimizer_v2.rmsprop.RMSprop object at 0x7fdfb8e92790>


In [12]:
K = net.cardinality # set automatically
N = net.N-len(EVIDENCE) 
logger.info("N=%i K=%i" % (N, K))

  4960 N=4 K=2


In [13]:
temperature_annealing = inference.TemperatureAnnealingExp(BASE_TEMP, ANNEAL_RATE, MIN_TEMP)

In [14]:
def straight_through_sample(y):
    K = y.shape[-1]
    y_hard = tf.cast(tf.one_hot(tf.argmax(y,-1), K), y.dtype)
    y = tf.stop_gradient(y_hard - y) + y  
    return y

In [15]:
EPS = 1e-31

def kl_divs(q, target, nsamples=max(100000, 5*K**N)):
    sample = straight_through_sample( q.sample(nsamples) )
    output_probs = prob_recovery.recover_prob_array_tf_one_hot( sample )    
    
    kl =  np.sum(output_probs * (np.log(output_probs+EPS)-np.log(target+EPS)))
    kl2 = np.sum(target * (-np.log(output_probs+EPS)+np.log(target+EPS)))        
    nonzeros = len(np.nonzero(output_probs.reshape(-1))[0])
    return kl, kl2, nonzeros

In [16]:
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [17]:
logits = tf.Variable(tf.random.normal((N, K)), name="logits", dtype='float32')

last_improvement = 0
best_loss = 1000000.
time_forward = 0
time_backward = 0
IMPROV_EPS = 0
results = []

In [18]:
for iteration in range(MAX_NITER):

    start_op_time = time.time()
    with tf.GradientTape() as tape:
        temp = float(temperature_annealing(iteration))
        q = tfd.RelaxedOneHotCategorical(temp, logits=logits)

        sample = q.sample(NSAMPLES)

        entropy = -tf.reduce_mean(tf.reduce_sum(q.log_prob(sample + 1e-31), -1))  # !                
        sample = straight_through_sample(sample) if ST else sample
        sample = net.set_evidence(sample)
        evidence = tf.reduce_mean(net.log_prob_evidence(sample))
        priors = tf.reduce_mean(net.log_relaxed_prob_priors(sample, PRIORS_TEMP, eps=1e-10))

        elbo = entropy + evidence + priors
        loss = -elbo

    time_forward += time.time() - start_op_time

    improved = loss + IMPROV_EPS < best_loss
    if improved:  # Track the best solution
        last_improvement = iteration

    # must be ivoked before approximation update
    if iteration%50 == 0 or (iteration < 50 and iteration%10 == 0) or iteration<10:
        kl1, kl2, nonzeros = kl_divs(q, TARGET)        
        print("%i. t_q=%.3f loss=%.3f p(x|z)=%.3f p(z;t=%s)=%.3f H(q)=%.3f (true: kl1=%.3f kl2=%.3f nz(q)=%d)" % \
              (iteration, temp, loss.numpy(), evidence, PRIORS_TEMP, priors, entropy, kl1, kl2, nonzeros))
        
        results.append(CFG+[iteration, temp, float(loss), float(evidence), float(priors), float(entropy), 
                            float(kl1), float(kl2), nonzeros, time_forward, time_backward])        
        if iteration%1000 == 0:
            pd.DataFrame(results).rename(columns=dict(enumerate(COLS))).to_csv(OUT, header=False, index=False)
        
    start_op_time = time.time()
    grads = tape.gradient(loss, q.trainable_variables)
    optimizer.apply_gradients(zip(grads, q.trainable_variables))
    logits = q.logits
    time_backward += time.time() - start_op_time

#     if iteration >= 0 and NOIMPROV_NITER < iteration - last_improvement:
#         logger.info("[VariationalInference.fit] No improvement in recent %i iterations. Stop." % \
#                     self.noimprov_niter)
#         break

0. t_q=1.000 loss=12.285 p(x|z)=-2.984 p(z;t=1.0)=-7.668 H(q)=-1.633 (true: kl1=7.562 kl2=2.467 nz(q)=16)
1. t_q=1.000 loss=10.584 p(x|z)=-2.808 p(z;t=1.0)=-6.572 H(q)=-1.203 (true: kl1=7.313 kl2=2.374 nz(q)=16)
2. t_q=1.000 loss=10.746 p(x|z)=-2.546 p(z;t=1.0)=-7.216 H(q)=-0.984 (true: kl1=7.171 kl2=2.316 nz(q)=16)
3. t_q=1.000 loss=10.889 p(x|z)=-2.816 p(z;t=1.0)=-6.806 H(q)=-1.267 (true: kl1=7.021 kl2=2.272 nz(q)=16)
4. t_q=1.000 loss=10.747 p(x|z)=-2.555 p(z;t=1.0)=-7.185 H(q)=-1.007 (true: kl1=6.882 kl2=2.222 nz(q)=16)
5. t_q=1.000 loss=10.706 p(x|z)=-2.785 p(z;t=1.0)=-6.733 H(q)=-1.189 (true: kl1=6.825 kl2=2.215 nz(q)=16)
6. t_q=1.000 loss=10.002 p(x|z)=-2.494 p(z;t=1.0)=-6.598 H(q)=-0.910 (true: kl1=6.694 kl2=2.165 nz(q)=16)
7. t_q=1.000 loss=10.352 p(x|z)=-2.649 p(z;t=1.0)=-6.688 H(q)=-1.016 (true: kl1=6.606 kl2=2.140 nz(q)=16)
8. t_q=1.000 loss=10.035 p(x|z)=-2.670 p(z;t=1.0)=-6.340 H(q)=-1.025 (true: kl1=6.509 kl2=2.121 nz(q)=16)
9. t_q=1.000 loss=9.690 p(x|z)=-2.623 p(z;t=1.

3200. t_q=1.000 loss=3.034 p(x|z)=-1.738 p(z;t=1.0)=2.269 H(q)=-3.566 (true: kl1=1.429 kl2=1.997 nz(q)=16)
3250. t_q=1.000 loss=2.950 p(x|z)=-1.656 p(z;t=1.0)=2.452 H(q)=-3.745 (true: kl1=1.419 kl2=1.997 nz(q)=16)
3300. t_q=1.000 loss=3.152 p(x|z)=-1.833 p(z;t=1.0)=2.455 H(q)=-3.774 (true: kl1=1.409 kl2=1.988 nz(q)=16)
3350. t_q=1.000 loss=2.969 p(x|z)=-1.753 p(z;t=1.0)=2.170 H(q)=-3.386 (true: kl1=1.409 kl2=2.043 nz(q)=16)
3400. t_q=1.000 loss=3.049 p(x|z)=-1.673 p(z;t=1.0)=2.612 H(q)=-3.988 (true: kl1=1.429 kl2=1.967 nz(q)=16)
3450. t_q=1.000 loss=3.005 p(x|z)=-1.699 p(z;t=1.0)=2.230 H(q)=-3.535 (true: kl1=1.395 kl2=1.963 nz(q)=16)
3500. t_q=1.000 loss=3.058 p(x|z)=-1.676 p(z;t=1.0)=2.260 H(q)=-3.642 (true: kl1=1.404 kl2=1.988 nz(q)=16)
3550. t_q=1.000 loss=2.905 p(x|z)=-1.759 p(z;t=1.0)=3.068 H(q)=-4.214 (true: kl1=1.424 kl2=2.023 nz(q)=16)
3600. t_q=1.000 loss=3.004 p(x|z)=-1.694 p(z;t=1.0)=2.047 H(q)=-3.356 (true: kl1=1.412 kl2=1.996 nz(q)=16)
3650. t_q=1.000 loss=3.151 p(x|z)=-1.

7050. t_q=1.000 loss=2.929 p(x|z)=-1.848 p(z;t=1.0)=2.587 H(q)=-3.668 (true: kl1=1.399 kl2=2.006 nz(q)=16)
7100. t_q=1.000 loss=2.950 p(x|z)=-1.801 p(z;t=1.0)=2.263 H(q)=-3.412 (true: kl1=1.401 kl2=1.959 nz(q)=16)
7150. t_q=1.000 loss=2.883 p(x|z)=-1.642 p(z;t=1.0)=2.446 H(q)=-3.686 (true: kl1=1.424 kl2=2.014 nz(q)=16)
7200. t_q=1.000 loss=2.928 p(x|z)=-1.795 p(z;t=1.0)=2.228 H(q)=-3.361 (true: kl1=1.396 kl2=1.964 nz(q)=16)
7250. t_q=1.000 loss=2.967 p(x|z)=-1.600 p(z;t=1.0)=2.318 H(q)=-3.685 (true: kl1=1.415 kl2=2.025 nz(q)=16)
7300. t_q=1.000 loss=2.989 p(x|z)=-1.848 p(z;t=1.0)=2.386 H(q)=-3.526 (true: kl1=1.410 kl2=2.033 nz(q)=16)
7350. t_q=1.000 loss=2.980 p(x|z)=-1.812 p(z;t=1.0)=2.487 H(q)=-3.656 (true: kl1=1.404 kl2=1.981 nz(q)=16)
7400. t_q=1.000 loss=2.909 p(x|z)=-1.535 p(z;t=1.0)=2.042 H(q)=-3.416 (true: kl1=1.398 kl2=1.990 nz(q)=16)
7450. t_q=1.000 loss=3.119 p(x|z)=-1.829 p(z;t=1.0)=2.136 H(q)=-3.427 (true: kl1=1.403 kl2=1.979 nz(q)=16)
7500. t_q=1.000 loss=3.046 p(x|z)=-1.

In [19]:
print("Compare recovered vs target distribution (on values > eps):")
output_probs = prob_recovery.recover_prob_array_tf_one_hot(straight_through_sample(q.sample(1000000)))
mask=(output_probs+TARGET>1e-6)
print(np.round(output_probs[mask].reshape(-1), 3), "\n", np.round(TARGET[mask].reshape(-1), 3))

Compare recovered vs target distribution (on values > eps):
[0.    0.    0.002 0.005 0.003 0.007 0.084 0.18  0.    0.    0.006 0.012
 0.008 0.017 0.215 0.46 ] 
 [0.006 0.001 0.275 0.031 0.171 0.019 0.029 0.003 0.    0.    0.    0.
 0.    0.006 0.023 0.436]
