In [154]:
import random, json
from itertools import product as iterprod
from collections import defaultdict
import networkx as nx
from pyvis.network import Network
from pgmpy.factors.discrete import TabularCPD
#from pgmpy.factors.continuous import ContinuousFactor
from pgmpy.models import BayesianNetwork
from pgmpy.inference import VariableElimination
from pgmpy.inference import BeliefPropagation
from ipywidgets import interact, interactive, interact_manual, Layout, Accordion
import ipywidgets as widgets

In [433]:
class ProbvalError(Exception):
    pass
class NodenameError(Exception):
    pass
    
class Model:
    def __init__(self):
        self.nodes = []
    def name_exists(self, name):
        return name in [i.name for i in self.nodes]
    def get_id(self):
        if self.nodes:
            return max([i.id for i in self.nodes]) + 1
        else:
            return 0
    def add_node(self, name):
        new_node = Node(self, name)
        self.nodes.append(new_node)
        return new_node
    def get_node(self, node_id):
        found_node = None
        for n in self.nodes:
            if n.id == node_id:
                found_node = n
                break
        return found_node
    def get_node_byname(self, node_name):
        found_node = None
        for n in self.nodes:
            if n.name == node_name:
                found_node = n
                break
        return found_node
    def dump(self, fname):
        outmodel = {
            'nodenames': {},
            'graph': []
        }
        for n in self.nodes:
            outmodel['nodenames'][n.id] = n.name
            outmodel['graph'].append({
                'id': n.id,
                'edges_out': [i.id for i in n.edges_out],
                'edges_in': [i.id for i in n.edges_in],
                'unit_long': n.unit_long,
                'unit_short': n.unit_short,
                'values': n.values,
                'probs': [p.probstring() for p in n.probs]
            })
        print(outmodel)
        out_j = json.dumps(outmodel)
        ofile = open(fname, 'w')
        ofile.write(out_j)
        ofile.close()
        return out_j
            
class Prob:
    def __init__(self, input_pairs, output, prob):
        self.inputs = input_pairs
        self.output = output
        self.prob = prob
    def probstring(self):
        input_pairs_strings = ['{0};{1}'.format(str(i[0]), str(i[1])) for i in self.inputs]
        input_pairs_strings.sort()
        return '{0}:{1}@{2}'.format(str(self.output), ';;'.join(input_pairs_strings), str(self.prob))
    
class Node:
    def __init__(self, model, name):
        if ';' in name or '@' in name or ':' in name:
            raise NodenameError
        self.model = model
        self.id = model.get_id()
        self.name = name.lower().replace('"', '')
        self.edges_out = []
        self.edges_in = []
        self.unit_long = None
        self.unit_short = None
        self.values = []
        self.probs = []
    def set_unit(self, unit_long=None, unit_short=None):
        self.unit_long = unit_long
        self.unit_short = unit_short
    def set_values(self, value_list):
        for i in value_list:
            if not i in self.values:
                self.values.append(i)
    def get_prob(self, input_pairs, output):   
        input_pairs_strings = ['{0};{1}'.format(str(i[0]), str(i[1])) for i in input_pairs]
        input_pairs_strings.sort()
        teststring = '{0}:{1}'.format(str(output), ';;'.join(input_pairs_strings))
        for p in self.probs:
            if p.probstring == teststring:
                return p
        return None
        
    def set_prob(self, input_pairs, output, prob):
        if prob < 0 or prob > 1:
            raise ProbvalError
        p = self.get_prob(input_pairs, output)
        if p:
            p.prob = prob
        else:
            p = Prob(inputs, output, prob)

In [441]:
def load_model(fname):
    ifile = open(fname, 'r')
    model_json = ifile.read()
    ifile.close()
    inmodel = json.loads(model_json)
    new_model = Model()
    new_nodes = {}
    for n in inmodel['graph']:
        new_node = Node(new_model, inmodel['nodenames'][str(n['id'])])
        new_node.id = n['id']
        new_node.unit_long = n['unit_long']
        new_node.unit_short = n['unit_short']
        new_node.values = n['values'][:]
        for p in n['probs']:
            items_colon = p.split(':')
            output = items_colon[0]
            items_colon_at = items_colon[1].split('@')
            prob = float(items_colon_at[1])
            input_pair_strings = (items_colon_at[0]).split(';;')
            inputs = [i.split(';') for i in input_pair_strings]
            new_node.probs.append(Prob(inputs, output, prob))
        new_nodes[n['id']] = new_node
        new_model.nodes.append(new_node)
    for n in inmodel['graph']:
        for outnode_id in n['edges_out']:
            connect_nodes(new_nodes[n['id']], new_nodes[outnode_id])
    return new_model

In [442]:
M=load_model('testdump.json')

{'nodenames': {'3': 'fuel mix', '4': 'fuel sales', '0': 'ghg', '2': 'temperature', '5': 'vehicle emissions', '1': 'vmt'}, 'graph': [{'id': 3, 'edges_out': [5], 'edges_in': [], 'unit_long': None, 'unit_short': None, 'values': ['high emission', 'low emission'], 'probs': ['high emission:@0.5', 'low emission:@0.5']}, {'id': 4, 'edges_out': [], 'edges_in': [1], 'unit_long': None, 'unit_short': None, 'values': ['high sales', 'moderate sales', 'low sales'], 'probs': ['high sales:1;very high@0.8', 'high sales:1;high@0.7', 'high sales:1;moderate@0.3', 'high sales:1;low@0.1', 'moderate sales:1;very high@0.15', 'moderate sales:1;high@0.2', 'moderate sales:1;moderate@0.5', 'moderate sales:1;low@0.2', 'low sales:1;very high@0.05', 'low sales:1;high@0.1', 'low sales:1;moderate@0.2', 'low sales:1;low@0.7']}, {'id': 0, 'edges_out': [2], 'edges_in': [5], 'unit_long': None, 'unit_short': None, 'values': ['high', 'moderate', 'low'], 'probs': ['high:5;high@0.6', 'high:5;moderate@0.3', 'high:5;low@0.1', 'm

In [4]:
def connect_nodes(from_node, to_node):
    if not to_node in from_node.edges_out:
        from_node.edges_out.append(to_node)
    if not from_node in to_node.edges_in:
        to_node.edges_in.append(from_node)
def disconnect_nodes(from_node, to_node):
    if to_node in from_node.edges_out:
        from_node.edges_out.remove(to_node)
    if from_node in to_node.edges_in:
        to_node.edges_in.remove(from_node)
    for p in to_node.probs:
        if from_node.id in [i.split(';')[0] for i in p.inputs]:
            to_node.probs.remove(p)

In [181]:
M = Model()
g = M.add_node('ghg')
d = M.add_node('vmt')
connect_nodes(d, g)

In [182]:
# EDIT NODES

def add_nodes_from_box(e):
    namelist = [i.strip().lower().replace('"','') for i in node_textinput.value.split(',')]
    for name in namelist:
        M.add_node(name)
    node_textinput.value = ''
    refresh_nodelist()
    
current_nodelist = widgets.HTML(value='')
node_textinput = widgets.Textarea(
        value='',
        placeholder='New nodes here, separated by commas',
        description='',
        disabled=False
    )
addnodes_button = widgets.Button(
        description='Add nodes',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Add nodes',
        icon='' # (FontAwesome names without the `fa-` prefix)
    )
addnodes_button.on_click(add_nodes_from_box)
editnode_box = widgets.HBox([current_nodelist, node_textinput, addnodes_button])
def refresh_nodelist():
    nodelist = M.nodes
    nodelist.sort(key=lambda x: x.name)
    html = '<div style="padding-right:20px"><b>Current nodes</b><ul style="list-style-type: none"><li>{}</li></ul></div>'.format('</li><li>'.join([i.name for i in nodelist]))
    current_nodelist.set_trait('value', html)
    #focal_node = M.get_node(int(focus_selector.value.split(':')[0]))
    setup_connectormenu()
    update_matrix()

def setup_editnodes():
    refresh_nodelist()
    
setup_editnodes()
editnode_box

HBox(children=(HTML(value='<div style="padding-right:20px"><b>Current nodes</b><ul style="list-style-type: non…

In [184]:
# EDIT EDGES

in_checkboxes = {}
out_checkboxes = {}
listen_connections = False

nodelist = M.nodes
nodelist.sort(key=lambda x:x.name)
connection_focalnode = nodelist[random.randint(0, len(nodelist)-1)]

focus_selector = None

connector_menu = widgets.HBox([widgets.HTML(), widgets.HTML(), widgets.HTML()])

def update_connections_in(e):
    if listen_connections:
        name = e['owner'].trait_values()['description']
        in_node = M.get_node_byname(name)
        focal_node = connection_focalnode
        if in_checkboxes[in_node.id].value:
            connect_nodes(in_node, focal_node)
        else:
            disconnect_nodes(in_node, focal_node)
        update_matrix()
def update_connections_out(e):
    if listen_connections:
        name = e['owner'].trait_values()['description']
        out_node = M.get_node_byname(name)
        focal_node = connection_focalnode
        if out_checkboxes[out_node.id].value:
            connect_nodes(focal_node, out_node)
        else:
            disconnect_nodes(focal_node, out_node)
        update_matrix()

def update_connectionmenus(e):
    global connection_focalnode
    global listen_connections
    listen_connections = False
    connection_focalnode = M.get_node(int(e['new'].split(':')[0]))
    for n in M.nodes:
        in_checkboxes[n.id].value = n in connection_focalnode.edges_in
        in_checkboxes[n.id].disabled = n == connection_focalnode
        out_checkboxes[n.id].value = n in connection_focalnode.edges_out
        out_checkboxes[n.id].disabled = n == connection_focalnode
    listen_connections = True
    
def setup_connectormenu():
    global focus_selector
    global connector_menu
    global listen_connections
    global in_checkboxes
    global out_checkboxes
    
    focus_selector = widgets.Dropdown(
            options=['{0}: {1}'.format(str(i.id), i.name) for i in nodelist],
            value='{0}: {1}'.format(str(connection_focalnode.id), connection_focalnode.name),
            description='Focal node',
            disabled=False,
        )
    focus_selector.observe(update_connectionmenus, names='value')
    
    options = M.nodes
    options.sort(key=lambda x:x.name)
    
    in_checkboxes = {}
    out_checkboxes = {}
    for n in options:
        in_checkboxes[n.id] = widgets.Checkbox(
            value = n in connection_focalnode.edges_in,
            disabled = n == connection_focalnode,
            description = n.name,
            indent = False
        )
        in_checkboxes[n.id].observe(update_connections_in, names='value')
        
        out_checkboxes[n.id] = widgets.Checkbox(
            value = n in connection_focalnode.edges_out,
            disabled = n == connection_focalnode,
            description = n.name,
            indent = False
        )
        out_checkboxes[n.id].observe(update_connections_out, names='value')
    listen_connections = True

    connector_menu = widgets.HBox([widgets.VBox(list(in_checkboxes.values())), focus_selector, widgets.VBox(list(out_checkboxes.values()))])

In [186]:
setup_connectormenu()
connector_menu

HBox(children=(VBox(children=(Checkbox(value=True, description='fuel mix', indent=False), Checkbox(value=False…

In [205]:
# CHECK FOR CYCLES

def find_cycles():
    current_graph = nx.DiGraph()
    edge_list = []
    for n in M.nodes:
        if n.edges_out:
            edge_list += [(n.name, i.name) for i in n.edges_out]
    current_graph.add_edges_from(edge_list)
    cycles = list(nx.simple_cycles(current_graph))
    if cycles:
        for cycle in cycles:
            print('{0} -> {1}'.format(' -> '.join(cycle), cycle[0]))
    else:
        print('No cycles found')

find_cycles()

No cycles found


In [196]:
list(c)

[]

In [183]:
# CONNECTIVITY MATRIX

matrix = widgets.HTML(value='')
def update_matrix():
    global matrix
    nodelist = M.nodes
    nodelist.sort(key=lambda x:x.name)
    html = '<table style="border:1px solid gray; border-collapse:collapse"><tr><th style="padding:10px; border:1px solid gray; border-collapse:collapse"></th><th style="padding:10px">{}</th></tr>'.format('</th><th style="padding:10px; border:1px solid gray; border-collapse:collapse">'.join([i.name for i in nodelist]))
    for n in nodelist:
        html += '<tr><th style="padding:10px; border:1px solid gray; border-collapse:collapse">{0}</th><td style="padding:10px; border:1px solid gray; border-collapse:collapse">{1}</td></tr>'.format(n.name, '</td><td style="padding:10px; border:1px solid gray; border-collapse:collapse">'.join([['', 'X'][int(n in i.edges_in)] for i in nodelist]))
    html += '</table>'
    matrix.set_trait('value', html)
update_matrix()
matrix

HTML(value='<table style="border:1px solid gray; border-collapse:collapse"><tr><th style="padding:10px; border…

In [447]:
# GRAPH VISUALIZATION

edge_list = []
for n in M.nodes:
    if n.edges_out:
        edge_list += [(n.id, i.id) for i in n.edges_out]

g = Network(height=800, width=800, directed=True, notebook=True)
g.toggle_hide_edges_on_drag(False)
for n in M.nodes:
    g.add_node(n.id, label=n.name, color='white', shape='box')
for e in edge_list:
    g.add_edge(*e, arrowStrikethrough=False, color='black')
g.repulsion(node_distance=100, spring_length=200)
g.show('currentgraph.html')


In [206]:
# EDIT VALUES

nodelist = M.nodes
nodelist.sort(key=lambda x:x.name)
values_focalnode = nodelist[random.randint(0, len(nodelist)-1)]

def update_valueseditor(e):
    global values_focalnode
    values_focalnode = M.get_node(int(e['new'].split(':')[0]))
    values_editor.value = '\n'.join(values_focalnode.values)

def update_values(e):
    valstring = values_editor.value
    values_focalnode.values = [i.strip().lower().replace('"','') for i in valstring.split('\n')]
    savevalues_button.disabled = True
    
def update_buttonstate(e):
    savevalues_button.disabled = False

focus_selector_values = widgets.Dropdown(
        options = ['{0}: {1}'.format(str(i.id), i.name) for i in nodelist],
        value = '{0}: {1}'.format(str(values_focalnode.id), values_focalnode.name),
        description = 'Focal node',
        disabled = False,
    )
focus_selector_values.observe(update_valueseditor, names='value')
values_editor = widgets.Textarea(
        value = '\n'.join(values_focalnode.values),
        placeholder = 'Enter linebreak-separated values',
        description = '',
        disabled=False
    )
values_editor.observe(update_buttonstate, names='value')
savevalues_button = widgets.Button(
        description='Save values',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Save values',
        icon='' # (FontAwesome names without the `fa-` prefix)
    )
savevalues_button.on_click(update_values)

values_box = widgets.HBox([focus_selector_values, values_editor, savevalues_button])
values_box

HBox(children=(Dropdown(description='Focal node', index=1, options=('3: fuel mix', '4: fuel sales', '0: ghg', …

In [235]:
# EDIT PROBS


nodelist = M.nodes
nodelist.sort(key=lambda x:x.name)
probs_focalnode = nodelist[random.randint(0, len(nodelist)-1)]
values_combos_display = widgets.HTML(value='')

def update_probseditor(e = None):
    global probs_focalnode
    if e:
        probs_focalnode = M.get_node(int(e['new'].split(':')[0]))
    
    input_nodes = probs_focalnode.edges_in
    input_values = [i.values for i in input_nodes]
    all_values = [probs_focalnode.values] + input_values
    all_combos = list(iterprod(*all_values))
    
    html = '<table><tr><th style="padding-left:5px;padding-right:5px">{}</th><th style="padding-left:5px;padding-right:5px">Prob</th></tr>'.format('</th><th style="padding-left:5px;padding-right:5px">'.join([probs_focalnode.name] + [i.name for i in input_nodes]))
    for c in all_combos:
        html += '<tr><td style="padding-left:5px;padding-right:5px">{0}</td><td style="padding-left:5px;padding-right:5px">{1}</td></tr>'.format('</td><td style="padding-left:5px;padding-right:5px">'.join(c), ['missing', '?'][int(True)])
    html += '</table>'
    values_combos_display.set_trait('value', html)
    
    prob_template = open('prob template {}.txt'.format(probs_focalnode.name), 'w')
    prob_template.write('\t'.join([probs_focalnode.name] + [i.name for i in input_nodes]))
    prob_template.write('\n')
    for c in all_combos:
        prob_template.write('\t'.join(c))
        prob_template.write('\n')
    prob_template.close()
    
def download_probtemplate(e):
    input_nodes = probs_focalnode.edges_in
    input_values = [i.values for i in input_nodes]
    all_values = [probs_focalnode.values] + input_values
    all_combos = list(iterprod(*all_values))
    
    prob_template = open('prob template {}.txt'.format(probs_focalnode.name), 'w')
    prob_template.write('\t'.join([probs_focalnode.name] + [i.name for i in input_nodes] + ['Prob']))
    prob_template.write('\n')
    for c in all_combos:
        prob_template.write('\t'.join(c))
        prob_template.write('\n')
    prob_template.close()
    
def upload_probtemplate(e):
    prob_template = open('prob template {} completed.txt'.format(probs_focalnode.name), 'r')
    ilines = prob_template.readlines()
    prob_template.close()
    inodes = [M.get_node_byname(i) for i in [j.strip() for j in ilines[0].split('\t')[:-1]]]
    probs_focalnode.probs = []
    for line in ilines[1:]:
        items = [i.strip() for i in line.split('\t')]
        valpairs = []
        for n_idx in range(1, len(inodes)):
            inode = inodes[n_idx]
            ival = items[n_idx]
            valpairs.append((inode.id, ival))
        prob = float(items[-1])
        probs_focalnode.probs.append(Prob(valpairs, items[0], prob))

focus_selector_probs = widgets.Dropdown(
        options = ['{0}: {1}'.format(str(i.id), i.name) for i in nodelist],
        value = '{0}: {1}'.format(str(probs_focalnode.id), probs_focalnode.name),
        description = 'Focal node',
        disabled = False,
    )
update_probseditor()
download_prob_template_button = widgets.Button(
        description='Download template',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Download probability template',
        icon='' # (FontAwesome names without the `fa-` prefix)
    )
download_prob_template_button.on_click(download_probtemplate)
upload_prob_template_button = widgets.Button(
        description='Upload template',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Upload completed probability template',
        icon='' # (FontAwesome names without the `fa-` prefix)
    )
upload_prob_template_button.on_click(upload_probtemplate)
focus_selector_probs.observe(update_probseditor, names='value')
widgets.HBox([widgets.VBox([focus_selector_probs, download_prob_template_button, upload_prob_template_button]), values_combos_display])


HBox(children=(VBox(children=(Dropdown(description='Focal node', options=('3: fuel mix', '4: fuel sales', '0: …

['fuel mix\t\n', 'high emission\t0.4\n', 'low emission\t0.6\n']
['fuel mix\t\n', 'high emission\t0.4\n', 'low emission\t0.6\n']
['fuel sales\tvmt\t\n', 'high sales\tvery high\t0.8\n', 'high sales\thigh\t0.7\n', 'high sales\tmoderate\t0.3\n', 'high sales\tlow\t0.1\n', 'moderate sales\tvery high\t0.15\n', 'moderate sales\thigh\t0.2\n', 'moderate sales\tmoderate\t0.5\n', 'moderate sales\tlow\t0.3\n', 'low sales\tvery high\t0.05\n', 'low sales\thigh\t0.1\n', 'low sales\tmoderate\t0.2\n', 'low sales\tlow\t0.7\n']
['ghg\tvehicle emissions\t\n', 'high\thigh\t0.6\n', 'high\tmoderate\t0.3\n', 'high\tlow\t0.1\n', 'moderate\thigh\t0.3\n', 'moderate\tmoderate\t0.5\n', 'moderate\tlow\t0.2\n', 'low\thigh\t0.1\n', 'low\tmoderate\t0.2\n', 'low\tlow\t0.7\n']
['temperature\tghg\t\n', 'very high\thigh\t0.4\n', 'very high\tmoderate\t0.2\n', 'very high\tlow\t0.1\n', 'high\thigh\t0.4\n', 'high\tmoderate\t0.4\n', 'high\tlow\t0.2\n', 'moderate\thigh\t0.2\n', 'moderate\tmoderate\t0.4\n', 'moderate\tlow\t0.7\n'

In [446]:
# RUN SCENARIOS

def run_query(e=None):
    global current_states
    if e:
        node_name = e['owner'].trait_values()['description']
        node = M.get_node_byname(node_name)
        new_value = [e['new'], None][int(e['new']=='no data')]
        current_states[node.name] = new_value

    nodelist_a = M.nodes
    card = {n.name: len(n.values) for n in nodelist_a}
    edgelist = []
    for n in nodelist_a:
        edgelist += [(n.name, i.name) for i in n.edges_out]
    model = BayesianNetwork(edgelist)
    cpds = []
    for n in nodelist_a:
        probs = defaultdict(list)
        for prob in n.probs:
            probs[prob.output].append(prob.prob)
        table = [probs[output] for output in n.values]
        involved_nodes = [n]
        if n.edges_in:
            involved_nodes += n.edges_in
        cpds.append(
            TabularCPD(
            variable = n.name,
            variable_card = card[n.name],
            values = table,
            evidence = [i.name for i in n.edges_in],
            evidence_card = [card[i] for i in [i.name for i in n.edges_in]],
            state_names = {i.name: [str(j) for j in i.values] for i in involved_nodes}
            )
        )
    model.add_cpds(*cpds)    
    infer = VariableElimination(model)

    nodelist = M.nodes
    nodelist.sort(key=lambda x:x.name)
    results = {}
    for i in nodelist:
        if current_states[i.name] is not None:
            results[i] = [int(current_states[i.name] == j) for j in i.values]
        else:
            infer = VariableElimination(model)
            results[i] = infer.query([i.name], show_progress=False, evidence = {j: current_states[j] for j in [k.name for k in nodelist] if ((current_states[j] is not None) and (i.name !=j))}).values
    html_results = '<table style="border:1px solid gray; border-collapse:collapse"><tr style="border:1px solid gray; border-collapse:collapse"><th style="border:1px solid gray; border-collapse:collapse;padding-left:5px;padding-right:5px">factor</th><th style="border:1px solid gray; border-collapse:collapse;padding-left:5px;padding-right:5px">state</th><th style="border:1px solid gray; border-collapse:collapse;padding-left:5px;padding-right:5px">prob</th></tr>'
    for n in nodelist:
        node_results = results[n]
        max_prob = max(node_results) * 100
        first_row = True
        for v_idx in range(len(n.values)):
            v = n.values[v_idx]
            html_results += '<tr style="border:1px solid gray;border-collapse:collapse">'
            if first_row:
                html_results += '<th rowspan="{0}" style="padding-left:5px;padding-right:5px">{1}</th>'.format(str(len(n.values)), n.name)
                first_row = False
            html_results += '<th style="border:1px solid gray; border-collapse:collapse;padding-left:5px;padding-right:5px{0}">{1}</th><td style="border:1px solid gray; border-collapse:collapse;padding-left:5px;padding-right:5px">{2:.1f}</td></tr>'.format(['', ';background-color:#ffff00;font-weight:bold'][int('{0:.1f}'.format(results[n][v_idx] * 100) == '{0:.1f}'.format(max_prob))], v, results[n][v_idx] * 100)
    html_results += '</table>'
    probdist_resultbox.set_trait('value', html_results)
    
current_states = {'n.id': None for n in nodelist}
state_buttonsets = []
nodelist = M.nodes
nodelist.sort(key=lambda x:x.name)

for n in nodelist:
    state_buttons = widgets.ToggleButtons(
        options = ['no data'] + n.values,
        description = n.name,
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltips=[],
    )
    state_buttons.observe(run_query, names='value')
    state_buttonsets.append(state_buttons)

current_states = {n.name: None for n in nodelist}
state_toggles_box = widgets.VBox(state_buttonsets)
probdist_resultbox = widgets.HTML(value='')
display(widgets.HBox([state_toggles_box, probdist_resultbox]))
run_query()


HBox(children=(VBox(children=(ToggleButtons(description='fuel mix', options=('no data', 'high emission', 'low …

In [415]:
current_states

{'fuel mix': 'high emission',
 'fuel sales': None,
 'ghg': None,
 'temperature': None,
 'vehicle emissions': None,
 'vmt': None}