In [1]:
import altair as alt
import pandas as pd

df = pd.read_csv("lung_cancer.csv")
df.head()

Unnamed: 0,age,gender,education_years,income_level,smoker,smoking_years,cigarettes_per_day,pack_years,passive_smoking,air_pollution_index,...,bmi,oxygen_saturation,fev1_x10,crp_level,xray_abnormal,exercise_hours_per_week,diet_quality,alcohol_units_per_week,healthcare_access,lung_cancer_risk
0,60,1,20,2,1,16,15,12,0,71,...,20,94,29,6,1,5,4,13,3,1
1,53,0,12,4,0,0,0,0,1,66,...,25,96,35,4,0,5,2,0,3,0
2,62,1,15,3,1,9,29,13,1,69,...,23,95,29,9,1,1,4,2,1,1
3,73,1,12,3,0,0,0,0,0,47,...,18,96,32,0,0,0,3,10,4,0
4,52,1,13,1,0,0,0,0,0,94,...,16,97,36,8,0,6,2,9,2,0


In [2]:
# Interactive scatter plot

categories = {
    "BMI": "bmi",
    "Exercise hours/week": "exercise_hours_per_week",
    "Diet quality": "diet_quality",
    "Alcohol units/week": "alcohol_units_per_week",
    "Cigarettes/day": "cigarettes_per_day",
    "Smoking years": "smoking_years",
}
vals = list(categories.values())
risk_col = "lung_cancer_risk"

df["risk_label"] = df[risk_col].map({0: "Not at risk", 1: "At risk"})

# add patient ids
df = df.reset_index(drop=True)
df["patient_id"] = df.index

xparam = alt.param(
    name="xvar",
    value=categories["Cigarettes/day"],
    bind=alt.binding_select(options=vals, labels=list(categories.keys()), name="X-axis: ")
)

yparam = alt.param(
    name="yvar",
    value=categories["Exercise hours/week"],
    bind=alt.binding_select(options=vals, labels=list(categories.keys()), name="Y-axis: ")
)

# switch to long format to iterate through
base = alt.Chart(df).add_params(xparam, yparam).transform_fold(
    vals, as_=["variable", "value"]
)

# compare only the chosen X or Y variable
data = (
    base
    .transform_filter((alt.datum.variable == xparam) | (alt.datum.variable == yparam))
    .transform_calculate(
        axis="if(datum.variable === xvar, 'x', 'y')"
    )
    .transform_aggregate(
        val="max(value)",  
        groupby=["patient_id", "axis", "risk_label"]
    )
    .transform_pivot(
        "axis", value="val", groupby=["patient_id", "risk_label"]
    )
)

chart = (
    data
    .mark_point()
    .encode(
        alt.X("x:Q"),
        alt.Y("y:Q"),
        alt.OpacityValue(0.5),
        alt.Color("risk_label:N", title="Lung Cancer Risk"),
        alt.Tooltip(["risk_label:N", "x:Q", "y:Q"])
    )
    .properties(width=500, height=500)
)

chart

In [16]:
# Histogram showing lung cancer risk based on age group, faceted by gender

df["Gender"] = df["gender"].map({0 : "Female", 1 : "Male"})

alt.Chart(df).mark_bar().encode(
    x = alt.X("age:Q", bin=alt.BinParams(maxbins=10), title="Age (years)"),
    y = alt.Y("count():Q", title="Number of patients"),
    color = alt.Color("mean(lung_cancer_risk):Q", title="Chance to be at elevated risk")
).facet(
    facet = "Gender:N",
    columns = 2,
    title = "Histogram of lung cancer risk based on age group faceted by gender"
)