In [3]:
import numpy as np
import time
import copy
import sys
sys.path.append('/home/ziniu.wzn/BayesCard')
import pandas as pd
import time
import bz2
import pickle
import logging
import ast

#from DataPrepare.join_data_preparation import JoinDataPreparator
from Models.pgmpy_BN import Pgmpy_BN

In [4]:
with open('clt.pkl', 'rb') as f:
    BN = pickle.load(f)

In [41]:
import itertools
import networkx as nx
import numpy as np
import time
from tqdm import tqdm
from collections import defaultdict
from itertools import chain
import copy
from Pgmpy.inference import Inference
from Pgmpy.factors import factor_product
from Pgmpy.models import BayesianModel, JunctionTree
from Pgmpy.inference.EliminationOrder import (
    WeightedMinFill,
    MinNeighbors,
    MinFill,
    MinWeight,
)
from Pgmpy.factors.discrete import TabularCPD


class VariableEliminationJIT(object):
    def __init__(self, model, cpds, topological_order, topological_order_node, probs=None, root=True):
        model.check_model()
        self.cpds = cpds
        self.topological_order = topological_order
        self.topological_order_node = topological_order_node
        self.model = model
        if probs is not None:
            self.probs = probs
        else:
            self.probs = dict()


        self.variables = model.nodes()

        self.cardinality = {}
        self.factors = defaultdict(list)

        if isinstance(model, BayesianModel):
            self.state_names_map = {}
            for node in model.nodes():
                cpd = model.get_cpds(node)
                if isinstance(cpd, TabularCPD):
                    self.cardinality[node] = cpd.variable_card
                    cpd = cpd.to_factor()
                for var in cpd.scope():
                    self.factors[var].append(cpd)
                self.state_names_map.update(cpd.no_to_name)
        else:
            assert False, "ExactCLT does not support models other than Discrete BN"

        if root:
            self.root = self.get_root()

    def get_root(self):
        """Returns the network's root node."""

        def find_root(graph, node):
            predecessor = next(self.model.predecessors(node), None)
            if predecessor:
                root = find_root(graph, predecessor)
            else:
                root = node
            return root

        return find_root(self, list(self.model.nodes)[0])

    def steiner_tree(self, nodes):
        """Returns the minimal part of the tree that contains a set of nodes."""
        sub_nodes = set()

        def walk(node, path):
            if len(nodes) == 0:
                return

            if node in nodes:
                sub_nodes.update(path + [node])
                nodes.remove(node)

            for child in self.model.successors(node):
                walk(child, path + [node])

        walk(self.root, [])
        sub_graph = self.model.subgraph(sub_nodes)
        sub_graph.cardinalities = defaultdict(int)
        for node in sub_graph.nodes:
            sub_graph.cardinalities[node] = self.model.cardinalities[node]
        return sub_graph


    def _get_working_factors(self, query=None, return_probs=False, reduce=True):
        """
        Uses the evidence given to the query methods to modify the factors before running
        the variable elimination algorithm.
        Parameters
        ----------
        evidence: dict
            Dict of the form {variable: state}
        Returns
        -------
        dict: Modified working factors.
        """
        useful_var = list(query.keys())
        sub_graph_model = self.steiner_tree(useful_var)

        elimination_order = []
        working_cpds = []
        working_factors = dict()
        for i, node in enumerate(self.topological_order_node[::-1]):
            ind = len(self.topological_order_node)-i-1
            if node in sub_graph_model.nodes:
                elimination_order.append(node)
                cpd = copy.deepcopy(self.cpds[ind])
                working_cpds.append(cpd)
                working_factors[node] = [cpd]

        for node in sub_graph_model.nodes:
            for cpd in working_cpds:
                if node != cpd.variable and node in cpd.variables:
                    working_factors[node].append(cpd)

        return working_factors, sub_graph_model, elimination_order


    def query(self, query, n_distinct=None):
        """
        Compiles a ppl program into a fixed linear algebra program to speed up the inference
        ----------
        query: dict
            a dict key, value pair as {var: state_of_var_observed}
            None if no evidence
        n_distinct: dict
            a dict key, value pair as {var: probability of observed value in state}
            This is for the case, where we bin the continuous or large domain so each state now contains many observed
            value. Default to none, meaning no large domain.
        """
        working_factors, sub_graph_model, elimination_order = self._get_working_factors(query)
        for i, var in enumerate(elimination_order):
            root_var = i == (len(elimination_order) - 1)
            if len(working_factors[var]) == 1:
                #leaf node in BN
                if var in query:
                    new_value = working_factors[var][0].values
                    if n_distinct:
                        new_value = np.dot(n_distinct[var], new_value[query[var]])
                    else:
                        new_value = np.sum(new_value[query[var]], axis=0)
                    if root_var:
                        return new_value
                else:
                    if root_var:
                        return 1
                    new_value = np.ones(working_factors[var][0].values.shape[-1])

                assert len(new_value.shape) == 1, f"unreduced variable {var}"
                working_factors[var][0].values = new_value
            else:
                if var in query:
                    self_value = working_factors[var][0].values[query[var]]  #Pr(var|Parent(var))
                    if n_distinct:
                        self_value = (self_value.transpose() * n_distinct[var]).transpose()
                    children_value = []
                    #check if all children has been reduced
                    for cpd in working_factors[var][1:]:
                        #print("y")
                        child_value = cpd.values[query[var]]    #M(var) = Pr(child(var)|var)
                        assert len(child_value.shape) == 1, "unreduced children"
                        children_value.append(child_value)
                    if len(children_value) == 1:
                        children_value = children_value[0]
                    else:
                        #print(children_value)
                        children_value = np.prod(np.stack(children_value), axis=0)
                    if root_var:
                        new_value = np.dot(self_value, children_value)
                        return new_value
                    new_value = np.dot(np.transpose(self_value), children_value)
                else:
                    self_value = working_factors[var][0].values  # Pr(var|Parent(var))
                    children_value = []
                    # check if all children has been reduced
                    for cpd in working_factors[var][1:]:
                        child_value = cpd.values  # M(var) = Pr(child(var)|var)
                        assert len(child_value.shape) == 1, "unreduced children"
                        children_value.append(child_value)
                    if len(children_value) == 1:
                        children_value = children_value[0]
                    else:
                        children_value = np.prod(np.stack(children_value), axis=0)
                    if root_var:
                        new_value = np.dot(self_value, children_value)
                        return new_value
                    new_value = np.dot(np.transpose(self_value), children_value)
                assert len(new_value.shape) == 1, f"unreduced variable {var}"
                working_factors[var][0].values = new_value
        return 0

In [42]:
def align_cpds_in_topological(BN):
    cpds = BN.model.cpds
    sampling_order = []
    while len(sampling_order) < len(BN.structure):
        for i, deps in enumerate(BN.structure):
            if i in sampling_order:
                continue  # already ordered
            if all(d in sampling_order for d in deps):
                sampling_order.append(i)
    topological_order = sampling_order
    topological_order_node = [BN.node_names[i] for i in sampling_order]
    new_cpds = []
    for n in topological_order_node:
        for cpd in cpds:
            if cpd.variable == n:
                new_cpds.append(cpd)
                break
    assert len(cpds) == len(new_cpds)
    return new_cpds, topological_order, topological_order_node

In [34]:
with open('/home/ziniu.wzn/BN_checkpoints/check_points/Census_chow-liu.pkl', 'rb') as f:
    BN = pickle.load(f)
BN.init_inference_method()

In [43]:

cpds, topological_order, topological_order_node = align_cpds_in_topological(BN)
ve = VariableEliminationJIT(BN.model, cpds, topological_order, topological_order_node)

In [36]:
topological_order_node

['dAge',
 'dIncome5',
 'dIncome7',
 'iRrelchld',
 'iRspouse',
 'iSchool',
 'iYearsch',
 'iYearwrk',
 'iDisabl2',
 'iLang1',
 'iMarital',
 'iMobility',
 'dOccup',
 'iRelat1',
 'iRelat2',
 'iRemplpar',
 'iRlabor',
 'iRownchld',
 'iSubfam1',
 'iSubfam2',
 'iTmpabsnt',
 'iWork89',
 'iWorklwk',
 'dAncstry1',
 'dAncstry2',
 'iDisabl1',
 'iEnglish',
 'iFertil',
 'dHispanic',
 'dIndustry',
 'iLooking',
 'iMeans',
 'iMilitary',
 'iMobillim',
 'iPerscare',
 'dPOB',
 'dPoverty',
 'iRagechld',
 'iRiders',
 'iRPOB',
 'iRvetserv',
 'iSept80',
 'iSex',
 'dTravtime',
 'iVietnam',
 'dWeek89',
 'iWWII',
 'dYrsserv',
 'iAvail',
 'iCitizen',
 'iClass',
 'dDepart',
 'iFeb55',
 'dHours',
 'iImmigr',
 'dIncome2',
 'dIncome3',
 'iKorean',
 'iMay75880',
 'iOthrserv',
 'dPwgt1',
 'dRearning',
 'dRpincome',
 'dHour89',
 'dIncome1',
 'dIncome4',
 'dIncome6',
 'dIncome8']

In [21]:
BN.init_inference_method()
#with open("../Benchmark/DMV/query.sql") as f:
 #   queries = f.readlines()
with open("/home/ziniu.wzn/Census/cardinality/query_one_side.sql") as f:
    queries = f.readlines()
from Evaluation.cardinality_estimation import parse_query_single_table

In [44]:
import copy
query_str = queries[147].split("||")[0]
query_str = parse_query_single_table(query_str.strip(), BN)
print(query_str)
print(BN.query(copy.deepcopy(query_str)))
query_str, n_distinct = BN.query_decoding(query_str, None)
print(query_str)
tic = time.time()
a = ve.query(query_str, n_distinct)
print(time.time()-tic)
print(a*BN.nrows)

{'iRelat2': [0], 'dTravtime': [1, 2, 0, 3], 'iYearsch': [11, 5, 10, 4, 8, 1, 7, 6, 2, 9, 3]}
1490293
{'iRelat2': [0], 'dTravtime': [3, 5, 0, 4], 'iYearsch': [1, 2, 0, 4, 7, 8, 6, 9, 15, 11, 14]}
0.002078533172607422
1490292.8364073704


In [30]:
(np.ones((3,5,6)).transpose() * np.asarray([1,2,3])).transpose()

array([[[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]],

       [[2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.]],

       [[3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3.]]])

In [31]:
np.dot(np.asarray([1,2,3]), np.ones((3,5,6))).shape

ValueError: shapes (3,) and (3,5,6) not aligned: 3 (dim 0) != 5 (dim 1)

In [19]:
query_str = queries[147].split("||")[0]
print(query_str)
parse_query_single_table(query_str.strip(), BN)

SELECT COUNT(*) FROM DMV WHERE Registration_Class IN [PAS, COM, LTR, BOT, MOT, TRL, SRF, PSD, OMT, OMS, ATV, SNO] AND Body_Type IN [SUBN, 4DSD, PICK, 2DSD, BOAT, LTRL, MCY, TRLR, VAN, CONV, ATV, SNOW, DUMP, BUS, H/TR, UTIL, SEMI, TRAC, H/WH, DELV, FLAT, P/SH, STAK, TANK, MOPD, TAXI, TOW, SEDN, REFG] AND Scofflaw_Indicator == N


{}

In [None]:
#with open("../Benchmark/DMV/query.sql") as f:
 #   queries = f.readlines()
with open("/home/ziniu.wzn/Census/cardinality/query_one_side.sql") as f:
    queries = f.readlines()

res = []
cards = []
lat = []
for q in queries:
    query_str = q.split("||")[0]
    card = int(q.split("||")[-1])
    cards.append(card)
    query_str = parse_query_single_table(query_str.strip(), BN)
    #print(BN.query(query_str))
    query_str, n_distinct = BN.query_decoding(query_str, None)
    tic=time.time()
    a = ve.query(query_str)
    print(time.time()-tic)
    lat.append(time.time()-tic)
    res.append(a*BN.nrows)

In [None]:
print(f"average latency: {int(np.mean(lat)*100000)/100}ms")
pred = np.asarray(res)
pred[pred<=1.0] = 1
cards = np.asarray(cards)
errors = np.maximum(np.divide(pred, cards), np.divide(pred, cards))
for i in [50,90,95,99,100]:
    print(f"{i}% quantile: {np.percentile(errors, i)}")

In [None]:
len(pred)