In [1]:
"""
We want an interctive finance chart for monthyl costs, to see affordability based on factors

Costs:
* Utilities
* Insurance
* Property Taxes

Based on factors
* Sale price
* part that is land cost
* Number of roms


For each factor, have line chart, of value vs monthly cost
"""

import pandas as pd
import altair as alt

In [2]:
import dataclasses
import typing

In [35]:
@dataclasses.dataclass
class Variable:
    label: str
    start: int
    stop: int
    step: int
    default: int
    tp: typing.Literal['$', '%']

    @property
    def axis_format(self):
        return {
            '$': '$.2s',
            '%': '%'
        }[self.tp]

def generate_plot(fn, **variables):
    sequences = {
        k: alt.sequence(v.start, v.stop, v.step, as_=k)
        for k, v in variables.items()
    }
    
    selections = {
        k: alt.selection(
            type='single',
            on='mouseover',
            nearest=True,
            fields=[k],
            init={k: v.default}
        )
        for k, v in variables.items()
    }
    
    line_charts = {
        k: alt.Chart(sequences[k]).transform_calculate(
            # Take the sum of all the returned monthly costs
            monthly_cost=sum(fn(**{
                # if the input is this variable, use the data field, otherwise use the last selection for it
                inner_k: getattr(alt.datum, inner_k) if inner_k == k else getattr(selections[inner_k], inner_k) 
                for inner_k in variables.keys()
            }).values())
        ).mark_line().encode(
            alt.X(field=k, type='quantitative', axis=alt.Axis(
                format=v.axis_format,
                title=v.label
            )),
            alt.Y('monthly_cost:Q', axis=alt.Axis(
                title='Monthly Cost',
                format='$.2s'
            )),
        )
        
        for k, v in variables.items()
    }
    
    # Transparent selectors across the chart. This is what tells us
    # the x-value of the cursor
    transparent_point_charts = {
        k: alt.Chart(sequences[k]).mark_point().encode(
            alt.X(field=k, type='quantitative'),
            opacity=alt.value(0),
        ).add_selection(
            selections[k]
        ) 
        for k in variables.keys()
    }
    
    # Draw a rule at the location of the selection
    rule_charts = {
        k: alt.Chart(sequences[k]).mark_rule(
            # color='gray'
        ).encode(
            alt.X(field=k, type='quantitative'),
        ).transform_filter(
            selections[k]
        )
        for k in variables.keys()
    }
    
    
    monthly_cost_categories = fn(**{
        k: getattr(v, k) for k, v in selections.items()
    })
    monthly_cost = None
    for k, v in monthly_cost_categories.items():
        monthly_cost = alt.expr.if_(alt.datum.category == k, v, monthly_cost)
    
    
    base_pie_chart = alt.Chart(
        alt.InlineData([
            {"category": k}
            for k in monthly_cost_categories.keys()
        ])
    ).transform_calculate(
        cost=monthly_cost
    )


    base_pie_chart_with_theta = base_pie_chart.encode(
        theta=alt.Theta("cost:Q", stack=True),
        tooltip=['category:N', 'cost:Q']
    )

    pie_arc_chart = base_pie_chart_with_theta.mark_arc(
        innerRadius=30,
        outerRadius=120
    ).encode(
        color=alt.Color(
            "category:N",
            legend=alt.Legend(
                orient='top',
                title='Monthly cost by category'
            )
        )
    )
    pie_text_chart = base_pie_chart_with_theta.mark_text(
        radius=140,
        size=10
    ).encode(alt.Text("cost:Q", format='$.2s'))


    pie_sum_text_chart = base_pie_chart.mark_text(radius=0, size=20).encode(
        alt.Text("cost:Q", aggregate='sum', format='$.2s')
    )

    chart = alt.vconcat(
        pie_arc_chart + pie_text_chart + pie_sum_text_chart,
        alt.concat(
            *(
                alt.layer(line_charts[k], transparent_point_charts[k], rule_charts[k])
                for k in variables.keys()
            ),
            columns=2
        ).resolve_scale(
            y='shared'
        )
    )
    chart.save('index.html')
    return chart

In [53]:
generate_plot(
    lambda property_cost, interest_rate, insurance, lawyer_fees, clt_value: {
        "Home Insurance": insurance / 12,
        "Investor Return": property_cost * interest_rate / 12 + lawyer_fees * interest_rate / 12 - clt_value * interest_rate / 12,
        "CLT Lease Fee": alt.expr.if_(clt_value <= 0, 0, 75)
    },
    property_cost=Variable("Purchase price", 110 * 1000, 300 * 1000, 5 * 1000, 170 * 1000, '$'),
    clt_value=Variable("Land value owned by CLT", 0, 150 * 1000, 10000, 70 * 1000, '$'),
    interest_rate=Variable("Investor interest rate", 0, 0.10, 0.01, 0.03, '%'),
    insurance=Variable("Home Insurance (annual)", 300, 800, 100, 500, '$'),
    lawyer_fees=Variable("Lawyer fees and closing costs", 2000, 20000, 1000, 10000, '$')
)