
# Interactive Scatterplots (Age, Cortical Thickness & Surface Area) 

This notebook loads **ESMRMB25_sample_data_ct_sa.csv** and builds an interactive Bokeh scatter with:
- X/Y axis switching (age, thickness, surface area)
- Horizontal jitter on x to reduce overplotting 
- Age range filter and optional lobe filter
- Hover tooltips

In [None]:
# Download the requirements.txt file (if running on your own computer and not using Google Colab)
! wget -nc https://raw.githubusercontent.com/saigerutherford/esmrmb_data_viz/main/requirements.txt

In [None]:
# Download the requirements.txt file (if using Google Colab and not running on your own computer)
! wget -nc https://raw.githubusercontent.com/saigerutherford/esmrmb_data_viz/main/requirements_colab.txt

In [None]:
# Download the sample data set
! wget -nc https://raw.githubusercontent.com/saigerutherford/esmrmb_data_viz/main/ESMRMB25_sample_data_ct_sa.csv

In [None]:
# Install the required python packages, uncomment depending on how you are running the tutorial
#! pip install -r requirements.txt # uncomment this line in running on a personal computer (not via Google Colab)
#! pip install -r requirements_colab.txt # un comment this line if running on Google Colab

In [None]:
from bokeh.io import output_notebook
output_notebook()

In [None]:
import pandas as pd, numpy as np

# Adjust the path to where you downloaded the sample data
csv_path = "ESMRMB25_sample_data_ct_sa.csv"
df = pd.read_csv(csv_path)

# Coerce numeric columns we will plot
for c in ["age","thickness","surf_area"]:
    if c in df.columns:
        df[c] = pd.to_numeric(df[c], errors="coerce")

# Drop rows missing critical numerics
df = df.dropna(subset=["age","thickness","surf_area"]).reset_index(drop=True)

# Size column for scatter based on surface area
size_min, size_max = 6, 22
tmin, tmax = df["surf_area"].min(), df["surf_area"].max()
df["size_var"] = (size_min + size_max)/2.0 if tmin == tmax else np.interp(df["surf_area"], [tmin, tmax], [size_min, size_max])

df.head()

In [None]:
df.groupby(by='subject_id')['thickness'].describe()

In [None]:
df.groupby(by='subject_id')['surf_area'].describe()


## Static scatter + hover (quick)


In [None]:
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.io import show
from bokeh.transform import jitter

source_quick = ColumnDataSource(df)
p = figure(height=600, width=1000, tools="pan,wheel_zoom,box_zoom,reset,save",
           x_axis_label="Age", y_axis_label="Cortical Thickness")
p.scatter(x=jitter('age', width=0.9, range=p.x_range), y="thickness", size=8, alpha=0.7, source=source_quick)
p.add_tools(HoverTool(tooltips=[("ROI","@ROI_name"),("Age","@age"),("CT","@thickness")]))
show(p)

In [None]:
from bokeh.models import (
    ColumnDataSource, Select, RangeSlider, CheckboxGroup, CustomJS, Div, HoverTool,
    Jitter
)
from bokeh.layouts import column, row
from bokeh.plotting import figure, show
from bokeh.transform import factor_cmap
from bokeh.palettes import Category10

# -----------------------------
# 0) Ensure we have a size column (used by scatter size=...)
# -----------------------------
if "size_var" not in df.columns:
    size_min, size_max = 6, 22
    tmin, tmax = df["surf_area"].min(), df["surf_area"].max()
    df["size_var"] = (size_min + size_max) / 2.0 if tmin == tmax else np.interp(
        df["surf_area"], [tmin, tmax], [size_min, size_max]
    )

# -----------------------------
# 1) Which numeric columns can be mapped to axes?
# -----------------------------
numeric_fields = ["age", "thickness", "surf_area"]
axis_options = [
    (f, "Surface area" if f == "surf_area" else f.capitalize())
    for f in numeric_fields
    if f in df.columns
]

# -----------------------------
# 2) Categorical column for color (legend)
# -----------------------------
cat_col = "lobe" if "lobe" in df.columns else None
if cat_col:
    cats = sorted(df[cat_col].dropna().unique().tolist())
    palette_vals = Category10[10] if len(cats) <= 10 else (
        Category10[10] * ((len(cats) // 10) + 1)
    )[: len(cats)]
else:
    cats, palette_vals = [], []

# -----------------------------
# 3) Data sources: original (for filtering), and a working "source" for the plot
# -----------------------------
original = ColumnDataSource(df.copy())
source   = ColumnDataSource(df.copy())

# -----------------------------
# 4) Widgets (no jitter slider)
# -----------------------------
x_select = Select(title="X Axis", value="age",       options=[(k, v) for k, v in axis_options])
y_select = Select(title="Y Axis", value="thickness", options=[(k, v) for k, v in axis_options])

age_min, age_max = int(np.floor(df["age"].min())), int(np.ceil(df["age"].max()))
age_slider = RangeSlider(title="Age range", start=age_min, end=age_max,
                         value=(age_min, age_max), step=1)

lobe_group = CheckboxGroup(labels=cats, active=list(range(len(cats)))) if cats else None

# -----------------------------
# 5) Figure + glyph:
# -----------------------------
plot = figure(height=600, width=1000, tools="pan,wheel_zoom,box_zoom,reset,save")

# Fixed jitter width = 0.6
jit = Jitter(width=0.6, range=plot.x_range)

if cat_col:
    r = plot.scatter(
        x={"field": x_select.value, "transform": jit},
        y={"field": y_select.value},
        size="size_var",
        fill_color=factor_cmap(cat_col, palette=palette_vals, factors=cats),
        line_color=None,
        alpha=0.85,
        source=source,
        legend_field=cat_col,
    )
    # Legend tweaks
    plot.legend.label_text_font_size = "8pt"
    plot.legend.glyph_width = 10
    plot.legend.spacing = 2
    plot.legend.padding = 4
    plot.add_layout(plot.legend[0], "right")
else:
    r = plot.scatter(
        x={"field": x_select.value, "transform": jit},
        y={"field": y_select.value},
        size="size_var",
        line_color=None,
        alpha=0.85,
        source=source,
    )

# Axis labels reflect current selections
plot.xaxis.axis_label = dict(axis_options)[x_select.value]
plot.yaxis.axis_label = dict(axis_options)[y_select.value]

# Hover info
plot.add_tools(HoverTool(tooltips=[
    ("ROI", "@ROI_name"),
    ("age", "@age"),
    ("thickness", "@thickness"),
    ("surf_area", "@surf_area"),
    ("sex", "@sex"),
]))

# -----------------------------
# 6) CustomJS callback:
# -----------------------------
args = dict(
    source=source, original=original,
    x_select=x_select, y_select=y_select,
    age_slider=age_slider, plot=plot, r=r, jit=jit
)
if lobe_group:
    args["lobe_group"] = lobe_group

callback = CustomJS(args=args, code="""
    const data = {...original.data};
    const x = x_select.value;
    const y = y_select.value;
    const a0 = age_slider.value[0];
    const a1 = age_slider.value[1];

    const out = {};
    for (const k of Object.keys(data)) out[k] = [];

    let active_lobes = null;
    if (typeof lobe_group !== 'undefined' && lobe_group) {
        active_lobes = new Set(lobe_group.active.map(i => lobe_group.labels[i]));
    }

    const N = data['age'].length;
    for (let i = 0; i < N; i++) {
        const age = data['age'][i];
        const xv = data[x][i];
        const yv = data[y][i];
        const age_ok = Number.isFinite(age) && age >= a0 && age <= a1;
        const finite_ok = Number.isFinite(xv) && Number.isFinite(yv);
        const lobe_ok = active_lobes ? active_lobes.has(data['lobe'][i]) : true;

        if (age_ok && finite_ok && lobe_ok) {
            for (const k of Object.keys(data)) out[k].push(data[k][i]);
        }
    }

    source.data = out;
    source.change.emit();

    // rebind glyph fields (keeps the same transform object)
    r.glyph.x = { field: x, transform: jit };
    r.glyph.y = { field: y };
    r.glyph.change.emit();

    if (plot.xaxis && plot.xaxis[0]) plot.xaxis[0].axis_label = x;
    if (plot.yaxis && plot.yaxis[0]) plot.yaxis[0].axis_label = y;

    plot.change.emit();
    plot.request_render();
""")

# Wire the widgets to the callback
for w in [x_select, y_select, age_slider]:
    w.js_on_change("value", callback)
if lobe_group:
    lobe_group.js_on_change("active", callback)

# -----------------------------
# 7) Layout and show
# -----------------------------
controls = [x_select, y_select, age_slider]
if lobe_group:
    controls += [Div(text="<b>Lobes</b>"), lobe_group]

show(row(plot, column(*controls)))