In [None]:
import plotnine as p9
import polars as pl

In [None]:
df_units = pl.read_csv("../data/units.csv")

arrows_dict = []
for row in df_units.iter_rows(named=True):
    if not row["is_base_unit"]:
        for _type, _var in zip(
            ["MULTIPLICATION", "DIVISION"],
            ["mult_dependency", "divide_dependency"],
        ):
            mults = row[_var].split(",")
            for _m in mults:
                if _m != "":
                    _row = df_units.filter(pl.col("unit_name") == _m)["row"][0]
                    _col = df_units.filter(pl.col("unit_name") == _m)["column"][
                        0
                    ]
                    _name = df_units.filter(pl.col("unit_name") == _m)[
                        "unit_name"
                    ][0]

                    arrows_dict.append(
                        {
                            "unit_name": _name,  # row["unit_name"],
                            "type": _type,
                            "color": _name,
                            "start": _name,
                            "start_column": _col,
                            "start_row": _row,
                            "end": row["unit_name"],
                            "end_column": row["column"],
                            "end_row": row["row"],
                            "start_end": _name + " " + row["unit_name"],
                        }
                    )


df_arrows = pl.DataFrame(arrows_dict)


df_arrows2 = pl.concat(
    [
        df_arrows.with_columns(x=pl.col("start_column"), y=pl.col("start_row")),
        df_arrows.with_columns(x=pl.col("end_column"), y=pl.col("end_row")),
    ]
).sort("start_end")

In [None]:
df = df_units.join(df_arrows2, on="unit_name", how="left", coalesce=True)
df

In [None]:
df_derived = df.filter(~pl.col("is_base_unit"))
df_base = df.filter(pl.col("is_base_unit"))

(
    p9.ggplot(mapping=p9.aes(x="column", y="row"))
    # ARROWS
    + p9.geom_line(
        data=df_arrows2.filter(pl.col("type") == "DIVISION"),
        linetype="dotted",
        mapping=p9.aes(x="x", y="y", group="start_end"),
    )
    + p9.geom_line(
        data=df_arrows2.filter(pl.col("type") == "MULTIPLICATION"),
        linetype="solid",
        mapping=p9.aes(x="x", y="y", group="start_end"),
    )
    # BASE
    + p9.geom_tile(
        data=df_base,
        mapping=p9.aes(width=0.9, height=0.9, fill="color"),
    )
    + p9.geom_text(
        data=df_base,
        mapping=p9.aes(label="unit_symbol"),
        color="white",
        nudge_y=+0.2,
        size=30,
    )
    + p9.geom_text(
        data=df_base,
        mapping=p9.aes(label="unit_name"),
        color="white",
        size=20,
        nudge_y=-0.2,
    )
    # DERIVED
    + p9.geom_tile(
        data=df_derived,
        mapping=p9.aes(width=0.7, height=0.7),
        fill="#f0f0f0",
    )
    + p9.geom_text(
        data=df_derived,
        mapping=p9.aes(label="unit_symbol"),
        size=15,
        nudge_y=+0.2,
    )
    + p9.geom_text(
        data=df_derived,
        mapping=p9.aes(label="unit_name"),
        nudge_y=0,
    )
    + p9.geom_text(
        data=df_derived,
        mapping=p9.aes(label="base_units_composite"),
        nudge_y=-0.2,
    )
    + p9.scale_y_reverse()
    + p9.theme_void()
    + p9.coord_equal(expand=False)
    + p9.theme(
        legend_position="none",
        figure_size=(10, 6),
        plot_margin_left=0.01,
        plot_margin_right=0.01,
        plot_background=p9.element_rect(fill="white"),
    )
)