In [1]:
# visualize_clusters.ipynb
import pandas as pd
import sys
import pandas as pd
import matplotlib.pyplot as plt
import textwrap
import glob
import seaborn as sns
import numpy as np
from collections import defaultdict
import matplotlib.backends.backend_pdf
import os
import glob
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import utils

matplotlib.rcParams["font.family"] = "Arial"  # or 'Helvetica' if available


# YYYYMMDD = "20231209"
YYYYMMDD = "20231227"
CLUSTER_DIR = f"./outputs/{YYYYMMDD}/clusters/"  # Input
RECLUSTER_DIR = f"./outputs/{YYYYMMDD}/recluster/"  # Input
OUTPUT_DIR = f"./outputs/{YYYYMMDD}/"  # TODO: change var name to Base-dir
OUTPUT_REPORT_FILE = f"{OUTPUT_DIR}/cluster_report_robust.txt"
OUTPUT_CLUSTER_DEFS_FILE = f"{OUTPUT_DIR}/cluster_def_all.csv"
OUTPUT_ROBUST_CLUSTERS_PRIMARY_FILE = f"{OUTPUT_DIR}/robust_primary_cluster_all.csv"
OUTPUT_ROBUST_CLUSTERS_SECONDARY_FILE = f"{OUTPUT_DIR}/robust_secondary_cluster_all.csv"
# if not os.path.exists(OUTPUT_DIR):
#     os.mkdir(OUTPUT_DIR)


questions_file = "prompts/questions_8.txt"
questions_df = pd.read_csv(questions_file, sep="\t")
questions_dict = {
    f"Q{(row.question_number-1):02d}": row.question
    for index, row in questions_df.iterrows()
}

cluster_defs = glob.glob(CLUSTER_DIR + "/Q*metadata*.csv")
cluster_defs.sort()
print(cluster_defs[:5])
print(pd.read_csv(cluster_defs[0]).head())
print()

cluster_files = glob.glob(CLUSTER_DIR + "/Q*cluster*.csv")
cluster_files.sort()
print(cluster_files[:5])
print(
    pd.read_csv(cluster_files[0]).head()
)  # Has secondary clusters. Note columns: subject_id top_level_cluster_id  secondary_cluster_ids
print(
    pd.read_csv(cluster_files[3]).head()
)  # No secondary clusters. Note columns: subject_id cluster_ids
print()

recluster_files = glob.glob(RECLUSTER_DIR + "/Q*recluster*.tsv")
recluster_files.sort()
print(recluster_files[:5])
print(
    pd.read_csv(recluster_files[0], sep="\t").head()
)  # Has secondary clusters. Note columns: subject_id top_level_cluster_id  secondary_cluster_ids
print(
    pd.read_csv(recluster_files[3], sep="\t").head()
)  # Has secondary clusters with dummy assignments. Use the respective cluster_file to figure out if one-level or two-level clustering has been done
print()

['./outputs/20231227/clusters/Q00_metadata.csv', './outputs/20231227/clusters/Q01_metadata.csv', './outputs/20231227/clusters/Q02_metadata.csv', './outputs/20231227/clusters/Q03_metadata.csv', './outputs/20231227/clusters/Q04_metadata.csv']
  cluster_id                    cluster_name   
0         C1            Young Adults (22-33)  \
1         C2      Middle-aged Adults (34-45)   
2         C3            Older Adults (46-60)   
3         C4          Seniors (61 and above)   
4         C5  Unclear/irrelevant/no response   

                                 cluster_description  
0       Subjects who are between 22 and 33 years old  
1       Subjects who are between 34 and 45 years old  
2       Subjects who are between 46 and 60 years old  
3             Subjects who are 61 years old or above  
4  Subjects who have not provided their age or th...  

['./outputs/20231227/clusters/Q00_clusters.csv', './outputs/20231227/clusters/Q01_clusters.csv', './outputs/20231227/clusters/Q02_clusters.

In [2]:
def explode_cluster_ids(df, column="cluster_ids"):
    assert column in df.columns, df.columns
    df[column] = df[column].str.split(",")
    df = df.explode(column)
    df[column] = df[column].str.strip()
    df = df.reset_index(drop=True)
    return df


def get_cluster_count_mean_stds(exploded_df, column="cluster_ids"):
    # Get per-cluster mean and std (over iterations) of number of subjects assigned to each cluster (in exploded df format)
    assert "iteration" in exploded_df.columns
    assert "subject_id" in exploded_df.columns
    assert column in exploded_df.columns
    per_iteration_counts = (
        exploded_df.groupby([column, "iteration"])
        .agg(subject_id_count=("subject_id", "count"))
        .reset_index()
    )
    per_cluster_mean_stds = (
        per_iteration_counts.groupby([column])
        .agg(
            count_mean=("subject_id_count", "mean"),
            count_std=("subject_id_count", "std"),
        )
        .reset_index()
    )
    return per_cluster_mean_stds


# Utils
def get_robust_clusters(exploded_df, column="cluster_ids"):
    """Get robust cluster assignments (non-exclusive) are those that have a frequency >= 0.5 (frequency over iterations)"""
    assert [
        col in exploded_df.columns.tolist()
        for col in ["subject_id", "iteration", column]
    ]
    num_total_iterations = exploded_df["iteration"].unique().shape[0]
    threshold = np.ceil(num_total_iterations / 2)

    # calculate how many times each cluster is assigned to each subject across iterations
    repeats_per_cluster_subject = exploded_df.groupby(["subject_id", column]).agg(
        num_iterations=(column, "count")
    )
    # get robust clusters
    robust_cluster_subject = repeats_per_cluster_subject[
        repeats_per_cluster_subject["num_iterations"] >= threshold
    ].reset_index()
    # aggregate to one line per subject
    robust_cluster_subject = robust_cluster_subject.groupby("subject_id").agg(
        robust_cluster_ids=(column, lambda x: ",".join(x))
    )
    return robust_cluster_subject


def get_robust_clusters(exploded_df, column="cluster_ids"):
    """Get robust cluster assignments (non-exclusive) are those that have a frequency >= 0.5 (frequency over iterations)"""
    assert all(
        [
            col in exploded_df.columns.tolist()
            for col in ["subject_id", "iteration", column]
        ]
    )
    num_total_iterations = exploded_df["iteration"].nunique()
    threshold = np.ceil(num_total_iterations / 2)

    # Calculate how many times each cluster is assigned to each subject across iterations
    repeats_per_cluster_subject = exploded_df.groupby(["subject_id", column]).agg(
        num_iterations=(column, "count")
    )

    # Calculate the fraction of times each cluster is assigned
    repeats_per_cluster_subject["fraction"] = (
        repeats_per_cluster_subject["num_iterations"] / num_total_iterations
    )

    # Get robust clusters
    robust_cluster_subject = repeats_per_cluster_subject[
        repeats_per_cluster_subject["num_iterations"] >= threshold
    ].reset_index()

    # Aggregate to one line per subject and include fractions
    robust_cluster_subject = robust_cluster_subject.groupby("subject_id").agg(
        robust_cluster_ids=(column, lambda x: ",".join(x.astype(str))),
        fraction=("fraction", lambda x: ",".join(x.astype(str))),
    )

    return robust_cluster_subject


def add_q_idx_column(df, q_idx):
    # Add q_idx column to df and move it to the front
    df = df.assign(q_idx=q_idx)
    df.insert(0, "q_idx", df.pop("q_idx"))
    return df

In [3]:
# Load data
cluster_defs = {
    os.path.basename(f).split("_")[0]: f
    for f in glob.glob(CLUSTER_DIR + "/Q*metadata*.csv")
}
cluster_files = glob.glob(CLUSTER_DIR + "/Q*cluster*.csv")
cluster_files.sort()
recluster_files_all = glob.glob(RECLUSTER_DIR + "/Q*recluster*.tsv")

#
cluster_report = defaultdict(dict)
concatenated_cluster_def = []
concatenated_primary_cluster_df = []
concatenated_secondary_cluster_df = []
for cluster_file in cluster_files[:-1]:  # Skip last file (Q41)
    # Get associated question
    q_idx = os.path.basename(cluster_file).split("_")[0]
    # print(q_idx)
    question = questions_dict[q_idx]
    # print(f"{q_idx}: {question}")

    # Get associated cluster file
    cluster_df = pd.read_csv(cluster_file)
    assert cluster_df.shape[0] == len(utils.ALL_SUBJECT_IDS)

    cluster_df["iteration"] = 0
    # print(cluster_df.columns)
    is_two_level_clustering = (
        True if "top_level_cluster_id" in cluster_df.columns else False
    )
    # Is primary clustering non-exclusive?
    col_name = "top_level_cluster_id" if is_two_level_clustering else "cluster_ids"
    is_primary_non_exclusive = (
        cluster_df[col_name].fillna("").str.split(",").apply(len).max()
        > 1  # fillna("") is to handle NAs
    )

    # Get associated metadata
    cluster_def = pd.read_csv(cluster_defs[q_idx])

    concatenated_cluster_def.append(add_q_idx_column(cluster_def, q_idx))

    # Get associated recluster files and concatenate
    recluster_files = glob.glob(RECLUSTER_DIR + f"/{q_idx}*recluster*.tsv")
    recluster_files.sort()
    # print(recluster_files)
    cluster_reclusters_df = [cluster_df]
    for recluster_file in recluster_files:
        # print(recluster_file)
        recluster_df = pd.read_csv(recluster_file, sep="\t")
        assert recluster_df.shape[0] == len(utils.ALL_SUBJECT_IDS)

        recluster_df["iteration"] = int(
            os.path.basename(recluster_file).split("recluster")[1].split(".")[1]
        )
        cluster_reclusters_df.append(recluster_df)
    cluster_reclusters_df = pd.concat(cluster_reclusters_df)
    try:
        assert cluster_reclusters_df.shape[0] == cluster_df.shape[0] * (
            len(recluster_files) + 1
        )
    except AssertionError:
        print("Something wrong with recluster file number-of-rows")
        print(
            cluster_df.shape,
            [pd.read_csv(rf, sep="\t").shape for rf in recluster_files],
        )
        # print(cluster_reclusters_df.shape, cluster_df.shape)
        # print(cluster_reclusters_df.head())
        # raise AssertionError

    # Transform
    # If two-level clustering, rename "top_level_cluster_id" to "cluster_ids"
    if is_two_level_clustering:
        cluster_df.rename(columns={"top_level_cluster_id": "cluster_ids"}, inplace=True)
        cluster_reclusters_df.rename(
            columns={"top_level_cluster_id": "cluster_ids"}, inplace=True
        )

    # Check if population-specific question & subset
    specific_population = None
    for (
        population_name,
        population_questions,
    ) in utils.population_specific_question_ids.items():
        if q_idx in population_questions:
            specific_population = population_name
            break
    if specific_population is not None:
        cluster_df = cluster_df[
            cluster_df["subject_id"].isin(
                utils.population_subject_ids[specific_population]
            )
        ]
        assert cluster_df.shape[0] == len(
            utils.population_subject_ids[specific_population]
        )

        cluster_reclusters_df = cluster_reclusters_df[
            cluster_reclusters_df["subject_id"].isin(
                utils.population_subject_ids[specific_population]
            )
        ]
        assert cluster_reclusters_df.shape[0] == cluster_df.shape[0] * (
            len(recluster_files) + 1
        )

    # Process primary and secondary clusters
    exploded_primary_df = explode_cluster_ids(
        cluster_reclusters_df, column="cluster_ids"
    )
    primary_cluster_mean_stds = get_cluster_count_mean_stds(exploded_primary_df)
    robust_primary_clusters_df = get_robust_clusters(exploded_primary_df)
    robust_primary_clusters_df.rename(
        columns={"robust_cluster_ids": "cluster_ids"}, inplace=True
    )
    concatenated_primary_cluster_df.append(
        add_q_idx_column(robust_primary_clusters_df.reset_index(), q_idx)
    )
    if is_two_level_clustering:
        cluster_reclusters_df["secondary_cluster_ids"].fillna("", inplace=True)
        exploded_secondary_df = explode_cluster_ids(
            cluster_reclusters_df, column="secondary_cluster_ids"
        )
        secondary_cluster_mean_stds = get_cluster_count_mean_stds(
            exploded_secondary_df, column="secondary_cluster_ids"
        )
        robust_secondary_clusters_df = get_robust_clusters(
            exploded_secondary_df, column="secondary_cluster_ids"
        )
        robust_secondary_clusters_df.rename(
            columns={"robust_cluster_ids": "secondary_cluster_ids"}, inplace=True
        )
        is_secondary_non_exclusive = (
            cluster_reclusters_df["secondary_cluster_ids"]
            .fillna("")
            .astype(str)  # Ensure all entries are strings
            .str.split(",")
            .apply(len)
            .max()
            > 1
        )
        concatenated_secondary_cluster_df.append(
            add_q_idx_column(robust_secondary_clusters_df.reset_index(), q_idx)
        )

    # Save to dict
    cluster_report[q_idx]["question"] = question
    cluster_report[q_idx]["specific_population"] = specific_population
    cluster_report[q_idx]["cluster_def"] = cluster_def
    cluster_report[q_idx]["cluster_df"] = cluster_df
    cluster_report[q_idx]["cluster_reclusters_df"] = cluster_reclusters_df
    cluster_report[q_idx]["primary_cluster_mean_stds"] = primary_cluster_mean_stds
    cluster_report[q_idx]["is_primary_non_exclusive"] = is_primary_non_exclusive
    cluster_report[q_idx]["robust_primary_clusters_df"] = robust_primary_clusters_df
    cluster_report[q_idx]["is_two_level_clustering"] = is_two_level_clustering
    cluster_report[q_idx]["secondary_cluster_mean_stds"] = (
        secondary_cluster_mean_stds if is_two_level_clustering else None
    )
    cluster_report[q_idx]["robust_secondary_clusters_df"] = (
        robust_secondary_clusters_df if is_two_level_clustering else None
    )
    cluster_report[q_idx]["is_secondary_non_exclusive"] = (
        is_secondary_non_exclusive if is_two_level_clustering else None
    )

    # Validation
    if specific_population in utils.population_subject_ids.keys():
        if (
            robust_primary_clusters_df.shape[0]
            != len(utils.population_subject_ids[specific_population])
            and specific_population is not None
        ):
            print(
                f"WARNING: {q_idx} has {robust_primary_clusters_df.shape[0]} subjects instead of {len(utils.population_subject_ids[specific_population])} [{specific_population}]"
            )
            # print(robust_primary_clusters_df.index.to_list())
            # print(robust_primary_clusters_df.head())
            print(
                "Robust clusters not found for subjects: ",
                set(utils.population_subject_ids[specific_population])
                - set(robust_primary_clusters_df.index.to_list()),
            )

    else:
        if robust_primary_clusters_df.shape[0] != len(utils.ALL_SUBJECT_IDS):
            print(
                f"WARNING: {q_idx} has {robust_primary_clusters_df.shape[0]} subjects instead of {len(utils.ALL_SUBJECT_IDS)}"
            )
            # print(robust_primary_clusters_df.head())
            print(
                "Robust clusters not found for subjects: ",
                set(utils.ALL_SUBJECT_IDS)
                - set(robust_primary_clusters_df.index.to_list()),
            )

# Done processing all questions -- save to files
pd.concat(concatenated_cluster_def).to_csv(OUTPUT_CLUSTER_DEFS_FILE, index=False)
pd.concat(concatenated_primary_cluster_df).to_csv(
    OUTPUT_ROBUST_CLUSTERS_PRIMARY_FILE, index=False
)
pd.concat(concatenated_secondary_cluster_df).to_csv(
    OUTPUT_ROBUST_CLUSTERS_SECONDARY_FILE, index=False
)
print("Saved cluster definitions and robust cluster assignments...")

print(q_idx, question)
print()
print(cluster_df)
# print(primary_cluster_mean_stds)
print()
# print(cluster_def)
# robust_primary_clusters_df.reset_index()
# cluster_def

Robust clusters not found for subjects:  {'C065'}
Robust clusters not found for subjects:  {'C060', 'C008'}
Robust clusters not found for subjects:  {'C084'}
Robust clusters not found for subjects:  {'C095', 'C064'}
Robust clusters not found for subjects:  {'C041'}
Robust clusters not found for subjects:  {'C002', 'C093', 'C073', 'C056'}
Robust clusters not found for subjects:  {'C072'}
Saved cluster definitions and robust cluster assignments...
Q40 If student or trainee, do you agree with your school's policies regarding medical students' roles at this time?

   subject_id cluster_ids secondary_cluster_ids  iteration
3        C004          C3                  C3.1          0
5        C006          C1                  C1.1          0
6        C007          C1                  C1.3          0
7        C008          C2                  C2.1          0
22       C024          C1                  C1.1          0
23       C025          C1                  C1.1          0
27       C030       

In [4]:
# DEBUG
# Group by subject_id and iteration to get a list of cluster_ids for each subject_id and iteration
# grouped_grouped_df = (
#     exploded_primary_df.groupby(["subject_id", "iteration"])["cluster_ids"]
#     .agg(list)
#     .reset_index()
#     .groupby("subject_id")["cluster_ids"]
#     .agg(list)
#     .reset_index()
# )
# print(grouped_grouped_df)
# print(get_robust_clusters(exploded_primary_df))
# print(cluster_df)

# secondary_cluster_mean_stds = get_cluster_count_mean_stds(exploded_secondary_df)
# exploded_secondary_df
# cluster_reclusters_df
# exploded_secondary_df
# robust_secondary_clusters_df

# DEBUG
# for q_idx, q_data in cluster_report.items():
#     if q_data["is_two_level_clustering"]:
#         is_secondary_non_exclusive = q_data["is_secondary_non_exclusive"]
#         if is_secondary_non_exclusive:
#             print(q_idx, q_data["question"])
#             cluster_reclusters_df = q_data["cluster_reclusters_df"]
#             rows_with_nonexclusive_secondary_clusters = cluster_reclusters_df[
#                 cluster_reclusters_df["secondary_cluster_ids"]
#                 .fillna("")
#                 .astype(str)  # Ensure all entries are strings
#                 .str.split(",")
#                 .apply(len)
#                 > 1
#             ]
#             print(rows_with_nonexclusive_secondary_clusters)

In [5]:
## Make a text report of the clusters
def print_cluster_report(q_idx, q_data):
    question = q_data["question"]
    specific_population = q_data["specific_population"]
    if specific_population is not None:
        header_suffix = f" [ONLY {specific_population}]"
    else:
        header_suffix = ""

    print(f"\nQuestion {q_idx} - {question} {header_suffix}")

    cluster_defs = q_data["cluster_def"]
    per_cluster_data = q_data["primary_cluster_mean_stds"]
    cluster_df = q_data["cluster_df"]

    total_subjects = cluster_df.shape[0]

    is_primary_non_exclusive = q_data["is_primary_non_exclusive"]
    title_suffix = ""
    if is_primary_non_exclusive:
        title_suffix = f" [non-exclusive membership]"

    total_reported_subjects = 0
    print(f"  Primary Clusters{title_suffix}:")
    for i, row in per_cluster_data.iterrows():
        cluster_id = row["cluster_ids"]
        count_mean = int(row["count_mean"])
        count_std = round(row["count_std"], 2)
        total_reported_subjects += count_mean

        cluster_name = cluster_defs[cluster_defs["cluster_id"] == cluster_id][
            "cluster_name"
        ].values[0]
        cluster_desc = cluster_defs[cluster_defs["cluster_id"] == cluster_id][
            "cluster_description"
        ].values[0]

        print(
            f"    [{cluster_id}] {cluster_name}: {count_mean} +/- {count_std} s.d. out of {total_subjects} subjects ({100*count_mean/total_subjects:.2f}%) [{cluster_desc}]"  # -- {count_mean}/{total_subjects} subject(s) ({100*count_mean/total_subjects:.2f}%)
        )
    print(
        f"    Total # subjects across all Primary clusters: {total_reported_subjects} out of {total_subjects} subjects ignoring s.d."
    )

    # secondary clusters
    if q_data["is_two_level_clustering"]:
        try:
            cluster_df = q_data["cluster_df"]
            if cluster_df["secondary_cluster_ids"].isna().sum() == cluster_df.shape[0]:
                # print("No secondary clusters")
                return
            # print(cluster_df)

            cluster_reclusters_df = q_data["cluster_reclusters_df"]
            # print(cluster_reclusters_df)

            cluster_def = q_data["cluster_def"]
            # print(cluster_def)
            # print(cluster_def.shape)

            per_cluster_data = q_data["secondary_cluster_mean_stds"]
            # print(per_cluster_data)

            is_secondary_non_exclusive = q_data["is_secondary_non_exclusive"]
            title_suffix = ""
            if is_secondary_non_exclusive:
                title_suffix = f" [non-exclusive membership]"

            print(f"  Secondary Clusters{title_suffix}:")
            for i, row in per_cluster_data.iterrows():
                try:
                    cluster_id = row["secondary_cluster_ids"]
                    if (
                        cluster_id == ""
                        or cluster_id is None
                        or cluster_id == "nan"
                        or cluster_id == "None"
                        or cluster_id == "N/A"
                        or cluster_id == "NA"
                        or cluster_id == "--"
                        or cluster_id == "Not available"
                        # or cluster_id.isna()
                    ):
                        continue
                    count_mean = int(row["count_mean"])
                    count_std = round(row["count_std"], 2)

                    cluster_name = cluster_defs[
                        cluster_defs["cluster_id"] == cluster_id
                    ]["cluster_name"].values[0]
                    cluster_desc = cluster_defs[
                        cluster_defs["cluster_id"] == cluster_id
                    ]["cluster_description"].values[0]

                    print(
                        f"    [{cluster_id}] {cluster_name}: {count_mean} +/- {count_std} s.d. out of {total_subjects} subjects ({100*count_mean/total_subjects:.2f}%) [{cluster_desc}]"  # -- {count_mean}/{total_subjects} subject(s) ({100*count_mean/total_subjects:.2f}%)
                    )
                except Exception as e:
                    print(e, f"[{cluster_id}]")
                    # print(f"Missing data for cluster: [{cluster_id}]")
        except Exception as e:
            print(e, q_idx)
            print(f"Exception for question {q_idx}")


for q_idx, q_data in cluster_report.items():
    print_cluster_report(q_idx, q_data)


# Save STDOUT to cluster_report.txt
def print_report_to_file(filename):
    original_stdout = sys.stdout  # Save a reference to the original standard output

    with open(filename, "w") as f:
        sys.stdout = f  # Change the standard output to the file we created.
        for q_idx, q_data in cluster_report.items():
            print_cluster_report(q_idx, q_data)
        sys.stdout = original_stdout  # Reset the standard output to its original value


print_report_to_file(OUTPUT_REPORT_FILE)


Question Q00 - How old are you? 
  Primary Clusters:
    [C1] Young Adults (22-33): 37 +/- 0.0 s.d. out of 93 subjects (39.78%) [Subjects who are between 22 and 33 years old]
    [C2] Middle-aged Adults (34-45): 30 +/- 0.0 s.d. out of 93 subjects (32.26%) [Subjects who are between 34 and 45 years old]
    [C3] Older Adults (46-60): 15 +/- 0.0 s.d. out of 93 subjects (16.13%) [Subjects who are between 46 and 60 years old]
    [C4] Seniors (61 and above): 5 +/- 0.0 s.d. out of 93 subjects (5.38%) [Subjects who are 61 years old or above]
    [C5] Unclear/irrelevant/no response: 6 +/- 0.0 s.d. out of 93 subjects (6.45%) [Subjects who have not provided their age or their response was unclear or irrelevant]
    Total # subjects across all Primary clusters: 93 out of 93 subjects ignoring s.d.

Question Q01 - Where do you live? 
  Primary Clusters:
    [C1] Houston, Texas: 41 +/- 0.45 s.d. out of 93 subjects (44.09%) [Subjects reporting their location in Houston, Texas, regardless of specific

In [6]:
# Bar plots
filename = f"{OUTPUT_DIR}/bar_charts_reclustered.pdf"
if not os.path.exists(f"{OUTPUT_DIR}/bar_charts"):
    os.makedirs(f"{OUTPUT_DIR}/bar_charts")
pdf = matplotlib.backends.backend_pdf.PdfPages(filename=filename)
for q_idx, q_data in cluster_report.items():
    # Extract data
    Q = q_idx
    question = q_data["question"]
    cluster_defs = q_data["cluster_def"]
    per_cluster_data = q_data["primary_cluster_mean_stds"]

    # Check if non-exclusive
    is_primary_non_exclusive = q_data["is_primary_non_exclusive"]
    if is_primary_non_exclusive:
        title_suffix = "[non-exclusive membership]"
    else:
        title_suffix = ""

    # Check if population-specific question & subset
    specific_population = q_data["specific_population"]
    if specific_population is not None:
        title_suffix += f" [Only {specific_population}]"

    # Get cluster names
    cluster_names = []
    for cluster_id in per_cluster_data["cluster_ids"].values:
        cluster_name = cluster_defs[cluster_defs["cluster_id"] == cluster_id][
            "cluster_name"
        ].values[0]
        cluster_names.append(cluster_name)

    # Wrap text for long cluster names
    labels = [
        "\n".join(
            textwrap.wrap(
                cluster_name, width=20, break_long_words=False, break_on_hyphens=False
            )
        )
        for cluster_name in cluster_names
    ]

    # Setup plot
    colors = plt.cm.tab20c.colors
    fig, ax = plt.subplots(figsize=(6, 2.5))

    ax.set_ylabel("Number of Subjects")
    wrapped_title = "\n".join(
        textwrap.wrap(f"{Q}: {question}\n{title_suffix} ", width=60)
    )
    ax.set_title(f"{wrapped_title}", pad=10)  # Add padding

    # Add bars
    cluster_ids = per_cluster_data["cluster_ids"].values
    counts = per_cluster_data["count_mean"].values
    stds = per_cluster_data["count_std"].values
    # ax.bar(cluster_ids, counts, yerr=stds, capsize=5)
    ax.bar(
        cluster_ids,
        counts,
        yerr=stds,
        capsize=5,
        align="center",
        alpha=0.7,
        color=colors[: len(counts)],
    )
    ax.set_xticks(cluster_ids)
    ax.set_xticklabels(labels, rotation=90)

    # Save figure
    individual_filename = f"{OUTPUT_DIR}/bar_charts/question_{q_idx}.pdf"
    plt.savefig(individual_filename, bbox_inches="tight")
    print(f"Saved {individual_filename}")

    pdf.savefig(fig, bbox_inches="tight")
    # plt.show()
    plt.close()
pdf.close()
print(f"Saved {filename}")

Saved ./outputs/20231227//bar_charts/question_Q00.pdf
Saved ./outputs/20231227//bar_charts/question_Q01.pdf
Saved ./outputs/20231227//bar_charts/question_Q02.pdf
Saved ./outputs/20231227//bar_charts/question_Q03.pdf
Saved ./outputs/20231227//bar_charts/question_Q04.pdf
Saved ./outputs/20231227//bar_charts/question_Q05.pdf
Saved ./outputs/20231227//bar_charts/question_Q06.pdf
Saved ./outputs/20231227//bar_charts/question_Q07.pdf
Saved ./outputs/20231227//bar_charts/question_Q08.pdf
Saved ./outputs/20231227//bar_charts/question_Q09.pdf
Saved ./outputs/20231227//bar_charts/question_Q10.pdf
Saved ./outputs/20231227//bar_charts/question_Q11.pdf
Saved ./outputs/20231227//bar_charts/question_Q12.pdf
Saved ./outputs/20231227//bar_charts/question_Q13.pdf
Saved ./outputs/20231227//bar_charts/question_Q14.pdf
Saved ./outputs/20231227//bar_charts/question_Q15.pdf
Saved ./outputs/20231227//bar_charts/question_Q16.pdf
Saved ./outputs/20231227//bar_charts/question_Q17.pdf
Saved ./outputs/20231227//ba

In [9]:
filename = f"{OUTPUT_DIR}/bar_charts_reclustered.pdf"
if not os.path.exists(f"{OUTPUT_DIR}/bar_charts"):
    os.makedirs(f"{OUTPUT_DIR}/bar_charts")
pdf = matplotlib.backends.backend_pdf.PdfPages(filename=filename)

for q_idx, q_data in cluster_report.items():
    # Extract data
    Q = q_idx
    question = q_data["question"]
    cluster_defs = q_data["cluster_def"]
    per_cluster_data = q_data["primary_cluster_mean_stds"]

    # Check if non-exclusive
    is_primary_non_exclusive = q_data["is_primary_non_exclusive"]
    title_suffix = "[non-exclusive membership]" if is_primary_non_exclusive else ""

    # Check if population-specific question & subset
    specific_population = q_data["specific_population"]
    if specific_population is not None:
        title_suffix += f" [Only {specific_population}]"

    # Get cluster names
    cluster_names = [
        cluster_defs[cluster_defs["cluster_id"] == cluster_id]["cluster_name"].values[0]
        for cluster_id in per_cluster_data["cluster_ids"].values
    ]

    # Wrap text for long cluster names
    labels = [
        "\n".join(
            textwrap.wrap(
                cluster_name, width=20, break_long_words=False, break_on_hyphens=False
            )
        )
        for cluster_name in cluster_names
    ]

    # Setup plot for horizontal bars
    colors = plt.cm.tab20c.colors
    # fig, ax = plt.subplots(
    #     figsize=(10, len(cluster_names) * 0.5)
    # )  # Adjust figure size based on the number of clusters
    fig, ax = plt.subplots(
        figsize=(2.5, max(3, len(cluster_names) * 0.7))
    )  # Adjust figure size based on the number of clusters

    ax.set_xlabel("Number of Subjects")
    wrapped_title = "\n".join(
        textwrap.wrap(f"{Q}: {question} {title_suffix}", width=80)
    )
    ax.set_title(f"{wrapped_title}", pad=20)

    # Add horizontal bars
    cluster_ids = range(
        len(cluster_names)
    )  # Use sequential numbers as cluster IDs for y-axis
    counts = per_cluster_data["count_mean"].values
    stds = per_cluster_data["count_std"].values
    ax.barh(
        cluster_ids,
        counts,
        xerr=stds,
        capsize=5,
        align="center",
        alpha=0.7,
        color=colors[: len(counts)],
    )
    ax.set_yticks(cluster_ids)
    ax.set_yticklabels(labels)
    ax.invert_yaxis()  # Invert y-axis so the first item is on top

    # Save figure
    individual_filename = f"{OUTPUT_DIR}/bar_charts/question_{q_idx}.pdf"
    plt.savefig(individual_filename, bbox_inches="tight")
    print(f"Saved {individual_filename}")

    pdf.savefig(fig, bbox_inches="tight")
    # plt.show()
    plt.close()

pdf.close()
print(f"Saved {filename}")

Saved ./outputs/20231227//bar_charts/question_Q00.pdf
Saved ./outputs/20231227//bar_charts/question_Q01.pdf
Saved ./outputs/20231227//bar_charts/question_Q02.pdf
Saved ./outputs/20231227//bar_charts/question_Q03.pdf
Saved ./outputs/20231227//bar_charts/question_Q04.pdf
Saved ./outputs/20231227//bar_charts/question_Q05.pdf
Saved ./outputs/20231227//bar_charts/question_Q06.pdf
Saved ./outputs/20231227//bar_charts/question_Q07.pdf
Saved ./outputs/20231227//bar_charts/question_Q08.pdf
Saved ./outputs/20231227//bar_charts/question_Q09.pdf
Saved ./outputs/20231227//bar_charts/question_Q10.pdf
Saved ./outputs/20231227//bar_charts/question_Q11.pdf
Saved ./outputs/20231227//bar_charts/question_Q12.pdf
Saved ./outputs/20231227//bar_charts/question_Q13.pdf
Saved ./outputs/20231227//bar_charts/question_Q14.pdf
Saved ./outputs/20231227//bar_charts/question_Q15.pdf
Saved ./outputs/20231227//bar_charts/question_Q16.pdf
Saved ./outputs/20231227//bar_charts/question_Q17.pdf
Saved ./outputs/20231227//ba