In [2]:
%run constants.ipynb
%run clean_functions.ipynb

In [14]:
def read_clean_data() -> tuple[pd.DataFrame, dict[str, str]]:
    # Load the arguments
    with open(ARGUMENTS_FILENAME, "r") as fid:
        arguments = json.load(fid)

    df = pd.read_csv(arguments["dataframe_outfile"])
    df["date clean"] = pd.to_datetime(df["date clean"])
    df.set_index("date clean", inplace=True)
    df.sort_index(inplace=True)
    
    return df, arguments

#### Cervical mucus charting

In [None]:
def lookup_color(col: pd.Series) -> list[str]:
    """Return the correct background color, depending on whether it's a symbol or other text."""
    return [
        f"background-color: {cervical_mucus_colors[text]}; color: white; font-weight: bold;"
        if text in cervical_mucus_colors
        else f"background-color: white); color: black;"
        for text in col
    ]


def get_chart_with_text(valid_subset: pd.DataFrame) -> pd.DataFrame:
    table_df = valid_subset[["nth_cycle", "day of cycle", "symbol", "cervical mucus description"]].reset_index()
    table_df["date"] = table_df["date clean"].dt.strftime("%m/%d")
    n_chars = 30
    table_df["cervical mucus description"] = table_df["cervical mucus description"].str[:n_chars]
    table = table_df.melt(
        id_vars=["nth_cycle", "day of cycle"],
        value_vars=["date", "symbol", "cervical mucus description"],
    ).pivot(
        index=["nth_cycle", "variable"],
        columns="day of cycle",
        values="value",
    ).reset_index().sort_values("variable", ascending=False).sort_values("nth_cycle").fillna("")

    table.loc[table["variable"] == "cervical mucus description", "variable"] = "description"
    table["variable"] = pd.Categorical(table["variable"], ["date", "symbol", "description"])
    table.sort_values(["nth_cycle", "variable"], inplace=True)
    table.set_index(["nth_cycle", "variable"], inplace=True)
    return table


def chart_cervical_mucus(
    df: pd.DataFrame,
    with_text: bool,  # whether to include dates and descriptions
    outfile: Union[str, None] = None,
) -> pd.io.formats.style.Styler:
    # Create dataset where we have cycle data, so we can always pivot on day
    subset = df[df["day of cycle"].notnull()].copy()
    subset["day of cycle"] = subset["day of cycle"].astype(int)
    subset["nth_cycle"] = subset["nth_cycle"].astype(int)
    subset = subset[subset["valid_cycle"]]

    if with_text:
        table = get_chart_with_text(subset)
    else:
        table = subset.pivot(
            index="nth_cycle",
            columns="day of cycle",
            values="symbol",
        ).fillna("")

    styled_table = table.style.apply(lookup_color)
    styled_table = styled_table.set_properties(**{"max-width": "80px"})
    
    # Save the charting image as a PNG file
    if outfile is not None:
        dfi.export(
            styled_table,
            f"{arguments['output_directory']}/{outfile}.png",
            dpi=300,
            max_cols=-1,
        )

    return styled_table

#### Plot a categorical variable over time

In [None]:
def plot_categorical_rate(
    metric: str,
    days_mean: int,
    possible_values: list[str],
    nan_is_no: bool = False,
):
    for possible_value in possible_values:
        df[f"{metric}_{possible_value}"] = (df[metric] == possible_value).astype(int)
        # assume no value means "no", if desired
        if nan_is_no and possible_value == "no":
            df.loc[df[metric].isnull(), f"{metric}_{possible_value}"] = 1

    rate_metrics = [
        f"{metric}_{possible_value}"
        for possible_value in possible_values
    ]

    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
    for rate_metric in rate_metrics:
        if df[rate_metric].mean() == 0:
            continue

        plt.plot(
            df[rate_metric].rolling(f"{days_mean}D").mean() * 100,
            label=rate_metric.replace(f"{metric}_", ""),
        )
        fig.autofmt_xdate()

    plt.title(f"{metric.title()} frequency", size=14)
    plt.xlabel("Date", size=12)
    plt.ylabel(f"% of days with this type of {metric} ({days_mean}-day rolling mean)", size=12)
    plt.legend(title=metric.title())

#### Plot numeric metrics over time

In [25]:
def plot_numeric(df: pd.DataFrame, metrics: list[str], days_window: int) -> None:
    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
    
    if len(metrics) == 1:
        metric = metrics[0]
        plt.plot(
            df[metric].rolling(f"{days_window}D").mean(),
        )
        title = metric.replace("_", " ").replace(" numeric", "")
        plt.title(title)
        plt.ylabel(f"{days_window}-day rolling mean")

    # compare several metrics
    else:
        for metric in metrics:
            plt.plot(
                df[metric].rolling(f"{days_window}D").mean(),
                label=metric.replace("_", " ").replace(" numeric", ""),
            )

        plt.title(f"{days_window}-day rolling mean", size=14)
        plt.xlabel("Date", size=12)
        plt.ylabel(f"{days_window}-day rolling mean", size=12)
        plt.legend(title="Metric")

    fig.autofmt_xdate()

#### Superimpose menstruation and ovulation

In [None]:
def plot_superimpose_cycle(
    df: pd.DataFrame,
    metric: str,
    days_window: int,
    only_valid_cycles: bool,
) -> None:
    if only_valid_cycles:
        plot_data = df.loc[df["valid_cycle"]].copy()
    else:
        plot_data = df.copy()

    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
    
    plt.plot(
        plot_data[metric].rolling(f"{days_window}D").mean(),
        label=metric.replace("_", " ").replace(" numeric", ""),
    )

    ## Peak day
    plotted_peak = False
    peak_days = plot_data.loc[plot_data["peak_day"] == True].index.values
    for peak_day in peak_days:
        # only make one legend label
        if plotted_peak:
            label = None
        else:
            label = "Peak Day"
            plotted_peak = True
        # approximate ovulation date, made a bit wider for visibility
        plt.axvspan(
            xmin=peak_day + np.timedelta64(-1,'D'),
            xmax=peak_day,
            color="royalblue",
            alpha=0.4,
            label=label,
        )

    ## Menstruation
    plotted_menstruation = False
    for cycle in plot_data["nth_cycle"].unique():
        if np.isnan(cycle):
            continue

        first_day_df = plot_data.loc[
            (plot_data["nth_cycle"] == cycle) &
            (plot_data["day of cycle"] == 1)
        ]

        last_day_df = plot_data.loc[
            (plot_data["nth_cycle"] == cycle) &
            plot_data["last_bleeding_day"]
        ]
        if first_day_df.empty or last_day_df.empty:
            continue
        first_day = first_day_df.index.values[0]
        last_day = last_day_df.index.values[0]

        # only make one legend label
        if plotted_menstruation:
            label = None
        else:
            label = "Menstruation"
            plotted_menstruation = True
        plt.axvspan(
            xmin=first_day,
            xmax=last_day,
            color="salmon",
            alpha=0.4,
            label=label,
        )

    fig.autofmt_xdate()

    plt.title(f"{metric} {days_window}-day rolling mean", size=14)
    plt.ylabel(f"{metric} {days_window}-day rolling mean", size=12)
    plt.legend(title="Metric/event")
    
    fig.savefig(
        f"{arguments['output_directory']}/{metric}_valid_only_{only_valid_cycles}.png",
    )

### Plot by cycle day, rather than by date

Note that metrics are standardized.

In [30]:
def get_metric_label(metric: str) -> str:
    return metric.replace("_", " ").replace(" numeric", "")

In [31]:
def plot_by_cycle_day(
    df: pd.DataFrame,
    metrics: list[str],
    days_window: int,
    standardize: bool,
) -> None:
    plot_data = df.loc[df["valid_cycle"]].copy()
    plot_data = plot_data.set_index("day of cycle").sort_index()

    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")

    for metric in metrics:
        # create a standardized metric, so we can compare many at once
        if standardize:
            plot_data[metric] = (
                (plot_data[metric] - plot_data[metric].mean())
                / (2 * plot_data[metric].std())
            )
            ylabel = f"Change in 2 SD, {days_window}-day rolling mean"
        else:
            ylabel = f"{days_window}-day rolling mean"

        # Note: confidence intervals are likely too tight b/c of high correlation among similar days of cycle
        # TODO: fit ARIMA model
        sns.lineplot(
            x=plot_data.index,
            y=plot_data[metric].rolling(window=days_window).mean(),
            ci=False,
            label=get_metric_label(metric),
        )

    plt.xlabel("Day of Cycle", size=12)
    plt.ylabel(ylabel, size=12)
    plt.title(f"By day of cycle, {days_window}-day mean", size=14)
    plt.legend(title="Metric")

In [27]:
def histogram(df: pd.DataFrame, metric: str) -> None:
    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
    df[metric].hist(bins=20)
    metric_label = get_metric_label(metric)
    plt.xlabel(metric_label, size=12)
    plt.ylabel("Frequency", size=12)
    plt.title(metric_label, size=14)

In [76]:
def plot_by_group(
    plot_type: str,
    df: pd.DataFrame,
    x_var: str,
    y_var: str,
    start_dt: Union[datetime.datetime, None] = None,
) -> None:
    if start_dt is not None:
        plot_data = df[df.index >= start_dt].copy()
    else:
        plot_data = df.copy()
    
    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
    if plot_type == "violin":
        sns.violinplot(
            x=plot_data[x_var],
            y=plot_data[y_var],
            color="gray",
        )
    elif plot_type == "line":
        sns.lineplot(
            x=plot_data[x_var],
            y=plot_data[y_var],
        )
    elif plot_type == "scatter":
        sns.scatterplot(
            x=plot_data[x_var],
            y=plot_data[y_var],
        )        

    x_label = get_metric_label(x_var)
    y_label = get_metric_label(y_var)
    plt.xlabel(x_label, size=12)
    plt.ylabel(y_label, size=12)
    plt.title(f"{y_label} by {x_label}", size=14)

### Snapshot of condition (symptoms, behavior)

e.g. to give a doctor a brief overview

In [None]:
def clean_categorical(s: str) -> str:
    """Create binary from string. Assume missing means 0."""
    if re.match("yes", s, re.I):
        return 1
    if re.match("a bit", s, re.I) or re.match("a little", s, re.I):
        return 1
    # we are only trying to determine if it's an issue, not its magnitude
    if re.match("extreme", s, re.I):
        return 1
    if re.match("^no", s, re.I):
        return 0
    # if we want to ignore missing values
    if s == "unknown":
        return np.nan
    return 0


def generate_summary(
    df: pd.DataFrame,
    start_dt: datetime.datetime,
    end_dt: datetime.datetime,
    categorical_metrics: list[str] = [],
    numeric_metrics: list[str] = [],
    min_pain: int = 4,  # minimum pain (0-10) to consider it worth reporting
) -> str:
    valid = (df.index >= start_dt) & (df.index <= end_dt)
    time_period = df[valid].copy()
    for metric in numeric_metrics:
        time_period[metric] = pd.to_numeric(time_period[metric].astype(str).str.replace(",", ""), errors="coerce")
    
    start_dt_str = start_dt.strftime("%m/%d/%y")
    end_dt_str = end_dt.strftime("%m/%d/%y")
    summary = f"{start_dt_str} - {end_dt_str}\n\n"

    # go in decending order of %
    cat_summaries = []
    for metric in categorical_metrics:
        # Pain metrics are numeric, so we need to create a bool for "pain was > {min_pain}"
        if "pain" in metric:
            time_period[f"{metric}_cat"] = (time_period[metric] >= min_pain).astype(int)
        elif df[metric].dtype == bool:
            time_period[f"{metric}_cat"] = time_period[metric]
        elif df[metric].dtype == object:
            time_period[f"{metric}_cat"] = time_period[metric].astype(str).map(clean_categorical)

        pct = round(time_period[f"{metric}_cat"].mean() * 100)
        
        metric_str = metric.replace(" numeric", "").replace("_bool", "")
        
        if "pain" in metric:
            # add 1 b/c drs don't take pain seriously
            metric_str = f"{metric_str} at least {min_pain + 1} / 10"

        cat_summaries.append((pct, f"{pct}% of days had {metric_str}."))

    cat_summaries.sort(reverse=True)
    summary += "\n".join([item[1] for item in cat_summaries])
    summary += "\n\n"
    
    for metric in numeric_metrics:
        mean = time_period[metric].mean()
        if np.isnan(mean):
            continue
        # doctors like round numbers
        mean = int(round(mean))

        if metric == "steps":
            mean_str = "{:,}".format(mean)
        elif metric == "previous night Oura score":
            mean_str = f"{mean} / 100"
        elif "pain" in metric:
            mean_str = f"{mean} / 10"
        else:
            mean_str = str(mean)

        if metric in ["BMs", "coffee", "steps"]:
            metric_str = f"{metric} / day"
        elif metric == "previous night Oura score":
            metric_str = "sleep score"
        else:
            metric_str = metric

        res = f"{mean_str} {metric_str}"
        summary += res + "\n"
        
    return summary

### Snapshot of medications

In [None]:
def get_medications_quantified(df: pd.DataFrame) -> list[str]:
    """Get list of medications ever taken. Only includes ones with known dosage."""
    medications = []
    for unit in DRUG_UNITS:
        medications.extend([c for c in df.columns if c.endswith(f" {unit}")])

    # Remove redundant medications, we already track THC and CBD mg
    medications = [
        med for med in medications
        if med not in ["THC/CBD ml", "THC ml"]
    ]
    return medications

In [None]:
def generate_med_summary(
    df: pd.DataFrame,
    end_dt: datetime.datetime,
    n_days: int,  # total days to include, including end_dt
) -> str:
    medications = get_medications_quantified(df)

    start_dt = end_dt - datetime.timedelta(n_days - 1)
    valid = (df.index >= start_dt) & (df.index <= end_dt)
    time_period = df[valid].copy()
    # This isn't true if we skipped days
    if len(time_period) < n_days:
        missing_days = n_days - len(time_period)
        print(f"WARNING: missing {missing_days} days.")
    
    start_dt_str = start_dt.strftime("%m/%d/%y")
    end_dt_str = end_dt.strftime("%m/%d/%y")
    if n_days == 1:
        summary = f"{start_dt_str}\n\n"
    else:
        summary = f"{start_dt_str} - {end_dt_str}\n\n"

    for med in medications:
        mean = int(round(time_period[med].mean()))
        if mean == 0:
            continue
        unit = med.rsplit(" ", 1)[1]
        med_str = med.rsplit(" ", 1)[0]
        
        previous1 = df[med].shift(1)
        previous2 = df[med].shift(2)
        date_started = df[med][
            (previous1.fillna(0) == 0)
            & (previous2.fillna(0) == 0)
            & df[med] > 0
        ].index[-1]
        date_started_str = date_started.strftime("%m/%d/%y")

        summary += f"{mean} {unit} {med_str}, started {date_started_str}\n"

    return summary

In [None]:
def get_med_plot_data(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    df_copy = df.copy()
    df_copy["date_clean"] = df_copy.index
    medications = get_potential_drugs(df_copy)
    # remove redundant meds already captured by "THC" and "CBD"
    medications = [med for med in medications if med not in ["THC/CBD", "CBD oil", "THC/CBD oil"]]
    med_cols = [f"daily_{med}" for med in medications]
    max_dates = {}
    for med_col in med_cols:
        max_date = df_copy.loc[df_copy[med_col] == True, "date_clean"].max()
        max_dates[med_col] = max_date

    max_dates_df = pd.DataFrame.from_dict(max_dates, orient="index").reset_index()
    max_dates_df.columns = ["med", "date"]
    max_dates_df["med"] = max_dates_df["med"].str.replace("daily_", "")
    max_dates_df = max_dates_df[max_dates_df["date"].notnull()]

    # ignore medications I haven't taken in a long time
    max_days_old = 6 * 30
    last_date = df[-1:].index.values[0]
    min_date = pd.to_datetime(last_date) - datetime.timedelta(days=max_days_old)
    max_dates_df = max_dates_df[max_dates_df["date"] > min_date]
    max_dates_df = max_dates_df[max_dates_df["med"].isin(medications)]

    max_dates_df.sort_values(["med"], ascending=False, kind="mergesort", inplace=True)
    max_dates_df.sort_values(["date"], ascending=True, kind="mergesort", inplace=True)
    max_dates_df.reset_index(inplace=True, drop=True)
    max_dates_df["plot_value"] = max_dates_df.index + 1
    plot_values = max_dates_df.set_index("med")["plot_value"].to_dict()

    cols = [f"daily_{med}" for med in plot_values.keys()]
    plot_data = (df[cols] * plot_values.values()).reset_index()
    plot_data = plot_data.melt(id_vars="date clean")
    plot_data.columns = ["date", "med", "value"]

    # we don't want to show dates where the med wasn't taken
    plot_data.loc[plot_data["value"] == 0, "value"] = np.nan
    return plot_data, max_dates_df

In [None]:
def plot_meds_over_time(df: pd.DataFrame) -> pd.DataFrame:
    plot_data, max_dates_df = get_med_plot_data(df)

    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
    sns.scatterplot(
        x=plot_data["date"],
        y=plot_data["value"],
        hue=plot_data["med"],
        legend=False,
    )

    plt.title("Medications", size=14)
    plt.xlabel("Date", size=12)
    plt.ylabel("")

    ax.set_xlim(right=plot_data["date"].max() + datetime.timedelta(days=80))

    for _, row in max_dates_df.iterrows():
        ax.text(
            row["date"] + datetime.timedelta(days=5),
            row["plot_value"],
            row["med"],
        )

    ax.yaxis.set_ticks([])
    fig.autofmt_xdate()
    
    fig.savefig(
        f"{arguments['output_directory']}/medications.png",
    )

In [None]:
def get_diet_data(
    df: pd.DataFrame,
    diets: list[str],
) -> tuple[pd.DataFrame, pd.DataFrame, dict[str: int]]:
    diet_data = pd.DataFrame()
    for diet in diets:
        diet_data[diet] = df["diet name"] == diet

    metrics = [
        "BMs > 3",
        "stomach pain > 4",
        "diarrhea_bool",
        "uncomfortable stomach_bool",
        "bloated stomach_bool",
    ]

    metrics_dict = dict(zip(metrics, range(1, len(metrics) + 1)))
    metric_data = (df[metrics_dict.keys()] * metrics_dict.values()).reset_index()
    metric_data = metric_data.melt(id_vars="date clean")
    metric_data.columns = ["date", "metric", "value"]
    metric_data = metric_data[metric_data["value"] != 0]
    metric_data["metric"] = metric_data["metric"].str.replace("_bool", "")
    metrics_dict = {metric.replace("_bool", ""): val for metric, val in metrics_dict.items()}
    
    return df, metric_data, diet_data, metrics_dict

In [None]:
def plot_diet_metric(
    metric: str,
    df: pd.DataFrame,
    diet_data: pd.DataFrame,
    start_dt: Union[pd.datetime, None] = None,
    end_dt: Union[pd.datetime, None] = None,
    rolling_days: int = 7,
    include_scatter: bool = False,
    include_foods: bool = False,
) -> None:
    plot_data = df[metric].reset_index()
    plot_data = plot_data[plot_data[metric].notnull()]
    
    if start_dt is not None:
        plot_data = plot_data[plot_data["date clean"] >= start_dt]
    if end_dt is not None:
        plot_data = plot_data[plot_data["date clean"] <= end_dt]
    
    fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
    sns.lineplot(
        x=plot_data["date clean"],
        y=plot_data[metric].rolling(rolling_days).mean(),
    )
    
    if include_scatter:
        sns.scatterplot(
            x=plot_data["date clean"],
            y=plot_data[metric],
            ax=ax,
        )

    ## Shade areas by diet
    diet_colors = dict(zip(diet_data.columns, sns.color_palette("pastel")[: len(diet_data.columns)]))
    for diet, color in diet_colors.items():
        date_df = diet_data[diet_data[diet] == True]
        plt.axvspan(
            xmin=max(plot_data["date clean"].min(), date_df.index.min()),
            xmax=date_df.index.max(),
            alpha=0.4,
            label=diet,
            color=color,
        )

    ## Add labels of foods added back
    if include_foods:
        foods_added_back = df["adding back"][df["adding back"].notnull()].reset_index()
        for _, row in foods_added_back.iterrows():
            ax.text(
                row["date clean"],
                plot_data[metric].rolling(rolling_days).mean().max() * 0.9,
                row["adding back"],
                rotation=70,
            )
    
    metric_str = metric.replace("_bool", "")
    plt.title(f"{metric_str} by Diet", size=14)
    plt.xlabel("Date", size=12)
    plt.ylabel(f"{metric_str}, {rolling_days}-day rolling mean", size=12)
    fig.autofmt_xdate()
    legend = plt.legend(title="Diet", fontsize=12)
    ax.legend(bbox_to_anchor=(1, 0.95))
    plt.setp(legend.get_title(), fontsize="large")

    metric_str = metric.replace(" ", "_")
    fig.savefig(
        f"{arguments['output_directory']}/diet_{metric_str}.png",
    )