# Variational Learning of Posteriors for Discrete Bayesian Networks using Mixture of Discrete Normalizing Flows (assumes factorized posteriors)

A Bayes network (BN) represents a joint distribution of random variables factorized according to a directed
acyclic graph (DAG) that determines their conditional independence. For BNs with latent nodes determining their joint posterior is difficult even if the structure is known. For sufficiently small networks the true posterior can be evaluated by direct enumeration of all configurations, with exponential cost in the number of latent nodes. We use this as ground truth, and compare against approximated posteriors represented with MDNF.

## Imports

In [1]:
#!pip install pgmpy==0.1.10

In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd

import copy
import time

import gc

In [3]:
import sys
sys.path.append('../mdnf')

In [4]:
import aux
import time_profiling

import base_constructors
import flows_mixture
import flows
import prob_recovery
import bayesian_networks
import inference



In [5]:
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(relativeCreated)6d %(message)s')

In [6]:
logger.info("TF version=%s" % tf.__version__)

  4904 TF version=2.3.0


## Configuration

In [7]:
# can be run as a script with args in format KEY=VAL,KEY=[STRVAL],...
args = aux.parse_script_args() 

  4909 parsing: <-f>


In [8]:
SEED = args.get("SEED", 0) # fix randomness

# MODEL = args.get("MODEL", "bnets/cancer.bif")
# EVIDENCE = args.get("EVIDENCE", "Cancer-True")
MODEL = args.get("MODEL", "bnets/earthquake.bif")
EVIDENCE = args.get("EVIDENCE", "MaryCalls-True")
#MODEL = args.get("MODEL", "bnets/asia.bif")
# EVIDENCE = args.get("EVIDENCE", "asia-yes/xray-yes")
#MODEL = args.get("MODEL", "bnets/sachs.bif")
#EVIDENCE = args.get("EVIDENCE", "Akt-LOW")
# MODEL = args.get("MODEL", "bnets/hepar2.bif")
# EVIDENCE = args.get("EVIDENCE", "carcinoma-present")

EVIDENCE = aux.parse_dict(EVIDENCE, entries_separator="/", key2val_separator="-")  

# see create_base_mixture in base_constructors.py for options
BASE_SPECIFICATION = args.get("BASE_SPECIFICATION", "p") 
B = args.get("B", 100) # how many flows in mixture
FLOW_TYPE = args.get("FLOW_TYPE", "F") # transformation type

# structure of transformation network
HIDDEN_LAYERS = args.get("HIDDEN_LAYERS", 1) 
HIDDEN_NODES_PER_VARIABLE = args.get("HIDDEN_NODES_PER_VARIABLE", 1)

INFERENCE = args.get("INFERENCE", 0) # 0=VIF, 1=BVIF/BVI
TRAIN_FLOWS = bool(args.get("TRAIN_FLOWS", 1)) # if 0 only weights will be trained (= BVI)

OPTIMIZER = args.get("OPTIMIZER", "RMS").upper()
LR = args.get("LR", 0.01)

# IGNORED (used only if weights are optimized separately)
LR2 = args.get("LR2", 0.1) 
SWITCH_NITER = args.get("SWITCH_NITER", 40) 

NSAMPLES = args.get("NSAMPLES", 100) # how many samples to estimate ELBO
MAX_NITER = args.get("MAX_NITER", 10000) 
NOIMPROV_NITER = args.get("NOIMPROV_NITER", 1000) # stop if no improvement seen in niters

# temperature settings
BASE_TEMP = args.get("BASE_TEMP", 10.0) 
ANNEAL_RATE = args.get("ANNEAL_RATE", 0.01)
MIN_TEMP = args.get("MIN_TEMP", 0.001)

# where to save results
OUT = args.get("OUT", "BN_AVI_%s_%s_%s%s_%s_%s.csv" % \
               (MODEL.replace("/", "_").replace(".bif", ""), BASE_SPECIFICATION, 
                INFERENCE, int(TRAIN_FLOWS), B, SEED))
logger.info("Results output file: %s" % OUT)

  4919 Results output file: BN_AVI_bnets_earthquake_p_01_100_0.csv


In [9]:
# Store & print configuration
CFG =      [MODEL, " ".join("%s=%s" % (k,v) for k,v in EVIDENCE.items()), SEED, 
            BASE_SPECIFICATION, B, FLOW_TYPE, 
            HIDDEN_NODES_PER_VARIABLE, HIDDEN_LAYERS,
            int(INFERENCE), int(TRAIN_FLOWS),
            OPTIMIZER, LR, LR2, 
            NSAMPLES, MAX_NITER, NOIMPROV_NITER, SWITCH_NITER,
            BASE_TEMP, ANNEAL_RATE, MIN_TEMP]
CFGNAMES = ["MODEL", "EVIDENCE", "SEED", 
            "BASE_SPECIFICATION", "B", "FLOW_TYPE", 
            "HIDDEN_NODES_PER_VARIABLE", "HIDDEN_LAYERS",
            "INFERENCE", "TRAIN_FLOWS",
            "OPTIMIZER", "LR", "LR2", 
            "NSAMPLES", "MAX_NITER", "NOIMPROV_NITER", "SWITCH_NITER",
            "BASE_TEMP", "ANNEAL_RATE", "MIN_TEMP"]

logger.info("CONFIGURATION:\n "+"\n ".join("%s=%s" % (name, val) 
                                     for name, val in zip(CFGNAMES, CFG)) )

  4925 CONFIGURATION:
 MODEL=bnets/earthquake.bif
 EVIDENCE=MaryCalls=True
 SEED=0
 BASE_SPECIFICATION=p
 B=100
 FLOW_TYPE=F
 HIDDEN_NODES_PER_VARIABLE=1
 HIDDEN_LAYERS=1
 INFERENCE=0
 TRAIN_FLOWS=1
 OPTIMIZER=RMS
 LR=0.01
 LR2=0.1
 NSAMPLES=100
 MAX_NITER=10000
 NOIMPROV_NITER=1000
 SWITCH_NITER=40
 BASE_TEMP=10.0
 ANNEAL_RATE=0.01
 MIN_TEMP=0.001


## Network & evidence

Load Bayesian network and fix evidence.

In [10]:
net = bayesian_networks.BayesianNetworkVI(EVIDENCE, MODEL)

  6118 +--------------+------------------+-------------------+------------------+-------------------+
| Burglary     | Burglary(True)   | Burglary(True)    | Burglary(False)  | Burglary(False)   |
+--------------+------------------+-------------------+------------------+-------------------+
| Earthquake   | Earthquake(True) | Earthquake(False) | Earthquake(True) | Earthquake(False) |
+--------------+------------------+-------------------+------------------+-------------------+
| Alarm(True)  | 0.95             | 0.94              | 0.29             | 0.001             |
+--------------+------------------+-------------------+------------------+-------------------+
| Alarm(False) | 0.05             | 0.06              | 0.71             | 0.999             |
+--------------+------------------+-------------------+------------------+-------------------+
  6120 +-----------------+------+
| Burglary(True)  | 0.01 |
+-----------------+------+
| Burglary(False) | 0.99 |
+-----------------+----

Model joint log-probability evaluation with the network


In [11]:
log_prob = lambda sample: net.log_prob( net.set_evidence(sample) )

If possible obtain posterior by enumeration

In [12]:
if net.enumeration_size < 10e6:
    positions, probs = net.posteriors_via_enumeration()
    TARGET = bayesian_networks.as_tensor(positions, probs)
else:
    TARGET = None

## Recording configuration & results

In [13]:
RESULTS = []

# what to put into results    
COLS = ["wall_time", "time", "time_forward", "time_backward",
        "C", "iteration", "temp", "loss", "kl1", "best_loss", "best_kl", "kl2", "kl1_samples", "kl2_samples"]   

In [14]:
def _store_results(RESULTS, COLS, CFG, CFGNAMES, OUT):
    results_pd = pd.DataFrame(RESULTS).rename(columns=dict(enumerate(COLS)))
    for name, val in zip(CFGNAMES, CFG):
        results_pd[name] = str(val)

    logger.info("Writing %i data rows to: %s" % (len(RESULTS), OUT))
    results_pd.to_csv(OUT, header=True, index=False)    

In [15]:
start_time = time.time()
best_loss, best_kl = float("inf"), float("inf")

def record_status(status, iteration, loss): # callback function
    global best_loss, best_kl    
    improved = loss < best_loss    
    if not (improved or iteration<10 or iteration%50==0 or (iteration<50 and iteration%10==0)): return 
    
    # fast evaluation using flows
    kl, kl2, _ = prob_recovery.kl_divergences(status.base, status.flow, TARGET)    
    
    # sanity check: compare against KL evaluated from samples (very slow)
    if iteration%1000==0 and iteration>0:
        EPS = 1e-31
        flow_output_probs = prob_recovery.recover_prob_array_flow_samples(flow, base, nsamples=max(100000, 5*K**N))
        kl11 =  np.sum(flow_output_probs * (np.log(flow_output_probs+EPS)-np.log(TARGET+EPS)))
        kl12 = np.sum(TARGET * (-np.log(flow_output_probs+EPS)+np.log(TARGET+EPS)))   
    else:
        kl11, kl12 = None, None
    
    if improved: best_loss, best_kl = loss, kl
    
    try: C = status.C
    except: C = ""
    temp = status.flow.temperature
    logger.info(("[%.0fs](%s) %s:%i. loss=%.3f " +
           "kl=%.2f~=%s (best: %.3f/%.2f) kl2=%.2f~=%s temp=%.4f\n\tmixing=%s") % \
          (time.time() - start_time, ("*" if improved else " "), C, iteration, 
           loss, kl, kl11, best_loss, best_kl, kl2, kl12,
           temp if temp is not None else float("nan"),
           str(np.round(status.base.mixing_probs,2))[:200]))
    
    RESULTS.append((time.time() - start_time,
                    status.time_forward + status.time_backward, 
                    status.time_forward, status.time_backward,
                    C, iteration, status.flow.temperature,
                    loss, kl, best_loss, best_kl, kl2, kl11, kl12))     

## Approximating distribution: set bases and flows

In [16]:
np.random.seed(SEED)
tf.random.set_seed(SEED)

K = net.cardinality # set automatically
N = net.N-len(EVIDENCE) 
if B is None or B<=0: B = K**N
logger.info("N=%i K=%i B=%i" % (N, K, B))

  6191 N=4 K=2 B=100


Construct bases

In [17]:
base = base_constructors.create_categorical_blocks(N, K, B, BASE_SPECIFICATION)

for i in range(len(base.components)):
    logger.debug("base no%s:\n%s" % (i, np.round(base.components[i].probs, 3)))
    
if INFERENCE==0: base.uniform_mask = True 

Construct flows

In [18]:
if not TRAIN_FLOWS:
    mixture_flows = [flows.DummyFlow(temperature=BASE_TEMP) for _ in range(B)]     
elif FLOW_TYPE in ["F", "FU"]:
    mixture_flows = [flows.DiscreteFlow(N,K, layers=[(FLOW_TYPE, None)]*HIDDEN_LAYERS) 
                     for _ in range(B)]
else:
    HIDDEN_NODES = int(N*HIDDEN_NODES_PER_VARIABLE) 
    mixture_flows = flows.parse_layers_specification([(FLOW_TYPE, [HIDDEN_NODES]*HIDDEN_LAYERS)]*B,
                     N, K, temperature=BASE_TEMP)

flow = flows_mixture.DiscreteFlowsMixture(N, K, B, flows=mixture_flows)

In [19]:
try:
    for i, d1 in enumerate(base.distributions):
        logger.debug("probabilities of component no %i:" % i)
        logger.debug(np.round(d1.get_joint_probability_array(), 3))
except Exception as e:
    logger.warn(" Failed: %s -> Skipping..." % e)
    pass

#logger.info("Trainable variables:\n %s" % "\n ".join([v.name for v in flow.trainable_variables]))

  logger.warn(" Failed: %s -> Skipping..." % e)
  6648  Failed: Outer product currently implemented only for 2D and 3D arrays! -> Skipping...


# Inference

In [20]:
np.random.seed(SEED)
tf.random.set_seed(SEED)

Create an optimizer:

In [21]:
OPTIMIZERS = {"RMS": tf.keras.optimizers.RMSprop,
              "ADAM": tf.keras.optimizers.Adam}
if OPTIMIZER not in OPTIMIZERS: raise ValueError("Unknown optimizer!")
optimizer_class = OPTIMIZERS[OPTIMIZER]
optimizer = optimizer_class(learning_rate=LR)
optimizer_weights = optimizer_class(learning_rate=LR2)
logger.info("optimizer=%s optimizer_weights=%s" % (optimizer, optimizer_weights))

  6661 optimizer=<tensorflow.python.keras.optimizer_v2.rmsprop.RMSprop object at 0x7fa7705b5670> optimizer_weights=<tensorflow.python.keras.optimizer_v2.rmsprop.RMSprop object at 0x7fa7705b5850>


In [22]:
ID2INFERENCE = {
    0: inference.VariationalInference,
    1: inference.BoostingVariationalInference,
    2: inference.IterativeVariationalInference,
    3: inference.BoostingVariationalInferenceAltering,
    4: inference.BoostingVariationalInferenceAlteringIndep,
}
inference_class = ID2INFERENCE[INFERENCE]

vi = inference_class(log_prob=log_prob, base=base, flow=flow, 
 temperature_annealing=inference.TemperatureAnnealingExp(BASE_TEMP,ANNEAL_RATE,MIN_TEMP),
 nsamples=NSAMPLES, max_niter=MAX_NITER, noimprov_niter=NOIMPROV_NITER, 
 optimizer=optimizer, optimizer_weights=optimizer_weights, switch_niter=SWITCH_NITER)
    
logger.info("VI inference type: %s" % vi)    

  6666 VI inference type: <inference.VariationalInference object at 0x7fa740983910>


In [23]:
total_niters = vi.fit(callback=record_status)

  7748 [ERROR][<inference.VariationalInference object at 0x7fa740983910>] Failed to make a copy of base & flow objects: cannot pickle '_thread.RLock' object
  7814 [2s](*) :0. loss=7.122 kl=3.26~=None (best: 7.122/3.26) kl2=1.39~=None temp=10.0000
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
  8297 [2s](*) :1. loss=6.786 kl=2.93~=None (best: 6.786/2.93) kl2=1.23~=None temp=9.9005
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
  8844 [3s](*) :2. loss=6.602 kl=2.74~=None (best: 6.602/2.74) kl2=1.19~=None temp=9.8020
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.0

 28062 [22s]( ) :50. loss=5.036 kl=1.18~=None (best: 4.761/0.90) kl2=1.01~=None temp=6.0653
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
 46176 [40s]( ) :100. loss=5.389 kl=1.53~=None (best: 4.761/0.90) kl2=0.39~=None temp=3.6788
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
 64102 [58s]( ) :150. loss=7.291 kl=3.43~=None (best: 4.761/0.90) kl2=1.62~=None temp=2.2313
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
 82172 [76s]( ) :200. loss=7.882 kl=4.02~=None (best: 4.761/0.90) kl2=7.69~=None temp=1.3534
	m

290319 [284s](*) :770. loss=4.436 kl=0.58~=None (best: 4.436/0.58) kl2=0.30~=None temp=0.0045
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
291476 [285s](*) :773. loss=4.431 kl=0.57~=None (best: 4.431/0.57) kl2=0.30~=None temp=0.0044
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
291918 [286s](*) :774. loss=4.392 kl=0.53~=None (best: 4.392/0.53) kl2=0.30~=None temp=0.0044
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
292352 [286s](*) :775. loss=4.390 kl=0.53~=None (best: 4.390/0.53) kl2=0.30~=None temp=0.00

594452 [588s]( ) :1600. loss=4.336 kl=0.48~=None (best: 4.334/0.48) kl2=0.22~=None temp=0.0010
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
612450 [606s]( ) :1650. loss=4.336 kl=0.48~=None (best: 4.334/0.48) kl2=0.22~=None temp=0.0010
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
630557 [624s]( ) :1700. loss=4.336 kl=0.48~=None (best: 4.334/0.48) kl2=0.22~=None temp=0.0010
	mixing=[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01
 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.
648725 [643s]( ) :1750. loss=4.336 kl=0.48~=None (best: 4.334/0.48) kl2=0.22~=None temp=

In [24]:
_store_results(RESULTS, COLS, CFG, CFGNAMES, OUT)

677687 Writing 87 data rows to: BN_AVI_bnets_earthquake_p_01_100_0.csv


In [25]:
# EPS = 1e-31

# def kl_divs(flow, base, target, nsamples=100000):
#     #sample = q.sample(nsamples)
#     #output_probs = prob_recovery.recover_prob_array_tf_one_hot( sample )    
#     output_probs = prob_recovery.recover_prob_array_flow_samples(flow, base, nsamples)
#     #print(output_probs)
    
#     kl =  np.sum(output_probs * (np.log(output_probs+EPS)-np.log(target+EPS)))
#     kl2 = np.sum(target * (-np.log(output_probs+EPS)+np.log(target+EPS)))        
#     return kl, kl2

In [26]:
# print( kl_divs(vi.best_flow, vi.best_base, TARGET, 1000000) )

In [27]:
# if TARGET is not None:
#     base, flow = vi.best_base, vi.best_flow
#     kl, kl2, flow_output_probs = prob_recovery.kl_divergences(base, flow, TARGET)
#     logger.info("kl=%.3f kl2=%.3f" % (kl, kl2))

#     mask=(flow_output_probs+TARGET>1e-6)
#     print(np.round(flow_output_probs[mask].reshape(-1), 3), "\n", np.round(TARGET[mask].reshape(-1), 3))

In [28]:
## Time measurements
# print(time_profiling.get_report())