In [50]:
import altair as alt
import pandas as pd

df = pd.read_csv("lung_cancer.csv")

categories = {
    "BMI": "bmi",
    "Exercise hours/week": "exercise_hours_per_week",
    "Diet quality": "diet_quality",
    "Alcohol units/week": "alcohol_units_per_week",
    "Cigarettes/day": "cigarettes_per_day",
    "Smoking years": "smoking_years",
}
vals = list(categories.values())
risk_col = "lung_cancer_risk"

df["risk_label"] = df[risk_col].map({0: "Not at risk", 1: "At risk"})

# add patient ids
df = df.reset_index(drop=True)
df["patient_id"] = df.index

xparam = alt.param(
    name="xvar",
    value=categories["Cigarettes/day"],
    bind=alt.binding_select(options=vals, labels=list(categories.keys()), name="X-axis: ")
)

yparam = alt.param(
    name="yvar",
    value=categories["Exercise hours/week"],
    bind=alt.binding_select(options=vals, labels=list(categories.keys()), name="Y-axis: ")
)

# switch to long format to iterate through
base = alt.Chart(df).add_params(xparam, yparam).transform_fold(
    vals, as_=["variable", "value"]
)

# compare only the chosen X or Y variable
data = (
    base
    .transform_filter((alt.datum.variable == xparam) | (alt.datum.variable == yparam))
    .transform_calculate(
        axis="if(datum.variable === xvar, 'x', 'y')"
    )
    .transform_aggregate(
        val="max(value)",  
        groupby=["patient_id", "axis", "risk_label"]
    )
    .transform_pivot(
        "axis", value="val", groupby=["patient_id", "risk_label"]
    )
)

chart = (
    data
    .mark_point()
    .encode(
        alt.X("x:Q"),
        alt.Y("y:Q"),
        alt.OpacityValue(0.5),
        alt.Color("risk_label:N", title="Lung Cancer Risk"),
        alt.Tooltip(["risk_label:N", "x:Q", "y:Q"])
    )
    .properties(width=500, height=500)
)

chart