# Additional visualizations

In [None]:
from pipeline import *
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches
from matplotlib import cm
from matplotlib.colors import ListedColormap
from adjustText import adjust_text
path= userpath(os.environ.get("USER", os.environ.get("USERNAME")), project="hcc")

### Donut charts

In [None]:
def create_ethnicity_donut(df, path, fontsize=12, category="ALL"):
    # Combine specific categories
    df['Count.Var1'] = df['Count.Var1'].replace({
        "I prefer not to answer": "No Answer",
        "None Indicated": "No Answer",
        "None of these": "No Answer",
        "PMI: Skip": "No Answer",
        "Native Hawaiian or Other Pacific Islander": "Pacific Islander"
    })

    # Group by the combined categories and sum their frequencies
    df = df.groupby('Count.Var1').agg({'Count.Freq': 'sum'}).reset_index()

    # Calculate the total number of cases
    total_N = df['Count.Freq'].sum()

    # Calculate percentages
    df['Percentage'] = df['Count.Freq'] / total_N * 100

    # Sort by percentage
    df = df.sort_values(by='Percentage', ascending=False).reset_index(drop=True)

    # Set up the figure and axis with a larger radius
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(aspect="equal"))

    export_string=category
    print(export_string)

    # Modify the colormap to start with yellow
    viridis = cm.get_cmap('viridis', len(df))
    colors = viridis(np.linspace(0, 1, len(df)))
    colors = ListedColormap(colors[::-1])(np.linspace(0, 1, len(df)))  # Reverse for yellow to start

    # Create the wedges with a larger radius
    wedges, _ = ax.pie(df['Percentage'], wedgeprops=dict(width=0.5), startangle=90, colors=colors)

    # Store text objects for adjustment
    text_objects = []

    # Add labels inside the wedges if the segment is large enough, else outside with a line
    for i, p in enumerate(wedges):
        ang = (p.theta2 - p.theta1)/2. + p.theta1
        y = np.sin(np.deg2rad(ang))
        x = np.cos(np.deg2rad(ang))

        horizontalalignment = 'right' if x < 0 else 'left'
        verticalalignment = 'center'
        connectionstyle = "angle,angleA=0,angleB=90" if y > 0 else "angle,angleA=0,angleB=-90"

        if df.loc[i, 'Percentage'] >= 5:
            text = ax.text(x*0.75, y*0.75, f"{df['Count.Var1'][i]}\n({df['Percentage'][i]:.1f}%)",
                           ha=horizontalalignment, va=verticalalignment, fontsize=fontsize)
        else:
            text = ax.text(x*1.4, y*1.4, f"{df['Count.Var1'][i]}\n({df['Percentage'][i]:.1f}%)",
                           ha=horizontalalignment, va=verticalalignment, fontsize=fontsize-2)
            ax.annotate('', xy=(x, y), xytext=(x*1.2, y*1.2),
                        arrowprops=dict(arrowstyle="-", color='black', connectionstyle=connectionstyle))
        text_objects.append(text)

    # Adjust text to minimize overlap on both axes
    adjust_text(text_objects, expand_text=(1.2, 1.2), expand_points=(2, 2),
                expand_objects=(1.2, 1.2), ax=ax, only_move={'text': 'xy', 'points': 'xy'})

    # Add total number in the center
    ax.text(0, 0, f"{export_string}\nN = {total_N:,}", ha='center', va='center', fontsize=fontsize+8, fontweight='bold')

    # Ensure the directory exists and save the plot as SVG
    visuals_path = os.path.join(path, "ext_val_visuals")
    if not os.path.exists(visuals_path):
        os.makedirs(visuals_path)

    svg_path = os.path.join(visuals_path, f"Ethnicity_{export_string}.svg")
    fig.savefig(svg_path, format="svg", bbox_inches="tight", transparent=True)

    # Show plot
    plt.show()

def process_dataframe(df):
    # Group by 'Count.Var1' and sum 'Count.Freq'
    df = df.groupby('Count.Var1').agg({'Count.Freq': 'sum'}).reset_index()

    # Calculate the total number of cases
    total_N = df['Count.Freq'].sum()

    # Calculate percentages
    df['Percentage'] = df['Count.Freq'] / total_N * 100

    # Sort by percentage in descending order
    df = df.sort_values(by='Percentage', ascending=False).reset_index(drop=True)

    return df

def process_and_plot_ethnicity_data(path, categories):
    for category in categories:
        # Construct the file path
        file_name = f"Ethnicity_counts_{category}.xlsx"
        import_path = os.path.join(path, "ext_val", "data", file_name)

        # Read the Excel file
        df = pd.read_excel(import_path)

        # Process the dataframe
        processed_df = process_dataframe(df)

        # Call the plotting function
        create_ethnicity_donut(processed_df, path, fontsize=26, category=category)

        print(f"Processed and plotted data for {category}")

In [None]:

categories = ["All", "HCC"]

# Run the processing and plotting
process_and_plot_ethnicity_data(path, categories)


#create_ethnicity_donut(df, path, fontsize=26, export_string="All") #For single donut