In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from IPython.display import display, HTML
from typing import List

In [2]:
# %load html-table
# import numpy as np 
# from IPython.display import display, HTML
# from typing import List

def h_table(
    rows: list,
    has_header: bool = True,
    header_rows: int = 1,
    dark_header: bool = False,
    header_color: str = "#ffffff",
    font_size:int = 16,
    row_margin: str = "2px",
    header_border_bottom: bool = True,
    body_border_bottom: bool = True,
    table_title: str = "",
    table_align: str = "center",
    table_border_all: bool = False,
    border_thickness: str = "1px",
    border_color: str = "Black",
    stripe: bool = True,
    stripe_intensity: float = 0.05,
    display_html:bool = True,
    return_html: bool = False,
    show_args: bool = False,
):

    """
    ---------------------------------------------------------------------------
    htable() takes a list of rows, each of which is a string.

    In the typical use case, these rows will be generated by kw_format(),
        which calls kw_f1().

    The keyword arguments other than the list of rows determine the html
        formatting applied to the rows.

    htable() applies the implied CSS styles inline so the table can be included
        in an html page without knowing what styles it is using and without
        changing its styles.

    Use 'show_args = True' to see all the keyword arguments that determine the
        style of the table.

    The default value for dark_header is false. In this case, the font color
        for the header is black.

    If 'dark_header = True', the font color for the header switches to white
        and the tinted stripes start on row 2 of the body instead of row 1 of
        the body.
    ---------------------------------------------------------------------------

    """
    htable_arg_dict = {
        "has_header": has_header,
        "header_rows": header_rows,
        "dark_header": dark_header,
        "header_color": header_color,
        "font_size": font_size,
        "row_margin": row_margin,
        "header_border_bottom": header_border_bottom,
        "body_border_bottom": body_border_bottom,
        "table_title": table_title,
        "table_align": table_align,
        "table_border_all": table_border_all,
        "border_thickness": border_thickness,
        "border_color": border_color,
        "stripe": stripe,
        "stripe_intensity": stripe_intensity,
        "display_html": display_html,
        "return_html": return_html,
        "show_args": show_args,
    }
    if show_args:
        print("'show_args = True' returns a dictionary of argument names and values,")
        print("    and prints the arguement names and values.")
        print()
        print("In both cases, the results reflect any supplied arguments and")
        print("    any defaults that have not been overwritten.")
        print()
        for key, value in htable_arg_dict.items():
            print(key, " = ", repr(value))
        return htable_arg_dict

    # Create strings for HTML styles
    #   If dark_header, calculate light version of header background for stripe
    if dark_header:
        header_bkgd = "background: " + header_color + ";"
        h_font_color = "color: #ffffff;"
        if stripe:
            stripe_bkgd = "background: " + tint(header_color, stripe_intensity) + "; "
        even_stripes = True

    else:  # default gray
        header_bkgd = "background: #ffffff; "
        h_font_color = "color: #000000; "
        if stripe:
            default_grey = tint("#f5f5f5", stripe_intensity / 0.05)
            stripe_bkgd = "background: " + default_grey + "; "
        even_stripes = False

    #   Font Size
    # table_style = "font-family: Menlo, Courier; "
    table_style = "font-size: " + str(font_size) + "px; "

    #   Borders
    border_style = border_thickness + " solid " + border_color

    if body_border_bottom:
        table_border_style = "border-bottom: " + border_style + "; "
        table_border_style += "border-collapse: collapse; "
    else:
        table_border_style = "border-bottom: none; "

    if table_border_all:
        table_border_style = "border: " + border_style + "; "
        table_border_style += "border-collapse: collapse; "

    table_style += table_border_style
    if table_align == "left":
        table_style += 'align="left";'
    else:
        table_style += "margin: 32px; margin-left: auto; margin-right: auto; "

    if header_border_bottom:
        header_style = "border-bottom: " + border_style + "; "
    else:
        header_style = "border-bottom: none; "

    #   Creating the table
    #       add the header row created with header_f()
    #       create header with h_row as input
    #       loop through generator to create b_row and append to b_rows
    #       run body_tags using b_rows as input
    #       run table_tags with header and body as input

    #   Defin local variables
    b_rows = ""

    #   Define functions for creating
    #       table_f -> table tags
    #       header_f -> header tags and header content
    #       body_f -> body tags
    #       body_row_f -> body row tags

    font_fam = """font-family: Menlo, "Deja Vu Sans Mono", "Roboto Mono", "Courier New", Courier, monospace; """

    def table_f(header, body):
        h = '<table style="{0}">\n'.format(table_style)
        # if len(table_caption) >0:
        #     c0_style = "text-align: center; font-size: "
        #     c1_style = str(font_size + 4) + "; "
        #     # cap_style = "{0}{1}{2}".format(font_fam, c0_style, c1_style)
        #     cap_style = "{0}{1}".format(c0_style, c1_style)
        #     # cap_style = "text-align: center; font-size:" +  + font_fam
        #     h += "<caption style='{0}'>".format(cap_style)
        #     h += table_caption + "</caption>\n"
        h += "{0}\n".format(header)
        h += "{0}\n".format(body)
        h += "</table>\n"
        return h

    def title_f():
        title_font_size = str(1.5 * font_size)
        title_style = "text-align: center; margin-bottom: " + title_font_size + "px; "
        title_style += "font-size: " + title_font_size + "px; "
        pre_style = font_fam + "{0}{1}{2}".format(
            header_bkgd, h_font_color, title_style
        )
        return pre_style

    def header_f(table_title):
        pre_style = font_fam + "{0}{1}".format(header_bkgd, h_font_color)
        h = '  <thead style="{0}">\n'.format(header_style)
        h += '    <tr style="{0}">\n'.format("")
        h += '      <th style="{0}">\n'.format(header_bkgd)
        if len(table_title) > 0:
            h += "        <pre style='margin: " + row_margin + "; {0}'>{1}</pre>\n".format(title_f(), table_title)
        for j in range(0,header_rows):
            h += "        <pre style='margin: " + row_margin + "; {0}'>{1}</pre>\n".format(pre_style, next(row_gen))
        h += "      </th>\n"
        h += "    </tr>\n"
        h += "  </thead>\n"
        return h

    def body_f(b_rows):
        h = "  <tbody>\n{0}\n".format(b_rows)
        h += "  </tbody>\n"
        return h

    def body_row_f(b_row, row_num, even_stripes):
        if even_stripes:
            if row_num % 2 == 0:
                bkgd = stripe_bkgd
            else:
                bkgd = "background: #ffffff; "
        else:
            if row_num % 2 == 0:
                bkgd = "background: #ffffff; "
            else:
                bkgd = stripe_bkgd

        pre_style = font_fam + "{0}".format(bkgd)

        h = "    <tr>\n"
        h += '      <td style="{0}">\n'.format(bkgd)
        h += "        <pre style='margin: " + row_margin + "; {0}'>{1}</pre>\n".format(pre_style, b_row)
        h += "      </td>\n"
        h += "    </tr>\n"
        return h, row_num + 1

    row_gen = (elem for elem in rows)

    #   Create a generator, build header, remove final newline
    if header_rows > 0 or len(table_title) > 0:
        header = header_f(table_title)[0:-1]
        row_num = 1
    else:
        header = ""
        row_num = 0

    # if header_rows % 2 == 0:
    #     even_stripes = False

    #   Iterate through remaining elements in the generator, create the string of body rows
    for elem in row_gen:
        h_row, row_num = body_row_f(elem, row_num, even_stripes)
        b_rows += h_row
    body = body_f(b_rows[0:-1])[0:-1]

    #   Create the table from header and body
    table1 = table_f(header, body)
    # table = '<div style="overflow-x:auto;">\n'
    table = "<div>\n"
    table += table1
    table += "</div>"

    if display_html:
        display(HTML(table))

    if return_html:
        return table
    else:
        return


# ===================================================================================
#  Color functions
# ===================================================================================


def shrink_dist_to_max(h, frac):
    return int(frac * h + (1 - frac) * 0xFF)


def tint(color, frac):
    """tint() takes a color as input, returns a tint that is
            only frac times the distance from white of color

       Inputs:
            color is a string in html form for hex representation
                starts w "#" followed by hex digits;
                e.g. "#0000ff" for blue
            frac is a float between 0 and 1
                closer to zero yields a lighter tint

       Return value:
            a string in the same form
            represents a tint that takes the original distance
                of color from white and reduces it to a value of size
                frac
    """
    r = shrink_dist_to_max(int(color[1:3], 16), frac)
    g = shrink_dist_to_max(int(color[3:5], 16), frac)
    b = shrink_dist_to_max(int(color[5:7], 16), frac)
    return "#" + hex((r * 0x10000) + (g * 0x100) + b)[2:]


In [3]:
# %load ad-tax
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib.ticker as ticker
# from html-table import h_table 
# from IPython.display import display, HTML

def ar(x: float, b: list, r: list):
    """
    Calculate the average tax rate

    Parameters
    ----------
    x, revenue 
    b, list of tax brackets
    r, corresponding list of tax rates 
    
    Returns
    -------
    average tax rate 
    """
    bl = b
    bu = (b[1:] + [np.inf])
    
    tp_bl = [0] + [ (bu[j]-bl[j]) * r[j] for j in range(0,len(b)-1) ]
    tp_cum = np.array(tp_bl).cumsum()

    a = 0
    for j in range(len(b)):
        if bl[j] < x and x <= bu[j]:
            a = ( (x - bl[j]) * r[j] + tp_cum[j] ) / x
    return a

def us_revenue():
    """
    Returns a dict with company names as keys and a list of revenue 2018-2023 as value.
    """
    return {
        "Amazon":     [7405535228.00,10319981126.00,15734930757.00,20471812845.12,26203920441.75,31968782938.94],
        "Facebook":     [24522653399.57,31266383084.45,38301319278.45,48475395994.50,57064562199.66,65375870105.82],
        "Google":     [36476873676.10,41799332909.88,44064482272.17,54933079215.58,60698189040.49,66472154111.69],
        "Hulu":     [1455200000.00,1951136805.00,2550075570.00,3390134249.43,4169865126.80,4982988826.52],
        "IAC":     [502800960.82,617207582.40,545000828.13,602225915.08,641370599.56,670232276.54],
        "Microsoft":     [4564720620.05,5294834699.31,5559599013.53,6654598018.60,7508823966.43,8301276813.48],
        "Reddit":     [76903200.00,103047001.20,176918167.48,247889909.99,315471183.58,373751216.05],
        "Roku":     [290441250.00,528408003.20,825367180.00,1461385859.79,2062842763.21,2743812995.34],
        "Snapchat":     [668196016.99,877990946.26,1253547960.84,1824657820.00,2495205459.02,3333581435.90],
        "Spotify":     [364553024.33,449755501.87,519348412.82,698564719.72,873205899.65,1030382961.59],
        "Twitter":     [1321202682.67,1604259465.12,1696783197.53,2206578767.25,2560867054.12,2752932083.18],
        "Verizon Media":     [3450212340.50,3360506819.64,3181279789.26,3435782172.40,3641929102.75,3787606266.86],
        "Yelp":     [894316739.97,963248050.00,825245505.00,918212191.63,1012078012.83,1072802693.60],
    }

def table_marg_rates(b, r):
    
    bl = b
    bu = (b[1:] + [np.inf])

    row_list = []
    row_list.append(f" For Revenue Between " + " " * 5 + "Marginal Tax Rate" + " " * 2)

    for j in range(len(b)):
        l = b[j]
        u = bu[j]
        t = r[j]
        if u == np.inf:
            row_list.append(" " * 6 + f"Above {l:>3} billion {t:>17.1%}" + " " * 8)
        else:
            row_list.append(" " * 5 + f" {l:>2} and {u:>2} billion {t:>17.1%}" + " " * 8)

    return h_table(row_list, font_size=12, row_margin = "4px", display_html = False, return_html = True)
    
    
def table_tax_paid_by_firm(b, r):
    us_rev = us_revenue()
    us_rev_l = list(us_rev.items())
    us_rev_s = sorted(us_rev_l, key=lambda company: company[1][3], reverse = True)
    us_names = [ company[0] for company in us_rev_s ]
    us_rev_2021 = np.array([ company[1][3] / 10**9 for company in us_rev_s ])
    us_avg_r = np.array([ ar(x, b, r) for x in us_rev_2021 ]) 
    tax_paid = np.array([ a * r for a, r in zip(us_avg_r, us_rev_2021)])
    
    row_list = []
    fm = "Firm"
    tp = "Tax Paid"
    row_list.append(f"{fm: ^15}  {tp: ^15}")
    for n, t in zip(us_names, tax_paid):
        row_list.append(f"{n: <15} {t:^15.1f}")
    total = tax_paid.sum()
    n = "Total"
    row_list.append(f"{n: <15} {total:^15.1f}")

    return h_table(row_list, font_size=12, row_margin = "4px", display_html = False, return_html = True)

    
def table_tax_paid_total(b, r):
    us_rev = us_revenue()

    us_rev_l = list(us_rev.items())
    us_rev_s = sorted(us_rev_l, key=lambda company: company[1][3], reverse = True)
    us_names = [ company[0] for company in us_rev_s ]
    us_rev_a = np.array([ np.array(company[1]) / 10**9 for company in us_rev_s ])
    avg_r_a = np.array([ [ar(x, b, r) for x in rev_list ] for rev_list in us_rev_a ]) 
    
    tax_payments = (avg_r_a * us_rev_a).sum(0)
    
    row_list = []
    year = "Year"
    tr = "Tax Receipts"
    b = "(billion)"
    blank = ""
    row_list.append(f"{year: ^10} {tr: ^15}")
    row_list.append(f"{blank: ^10} {b: ^15}")
    
    for j in range(6): 
        row_list.append(f"{j+2018: ^10d} {tax_payments[j]: ^15.1f}")
    
    return h_table(row_list, header_rows = 2, font_size=12, row_margin = "4px", display_html = False, return_html = True)    

base_style = {
    'figure.dpi':                   300,  # in a notebook, higher dpi makes the graph larger   
    'figure.figsize':  [3*1.6180339, 3],  # constrain the size and use the golden ratio to set the size
    'figure.facecolor':         'white',
    'figure.titlesize':               8,
    'axes.titlesize':                 8,  # the default font sizes have to be smaller bc of the higher dpi
    'axes.labelsize':                 6,
    'ytick.labelsize':                4,
    'xtick.labelsize':                4,
    'legend.fontsize':                5,
    'lines.linewidth':                1,
    'lines.markersize':               3,
    'xtick.major.size':             2.0,
    'xtick.major.width':            0.3,
    'ytick.major.size':             2.0,
    'ytick.major.width':            0.3,
}

def floating_spines(ax, axis = 'l'):
    """[summary]

    Parameters
    ----------
    ax : [type]
        [description]
    axis : str, optional
        [description], by default 'l'

    Returns
    -------
    [type]
        [description]
    """
    ax.spines['top'].set_visible(False)
    if axis == 'l':
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(True)
        ax.spines['left'].set_linewidth(0.2)
        ax.spines['left'].set_position(('outward', 5))
    else:
        ax.spines['left'].set_visible(False)
        ax.spines['right'].set_visible(True)
        ax.spines['right'].set_linewidth(0.2)
        ax.spines['right'].set_position(('outward', 5))

    ax.spines['bottom'].set_visible(True)
    ax.spines['bottom'].set_linewidth(0.2)
    ax.spines['bottom'].set_position(('outward', 5))
    return ax

def us_fig(b: list, r: list):
    """
    """

    us_rev = us_revenue()

    us_rev_l = list(us_rev.items())
    us_rev_s = sorted(us_rev_l, key=lambda company: company[1][3], reverse = True)
    us_names = [ company[0] for company in us_rev_s ]
    us_rev_2021 = np.array([ company[1][3] / 10**9 for company in us_rev_s ])
    us_avg_r = np.array([ ar(x, b, r) for x in us_rev_2021 ]) 
        
    rev_for_line = range(60)
    ar_for_line = [ar(x, b, r) for x in rev_for_line]
    
    adjustments = [
        (+0.5, -1.5, "", 0),
        (+0.5, -1.5, "", 1),
        (+0.5, -1.5, "", 2),
        (+1.0, -0.4, "", 3),
        (+2.0, -0.8, "All others", 4)
    ]

    with plt.style.context(base_style):
        fig = plt.figure()
        ax = fig.add_axes([0.1, 0.15, 0.8, 0.75])
        # ax = fig.add_axes([0.35, 0.2, 0.65, 0.6])
        ax = floating_spines(ax)
        ax.set_xlim(-1, 60)
        ax.set_ylim(-0.01, 0.4)
        ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax = 1.0, decimals = 0))
        ax.spines['left'].set_bounds(low = 0, high = 0.4)
        ax.spines['bottom'].set_bounds(low = 0, high = 60)
        ax.plot(rev_for_line, ar_for_line)
        ax.plot(us_rev_2021, us_avg_r, ls = '', marker = 'o', ms = 1.5)
        for j in range(len(adjustments)):
            if adjustments[j][2] == "All others":
                ax.text(
                    us_rev_2021[j] + adjustments[j][0], 
                    us_avg_r[j] + adjustments[j][1]/100, 
                    adjustments[j][2], fontsize = 4
                ) 
            else:  
                ax.text(
                    us_rev_2021[j] + adjustments[j][0], 
                    us_avg_r[j] + adjustments[j][1]/100, 
                    us_names[j] + adjustments[j][2], fontsize = 4
                ) 

        ax.set_title("Average Tax Rate as a Function of Total Revenue", pad = 10)
        ax.axes.xaxis.set_label_text("US Revenue 2021, Billion USD per Year")

    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    img_str = 'data:image/png;base64,' + base64.b64encode(buf.read()).decode('UTF-8')
        
    return img_str

def split(s, b, r, revenue = 50):
    tax_bill = s * (ar(revenue/s, b, r)* revenue/s)
    return f"${tax_bill: <2.1f} billion"

In [4]:
# bracket thresholds
b0 = [0, 5, 10, 15, 20, 25, 30, 35, 40, 50, 60];
r0 = [0, 0.05, 0.125, 0.20, 0.275, 0.35, 0.425, 0.50, 0.575, 0.65, 0.725 ]

In [5]:
# print(table_tax_paid_total(b0, r0))

In [6]:
display(HTML(table_tax_paid_total(b0, r0)))

Year Tax Receipts (billion)
2018 11.1
2019 16.5
2020 21.8
2021 35.2
2022 46.1
2023 58.3


In [7]:
display(HTML(table_marg_rates(b0, r0)))

For Revenue Between Marginal Tax Rate
0 and 5 billion 0.0%
5 and 10 billion 5.0%
10 and 15 billion 12.5%
15 and 20 billion 20.0%
20 and 25 billion 27.5%
25 and 30 billion 35.0%
30 and 35 billion 42.5%
35 and 40 billion 50.0%
40 and 50 billion 57.5%
50 and 60 billion 65.0%


In [8]:
s1 = '(function() {document.getElementById("total-revenue-table-default").innerHTML = '
s1 += repr(table_tax_paid_total(b0, r0))  + '})();\n\n'
s1 += '(function() {document.getElementById("marginal-rate-table-default").innerHTML = '
s1 += repr(table_marg_rates(b0, r0)) + '})();'

In [9]:
# print(s1) 

In [10]:
with open('../js/load-tables.js', 'w') as f:
    f.write(s1)