# Visual Aid for Bayes' Theorem 

## Code

[Jump to Interactive Section](#interactive)

In [None]:
from __future__ import division

import numpy as np
import math
from six.moves import xrange

from plotly import offline as pyoffline
import plotly.graph_objs as go
pyoffline.init_notebook_mode(connected=True)

from IPython.display import display, HTML, clear_output
from ipywidgets import interact, interact_manual, FloatSlider, IntSlider, Button, HBox, VBox
from ipywidgets import HTML as HTMLWidget

In [None]:
colorMap = {
    "green": "#3FB63F",
    "blue": "#53ACE4",
    "yellow": "#FAB00E",
    "red": "#FF4C4C",
    "purple": "#663399",
}

local_css = """
<style type="text/css">
div.jupyter-widgets .ui-slider-handle {{
    background-image: none;
}}
div.jupyter-widgets.color-green .ui-slider-handle {{
    background-color: {green};
}}
div.jupyter-widgets.color-blue .ui-slider-handle {{
    background-color: {blue};
}}
div.jupyter-widgets.color-yellow .ui-slider-handle {{
    background-color: {yellow};
}}
div.jupyter-widgets.color-red .ui-slider-handle {{
    background-color: {red};
}}
div.jupyter-widgets.color-purple .ui-slider-handle {{
    background-color: {purple};
}}

span.color-green {{
    color: {green};
}}
span.color-blue {{
    color: {blue};
}}
span.color-yellow {{
    color: {yellow};
}}
span.color-red {{
    color: {red};
}}
span.color-purple {{
    color: {purple};
}}
</style>
""".format(**colorMap)
display(HTML(local_css))

In [None]:
class Confusion():
    def __init__(self, prevalence=0.01, tpr=0.8, fpr=0.096, population=10000, num_points=100):
        self.prevalence = prevalence
        self.tpr = tpr
        self.fpr = fpr
        self.population = population
        self.num_points = num_points
        self.setupVars()
        self.setupCons()
    
    def setupCons(self):
        colors = {}
        colors["True Positive"] = colorMap["green"]
        colors["False Negative"] = colorMap["blue"]
        colors["False Positive"] = colorMap["yellow"]
        colors["True Negative"] = colorMap["red"]
        self.colors = colors

        self.labels_condition_positive = ["True Positive", "False Negative"]
        self.labels_condition_negative = ["False Positive", "True Negative"]

    def setupVars(self):
        rates = {}
        rates["prevalence"] = self.prevalence # condition positive
        rates["True Positive"] = self.tpr # recall (over condition positive)
        rates["False Positive"] = self.fpr # (over condition negative)
        rates["absence"] = 1 - rates["prevalence"] # condition negative
        rates["False Negative"] = 1 - rates["True Positive"] # (over condition positive)
        rates["True Negative"] = 1 - rates["False Positive"] # (over condition negative)
        self.rates = rates

        min_display = int(math.ceil(0.10 * self.num_points))
        max_display = self.num_points - min_display

        points = {}
        points["prevalence"] = int(math.ceil(rates["prevalence"] * self.num_points))
        points["True Positive"] = int(math.ceil(rates["True Positive"] * self.num_points))
        points["False Positive"] = int(math.ceil(rates["False Positive"] * self.num_points))
        for key in ["prevalence", "True Positive", "False Positive"]:
            if points[key] < min_display and points[key] != 0:
                points[key] = min_display
            elif points[key] > max_display and points[key] != self.num_points:
                points[key] = max_display
        points["absence"] = self.num_points - points["prevalence"]
        points["False Negative"] = self.num_points - points["True Positive"]
        points["True Negative"] = self.num_points - points["False Positive"]
        self.points = points

        counts = {}
        counts["True Positive"] = int(round(self.population * rates["prevalence"] * rates["True Positive"]))
        counts["False Negative"] = int(round(self.population * rates["prevalence"] * rates["False Negative"]))
        counts["False Positive"] = int(round(self.population * rates["absence"] * rates["False Positive"]))
        counts["True Negative"] = int(round(self.population * rates["absence"] * rates["True Negative"]))

        counts["Condition Positive"] = counts["True Positive"] + counts["False Negative"]
        counts["Condition Negative"] = counts["False Positive"] + counts["True Negative"]
        counts["Prediction Positive"] = counts["True Positive"] + counts["False Positive"]
        counts["Prediction Negative"] = counts["False Negative"] + counts["True Negative"]
        counts["Prediction True"] = counts["True Positive"] + counts["True Negative"]
        counts["Prediction False"] = counts["False Positive"] + counts["False Negative"]
        self.counts = counts
        
    def makeSurface(self, label, pos):
        points = self.points
        def placeVal(place):
            return self.counts[label] if place == pos else 0
        if label in self.labels_condition_positive:
            up = self.labels_condition_positive[0]
            down = self.labels_condition_positive[1]
        else:
            up = self.labels_condition_negative[0]
            down = self.labels_condition_negative[1]
        surface = \
            [([placeVal(0) for i in xrange(points["prevalence"])] + [placeVal(1) for i in xrange(points["absence"])]) \
                for j in xrange(points[up])] + \
            [([placeVal(2) for i in xrange(points["prevalence"])] + [placeVal(3) for i in xrange(points["absence"])]) \
                for j in xrange(points[down])]
        surface = [[0 for i in xrange(self.num_points)]] + surface + [[0 for i in xrange(self.num_points)]]
        surface = [([0] + i + [0]) for i in surface]
        return surface

    def writeSurfaceTooltips(self, label):
        if label in self.labels_condition_positive:
            prev_abs = "{} (Prevalence | Condition)".format(self.rates["prevalence"])
        else:
            prev_abs = "{} (Absence | Condition)".format(self.rates["absence"])
        text = '{rate} ({label} Rate | Prediction)<br> x {prev_abs}<br> x {pop} (Population)<br> = {count} cases'.format(
            rate=self.rates[label],
            label=label,
            prev_abs=prev_abs,
            pop=self.population,
            count=self.counts[label],
        )
        return [[text for j in xrange(self.num_points)] for i in xrange(self.num_points)]

    def drawSurfacePlot(self):
        surfaces = {}
        surfaces["True Positive"] = self.makeSurface("True Positive", 0)
        surfaces["False Negative"] = self.makeSurface("False Negative", 2)
        surfaces["False Positive"] = self.makeSurface("False Positive", 1)
        surfaces["True Negative"] = self.makeSurface("True Negative", 3)
        
        surfaceData = [dict(
            z=surfaces[label],
            name=label,
            colorscale=[[0, "white"], [1, self.colors[label]]],
            text=self.writeSurfaceTooltips(label),
            hoverinfo="text+name",
            hoverlabel=dict(
                font=dict(
                    color=self.colors[label]
                )
            ),
            type="surface",
            showscale=False
        ) for label in ["True Positive", "False Negative", "False Positive", "True Negative"]]

        surfaceLayout = go.Layout(
            scene=go.Scene(
                xaxis=dict(
                    autorange=True,
                    title="Condition",
                    tickmode='array',
        #             tickangle=-90,
                    tickvals=[0, self.points["prevalence"] / 2, self.points["prevalence"], (100 - self.points["prevalence"]) / 2, 100],
                    ticktext=["", "+ ({}%)".format(self.rates["prevalence"] * 100), "", "- ({}%)".format(self.rates["absence"] * 100), ""],
                ),
                yaxis=dict(
                    autorange=True,
                    title="Prediction",
                    tickmode='array',
        #             tickangle=-90,
                    tickvals=[0, self.points["False Positive"], self.points["True Positive"], 100],
                    ticktext=["+", "", "", "-"],
                ),
                zaxis=dict(
                    autorange=True,
                    title="Count",
                ),
                camera=dict(
                    up=dict(x=0, y=0, z=1),
                    center=dict(x=0, y=0, z=0),
                    eye=dict(x=-1.25, y=-1.25, z=1.25),
                ),
            )
        )
        surfaceFig = go.Figure(data=surfaceData, layout=surfaceLayout)
        pyoffline.iplot(surfaceFig)

    def genPieDomains(self, rows, cols):
        arr = []
        rowsize = 1 / rows
        colsize = 1 / cols
        for row in xrange(rows):
            for col in xrange(cols):
                arr.append({
                    'x': [
                        col * colsize,
                        (col + 1) * colsize
                    ],
                    'y': [
                        1 - (row + 1) * rowsize,
                        1 - row * rowsize
                    ]
                })
        return arr

    def bakePies(self, pies, title=""):
        num_pies = len(pies)
        domains = self.genPieDomains(rows=int(math.ceil(num_pies / 2)), cols=min([2, num_pies]))
        data = []
        annotations = []
        for index, item in enumerate(pies):
            labels = item["labels"]
            sublabels = item.get("sublabels", [])
            subtitle = item.get("title", "")
            domain = domains[index]

            pie = go.Pie(
                labels=labels,
                values=[self.counts[label] for label in labels],
                marker=dict(
                    colors=[self.colors.get(label) for label in labels]
                ),
                name=subtitle,
                textinfo='percent',
                text=sublabels,
                domain=domain
            )
            if num_pies > 1:
                pie["hole"] = 0.4
            data.append(pie)

            anno = dict(
                showarrow=False,
                text=subtitle,
                x=np.mean(domain["x"]),
                y=np.mean(domain["y"]),
                xanchor="center",
                yanchor="middle"
            )
            annotations.append(anno)
        layout = {
            "title": title,
        }
        if num_pies > 1:
            layout["annotations"] = annotations
        pieFig = go.Figure(data=data, layout=layout)
        pyoffline.iplot(pieFig)

    def drawPieCharts(self):
        self.bakePies([
            dict(labels=["True Positive", "False Positive"], sublabels=["Precision / Positive Predictive Value (PPV)", "False Discovery Rate (FDR)"], title="Prediction Positive"),
            dict(labels=["True Negative", "False Negative"], sublabels=["Negative Predictive Value (NPV)", "False Omission Rate (FOR)"], title="Prediction Negative"),
            dict(labels=["True Positive", "False Negative"], sublabels=["Recall / Sensitivity / True Positive Rate (TPR)", "False Negative Rate (FNR)"], title="Condition Positive"),
            dict(labels=["False Positive", "True Negative"], sublabels=["Fall-out / False Positive Rate (FPR)", "Specificity / True Negative Rate (TNR)"], title="Condition Negative"),
        ])
        self.bakePies([
            dict(labels=["Condition Positive", "Condition Negative"], sublabels=["Prevalence", ""], title="Condition"),
            dict(labels=["Prediction Positive", "Prediction Negative"], title="Prediction Result"),
            dict(labels=["Prediction True", "Prediction False"], sublabels=["Accuracy", ""], title="Prediction Accuracy"),
            dict(labels=["True Positive", "False Negative", "False Positive", "True Negative"], title="Population"),
        ])

In [None]:
# def updatePlots(**kwargs):
#     conf = Confusion(**kwargs)
#     conf.drawSurfacePlot()
#     conf.drawPieCharts()

# interact_manual(
#     updatePlots,
#     prevalence=prevalence,
#     population=population,
#     tpr=tpr,
#     fpr=fpr,
# )

def setup_one_sum(a, b):
    def update_a(*args):
        a.value = 1 - b.value
    b.observe(update_a, 'value')
    def update_b(*args):
        b.value = 1 - a.value
    a.observe(update_b, 'value')

prevalence = FloatSlider(min=0,max=1,step=0.01,value=0.01,continuous_update=True,description="prevalence",readout_format='.1%')
population = IntSlider(min=1,max=10000000,value=10000,continuous_update=True,description="population",readout_format=',')

tpr = FloatSlider(min=0,max=1,step=0.01,value=0.8,continuous_update=True,description="true positive rate",readout_format='.1%')
fnr = FloatSlider(min=0,max=1,step=0.01,value=0.2,continuous_update=True,description="false negative rate",readout_format='.1%')

fpr = FloatSlider(min=0,max=1,step=0.01,value=0.096,continuous_update=True,description="false positive rate",readout_format='.1%')
tnr = FloatSlider(min=0,max=1,step=0.01,value=0.904,continuous_update=True,description="true negative rate",readout_format='.1%')

setup_one_sum(fpr, tnr)
setup_one_sum(tpr, fnr)

tpr.add_class('color-green')
fnr.add_class('color-blue')
fpr.add_class('color-yellow')
tnr.add_class('color-red')
prevalence.add_class('color-purple')

def get_scenario():
    return """<h2>Visual Aid for Bayes' Theorem</h2>
Scenario:
<span class="color-purple">{prevalence:.1%}</span> of people who participate in routine screening have a certain disease. 
<span class="color-green">{tpr:.1%}</span> of people with that disease will test positive.
<span class="color-yellow">{fpr:.1%}</span> of people without that disease will also test positive.
Someone tested positive in a routine screening.
What is the probability that this person actually has that disease?<br><br>

Click on the 'Run' button below. You can compare the relative heights of the <span class="color-green">true positive</span> and <span class="color-yellow">false positive</span> bars to get a general idea of the ratio,
or look at the percentage for <span class="color-green">true positives</span> in the first donut chart to find the answer.<br><br>

Note: The 3D bar chart is only an approximation. It may not be to scale to make smaller amounts more visible.<br></br>
""".format(prevalence=prevalence.value, tpr=tpr.value, fpr=fpr.value)

scenario_text = get_scenario()

scenario = HTMLWidget(
    value=scenario_text,
)

def update_scenario(*args):
    scenario.value = get_scenario()
for control in [prevalence, tpr, fpr]:
    control.observe(update_scenario, 'value')

runButton = Button(description="Run")
def updatePlots(b):
    clear_output(wait=True)
    conf = Confusion(prevalence=prevalence.value, population=population.value, tpr=tpr.value, fpr=fpr.value)
    conf.drawSurfacePlot()
    conf.drawPieCharts()
runButton.on_click(updatePlots)

box = VBox([
    scenario,
    HBox([
        VBox([population, prevalence]),
        VBox([tpr, fnr]),
        VBox([fpr, tnr])
    ]),
    runButton
])

<a id="interactive"></a>

In [None]:
display(box)