# Variational Learning of Posteriors for Discrete Bayesian Networks using Mixture of Discrete Normalizing Flows

A Bayes network (BN) represents a joint distribution of random variables factorized according to a directed
acyclic graph (DAG) that determines their conditional independence. For BNs with latent nodes determining their joint posterior is difficult even if the structure is known. For sufficiently small networks the true posterior can be evaluated by direct enumeration of all configurations, with exponential cost in the number of latent nodes. We use this as ground truth, and compare against approximated posteriors represented with MDNF.

## Imports

In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd

import copy
import time

import gc

In [2]:
import sys
sys.path.append('../mdnf')

In [3]:
import aux
import time_profiling

import base_constructors
import flows_mixture
import flows
import prob_recovery
import bayesian_networks
import inference

In [4]:
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(relativeCreated)6d %(message)s')

In [5]:
logger.info("TF version=%s" % tf.__version__)

  3477 TF version=2.2.0


## Configuration

In [6]:
# can be run as a script with args in format KEY=VAL,KEY=[STRVAL],...
args = aux.parse_script_args() 

  3494 parsing: <-f>


In [7]:
SEED = args.get("SEED", 1) # fix randomness

MODEL = args.get("MODEL", "bnets/asia.bif")
EVIDENCE = args.get("EVIDENCE", "asia-yes/xray-yes")
# MODEL = args.get("MODEL", "bnets/sachs.bif")
# EVIDENCE = args.get("EVIDENCE", "Akt-LOW")
# MODEL = args.get("MODEL", "bnets/hepar2.bif")
# EVIDENCE = args.get("EVIDENCE", "carcinoma-present")

EVIDENCE = aux.parse_dict(EVIDENCE, entries_separator="/", key2val_separator="-")  

# num categories, None = select automatically from Bayesian Network
K = args.get("K", None) 

# see create_base_mixture in base_constructors.py for options
BASE_SPECIFICATION = args.get("BASE_SPECIFICATION", "D(0.01)") 
B = args.get("B", 8) # how many flows in mixture
FLOW_TYPE = args.get("FLOW_TYPE", "M") # transformation type

# structure of transformation network
HIDDEN_LAYERS = args.get("HIDDEN_LAYERS", 1) 
HIDDEN_NODES_PER_VARIABLE = args.get("HIDDEN_NODES_PER_VARIABLE", 1)

INFERENCE = args.get("INFERENCE", 0) # 0=VIF, 1=BVIF/BVI
TRAIN_FLOWS = bool(args.get("TRAIN_FLOWS", 1)) # if 0 only weights will be trained (= BVI)

OPTIMIZER = args.get("OPTIMIZER", "RMS").upper()
LR = args.get("LR", 0.01)

# ignored: used only if weights are optimized separately
LR2 = args.get("LR2", 0.1) 
SWITCH_NITER = args.get("SWITCH_NITER", 40) 

NSAMPLES = args.get("NSAMPLES", 100) # how many samples to estimate ELBO
MAX_NITER = args.get("MAX_NITER", 100) 
NOIMPROV_NITER = args.get("NOIMPROV_NITER", 50) # stop if no improvement seen in niters

# temperature settings
BASE_TEMP = args.get("BASE_TEMP", 0.1) 
ANNEAL_RATE = args.get("ANNEAL_RATE", 0.0)
MIN_TEMP = args.get("MIN_TEMP", 0.001)

# where to save results
OUT = args.get("OUT", "BN_AVI_%s_%s_%s%s_%s_%s.csv" % \
               (MODEL.replace("/", "_").replace(".bif", ""), BASE_SPECIFICATION, 
                INFERENCE, int(TRAIN_FLOWS), B, SEED))
logger.info("Results output file: %s" % OUT)

  3525 Results output file: BN_AVI_bnets_asia_D(0.01)_01_8_1.csv


In [8]:
# Store & print configuration
CFG =      [MODEL, " ".join("%s=%s" % (k,v) for k,v in EVIDENCE.items()), SEED, 
            BASE_SPECIFICATION, B, FLOW_TYPE, 
            HIDDEN_NODES_PER_VARIABLE, HIDDEN_LAYERS,
            int(INFERENCE), int(TRAIN_FLOWS),
            OPTIMIZER, LR, LR2, 
            NSAMPLES, MAX_NITER, NOIMPROV_NITER, SWITCH_NITER,
            BASE_TEMP, ANNEAL_RATE, MIN_TEMP]
CFGNAMES = ["MODEL", "EVIDENCE", "SEED", 
            "BASE_SPECIFICATION", "B", "FLOW_TYPE", 
            "HIDDEN_NODES_PER_VARIABLE", "HIDDEN_LAYERS",
            "INFERENCE", "TRAIN_FLOWS",
            "OPTIMIZER", "LR", "LR2", 
            "NSAMPLES", "MAX_NITER", "NOIMPROV_NITER", "SWITCH_NITER",
            "BASE_TEMP", "ANNEAL_RATE", "MIN_TEMP"]

logger.info("CONFIGURATION:\n "+"\n ".join("%s=%s" % (name, val) 
                                     for name, val in zip(CFGNAMES, CFG)) )

  3539 CONFIGURATION:
 MODEL=bnets/asia.bif
 EVIDENCE=asia=yes xray=yes
 SEED=1
 BASE_SPECIFICATION=D(0.01)
 B=8
 FLOW_TYPE=M
 HIDDEN_NODES_PER_VARIABLE=1
 HIDDEN_LAYERS=1
 INFERENCE=0
 TRAIN_FLOWS=1
 OPTIMIZER=RMS
 LR=0.01
 LR2=0.1
 NSAMPLES=100
 MAX_NITER=100
 NOIMPROV_NITER=50
 SWITCH_NITER=40
 BASE_TEMP=0.1
 ANNEAL_RATE=0.0
 MIN_TEMP=0.001


Create an optimizer:

In [9]:
OPTIMIZERS = {"RMS": tf.keras.optimizers.RMSprop,
              "ADAM": tf.keras.optimizers.Adam}
if OPTIMIZER not in OPTIMIZERS: raise ValueError("Unknown optimizer!")
optimizer_class = OPTIMIZERS[OPTIMIZER]
optimizer = optimizer_class(learning_rate=LR)
optimizer_weights = optimizer_class(learning_rate=LR2)
logger.info("optimizer=%s optimizer_weights=%s" % (optimizer, optimizer_weights))

  3551 optimizer=<tensorflow.python.keras.optimizer_v2.rmsprop.RMSprop object at 0x7f5c9a25cb10> optimizer_weights=<tensorflow.python.keras.optimizer_v2.rmsprop.RMSprop object at 0x7f5c9a25cad0>


## Network & evidence

Load Bayesian network and fix evidence.

In [10]:
net = bayesian_networks.BayesianNetworkVI(EVIDENCE, MODEL)

  7284 +-----------+------+
| asia(yes) | 0.01 |
+-----------+------+
| asia(no)  | 0.99 |
+-----------+------+
  7285 +------------+------------+-----------+
| smoke      | smoke(yes) | smoke(no) |
+------------+------------+-----------+
| bronc(yes) | 0.6        | 0.3       |
+------------+------------+-----------+
| bronc(no)  | 0.4        | 0.7       |
+------------+------------+-----------+
  7287 +-----------+-------------+------------+-------------+------------+
| bronc     | bronc(yes)  | bronc(yes) | bronc(no)   | bronc(no)  |
+-----------+-------------+------------+-------------+------------+
| either    | either(yes) | either(no) | either(yes) | either(no) |
+-----------+-------------+------------+-------------+------------+
| dysp(yes) | 0.9         | 0.8        | 0.7         | 0.1        |
+-----------+-------------+------------+-------------+------------+
| dysp(no)  | 0.1         | 0.2        | 0.3         | 0.9        |
+-----------+-------------+------------+----------

Model joint log-probability evaluation with the network


In [11]:
log_prob = lambda sample: net.log_prob( net.set_evidence(sample) )

If possible obtain posterior by enumeration

In [12]:
if net.enumeration_size < 10e6:
    positions, probs = net.posteriors_via_enumeration()
    TARGET = bayesian_networks.as_tensor(positions, probs)
else:
    TARGET = None

## Recording configuration & results

In [13]:
def _store_results(RESULTS, COLS, CFG, CFGNAMES, OUT):
    results_pd = pd.DataFrame(RESULTS).rename(columns=dict(enumerate(COLS)))
    for name, val in zip(CFGNAMES, CFG):
        results_pd[name] = str(val)

    logger.info("Writing %i data rows to: %s" % (len(RESULTS), OUT))
    results_pd.to_csv(OUT, header=True, index=False)    

In [14]:
RESULTS = []

# what to put into results    
COLS = ["wall_time", "time", "time_forward", "time_backward",
        "C", "iteration", "temp", "loss", "kl", "best_loss", "best_kl", "kl2"]   

start_time = time.time()
best_loss, best_kl = float("inf"), float("inf")

def record_status(status, iteration, loss): # callback function
    global best_loss, best_kl    
    improved = loss < best_loss    
    if not improved and iteration>10 and iteration%10!=0: return 
    
    kl, kl2, _ = prob_recovery.kl_divergences(status.base, status.flow, TARGET)    
    if improved: best_loss, best_kl = loss, kl
    
    try: C = status.C
    except: C = ""
    temp = status.flow.temperature
    logger.info(("[%.0fs](%s) %s:%i. loss=%.3f " +
           "kl=%.2f (best: %.3f/%.2f) kl2=%.2f temp=%.4f\n\tmixing=%s") % \
          (time.time() - start_time, ("*" if improved else " "), C, iteration, 
           loss, kl, best_loss, best_kl, kl2,
           temp if temp is not None else float("nan"),
           str(np.round(status.base.mixing_probs,2))[:200]))
    
    RESULTS.append((time.time() - start_time,
                    status.time_forward + status.time_backward, 
                    status.time_forward, status.time_backward,
                    C, iteration, status.flow.temperature,
                    loss, kl, best_loss, best_kl, kl2))     

## Approximating distribution: set bases and flows

In [15]:
np.random.seed(SEED)
tf.random.set_seed(SEED)

if K is None: K = net.cardinality # set automatically
N = net.N-len(EVIDENCE) 
if B is None or B<=0: B = K**N
logger.info("N=%i K=%i B=%i" % (N, K, B))

  7451 N=6 K=2 B=8


Construct bases

In [16]:
base = base_constructors.create_categorical_blocks(N, K, B, BASE_SPECIFICATION)

for i in range(len(base.components)):
    logger.debug("base no%s:\n%s" % (i, np.round(base.components[i].probs, 3)))

Construct flows

In [17]:
if not TRAIN_FLOWS:
    mixture_flows = [flows.DummyFlow(temperature=BASE_TEMP) for _ in range(B)]     
elif FLOW_TYPE in ["F", "FU"]:
    mixture_flows = [flows.DiscreteFlow(N,K, layers=[(FLOW_TYPE, None)]*HIDDEN_LAYERS) 
                     for _ in range(B)]
else:
    HIDDEN_NODES = int(N*HIDDEN_NODES_PER_VARIABLE) 
    mixture_flows = flows.parse_layers_specification([(FLOW_TYPE, [HIDDEN_NODES]*HIDDEN_LAYERS)]*B,
                     N, K, temperature=BASE_TEMP)

flow = flows_mixture.DiscreteFlowsMixture(N, K, B, flows=mixture_flows)

In [18]:
try:
    for i, d1 in enumerate(base.distributions):
        logger.debug("probabilities of component no %i:" % i)
        logger.debug(np.round(d1.get_joint_probability_array(), 3))
except Exception as e:
    logger.warn(" Failed: %s -> Skipping..." % e)
    pass

#logger.info("Trainable variables:\n %s" % "\n ".join([v.name for v in flow.trainable_variables]))

  
  7576  Failed: Outer product currently implemented only for 2D and 3D arrays! -> Skipping...


# Inference

In [19]:
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [20]:
ID2INFERENCE = {
    0: inference.VariationalInference,
    1: inference.BoostingVariationalInference,
    2: inference.IterativeVariationalInference,
    3: inference.BoostingVariationalInferenceAltering,
    4: inference.BoostingVariationalInferenceAlteringIndep,
}
inference_class = ID2INFERENCE[INFERENCE]

vi = inference_class(log_prob=log_prob, base=base, flow=flow, 
 temperature_annealing=inference.TemperatureAnnealingExp(BASE_TEMP,ANNEAL_RATE,MIN_TEMP),
 nsamples=NSAMPLES, max_niter=MAX_NITER, noimprov_niter=NOIMPROV_NITER, 
 optimizer=optimizer, optimizer_weights=optimizer_weights, switch_niter=SWITCH_NITER)
    
logger.info("VI inference type: %s" % vi)    

  7599 VI inference type: <inference.VariationalInference object at 0x7f5c980ccc50>


In [21]:
total_niters = vi.fit(callback=record_status)

  8045 [1s](*) :0. loss=23.168 kl=14.99 (best: 23.168/14.99) kl2=7.41 temp=0.1000
	mixing=[0.12 0.12 0.12 0.12 0.12 0.12 0.12 0.12]
  8564 [1s](*) :1. loss=13.970 kl=9.07 (best: 13.970/9.07) kl2=6.60 temp=0.1000
	mixing=[0.12 0.12 0.12 0.12 0.12 0.12 0.12 0.12]
  8940 [2s]( ) :2. loss=14.460 kl=9.72 (best: 13.970/9.07) kl2=5.63 temp=0.1000
	mixing=[0.12 0.12 0.12 0.12 0.12 0.12 0.12 0.12]
  9284 [2s]( ) :3. loss=20.794 kl=11.79 (best: 13.970/9.07) kl2=7.21 temp=0.1000
	mixing=[0.12 0.12 0.12 0.12 0.12 0.12 0.12 0.12]
  9623 [2s]( ) :4. loss=18.740 kl=10.46 (best: 13.970/9.07) kl2=6.68 temp=0.1000
	mixing=[0.12 0.12 0.12 0.12 0.12 0.12 0.12 0.12]
 10104 [3s](*) :5. loss=8.900 kl=2.20 (best: 8.900/2.20) kl2=6.04 temp=0.1000
	mixing=[0.12 0.12 0.12 0.12 0.12 0.12 0.12 0.12]
 10587 [3s](*) :6. loss=8.574 kl=1.96 (best: 8.574/1.96) kl2=5.70 temp=0.1000
	mixing=[0.12 0.12 0.12 0.12 0.12 0.12 0.12 0.12]
 10966 [4s]( ) :7. loss=11.896 kl=5.44 (best: 8.574/1.96) kl2=6.11 temp=0.1000
	mixing=[0.

In [22]:
_store_results(RESULTS, COLS, CFG, CFGNAMES, OUT)

 34151 Writing 21 data rows to: BN_AVI_bnets_asia_D(0.01)_01_8_1.csv


In [23]:
if TARGET is not None:
    base, flow = vi.best_base, vi.best_flow
    kl, kl2, flow_output_probs = prob_recovery.kl_divergences(base, flow, TARGET)
    logger.info("kl=%.3f kl2=%.3f" % (kl, kl2))

 34219 kl=1.185 kl2=5.524


## Time measurements

In [24]:
print(time_profiling.get_report())

                                        func  count      total    median  \
0                  DiscreteFlowsMixture.call     76  11.137210  0.138255   
1               DiscreteFlowsMixture.reverse     98   2.366740  0.021829   
2  FactorizedCategoricalMixture.log_prob_ext     98   0.048434  0.000474   
3   FactorizedCategoricalMixture.sample_extm     76   0.212890  0.002547   

       mean       min       max      q=.8  #max  
0  0.146542  0.126653  0.293996  0.157641     2  
1  0.024150  0.018183  0.143198  0.025152     1  
2  0.000494  0.000351  0.000907  0.000524     4  
3  0.002801  0.002316  0.007165  0.002985     1  
