In [65]:
# load ./out/training_data.json into df
import pandas as pd
import altair as alt
alt.data_transformers.enable("vegafusion")

# Load the JSON file
df = pd.read_json('./out/training_data.json')
df['chart_complexity_ordinal'] = df['chart_complexity'].map({
    'simple': 0,
    'medium': 1,
    'complex': 2,
    'extra complex': 3,
    })

complexity_labelExpr = """
        {
            0: 'Simple',
            1: 'Medium',
            2: 'Complex',
            3: 'Extra Complex'
        }[datum.label]
        """
chart_complexity_legend = alt.Legend(
        labelExpr=complexity_labelExpr
        )
chart_complexity_format = alt.Axis(
    labelExpr=complexity_labelExpr
)

legend_labelExpr = """
    {
        "scatterplot": "Scatterplot",
        "barchart": "Barchart",
        "stacked_bar": "Stacked Bar",
        "grouped_bar": "Grouped Bar",
        "normalized_bar": "Normalized Bar",
        "circular": "Circular",
        "table": "Table",
        "line": "Line",
        "area": "Area",
        "grouped_line": "Grouped Line",
        "grouped_area": "Grouped Area",
        "grouped_scatter": "Grouped Scatter",
        "heatmap": "Heatmap",
        "histogram": "Histogram",
        "dot": "Dot",
        "grouped_dot": "Grouped Dot"
    }[datum.label]
    """

chart_type_legend = alt.Legend(
    labelExpr=legend_labelExpr
)

chart_type_format = alt.Axis(
    labelExpr=legend_labelExpr
)

In [66]:
# create a histogram of df['spec_key_count']
histogram = alt.Chart(df).mark_bar().encode(
    alt.X("spec_key_count:Q", bin=alt.Bin(maxbins=100)),
    alt.Y("count():Q"),
    alt.Color("chart_complexity_ordinal:N", legend=chart_complexity_legend),
    tooltip=["count():Q", "spec_key_count:Q"]
).properties(
    width=800,
    height=400
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_title(
    fontSize=16
)

histogram.display()

In [67]:
# group by chart complexity and visualization type
grouped = df.groupby(['chart_complexity_ordinal', 'chart_type']).size().reset_index(name='count')


# create a heatmap
heatmap = alt.Chart(grouped).mark_rect().encode(
    y=alt.Y('chart_type:N', title='Chart Type'),
    # x=alt.X('chart_complexity_ordinal:O', title='Chart Complexity'),
    x=alt.X('chart_complexity_ordinal:O', title='Chart Complexity', axis=chart_complexity_format),
    color=alt.Color('count:Q', scale=alt.Scale(scheme='blues')),
    tooltip=['chart_type:N', 'chart_complexity_ordinal:O', 'count:Q']
)

textlayer = alt.Chart(grouped).mark_text().encode(
    y=alt.Y('chart_type:N', title='Chart Type', axis=chart_type_format),
    x=alt.X('chart_complexity_ordinal:O', title='Chart Complexity'),
    text=alt.Text('count:Q'),
    color=alt.condition(
        alt.datum.count > 6000,
        alt.value('white'),  # If count is greater than 6000, use white text
        alt.value('black')   # Otherwise, use black text
    ),
)

heatmap = heatmap + textlayer
heatmap = heatmap.properties(
    width=300,
    height=400
).configure_axisX(
    labelAngle=-45,
)

heatmap.display()