In [2]:
import pandas as pd

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
%run config/seaborn_config.ipynb

### **Data Exploration**

##### Data Frame structure

In [None]:
def explore_dataframe(data: pd.DataFrame):
    """
    Explore the structure and content of a Pandas DataFrame.

    This function provides an overview of the given DataFrame, including:
    - The shape of the DataFrame (number of rows and columns).
    - The number of rows and columns.
    - The data types of each column.
    - Displays the first record in the DataFrame.
    - Displays the last record in the DataFrame.

    Parameters:
    data (pd.DataFrame): The DataFrame to be explored.

    Returns:
    None: This function prints out the structure and displays example records of the DataFrame.
    """
    # Shape of the dataframe
    print("Shape of the DataFrame:", data.shape, "\n")
    
    # Number of rows and columns
    print(f"df contains {len(data)} rows.")
    print(f"df contains {len(data.columns)} columns:", "\n")
    
    # List of columns with data type of each column
    print(data.dtypes)
    
    # Display first record
    print("\nFirst record:")
    display(data.head(1))
    
    # Display last record
    print("\nLast record:")
    display(data.tail(1))


##### Column type mapping

In [None]:
def change_column_types(data: pd.DataFrame, column_type_mapping: dict) -> pd.DataFrame:
    """
    Change the data types of DataFrame columns based on a provided mapping.

    Parameters:
    data (pd.DataFrame): The DataFrame whose column types need to be changed.
    column_type_mapping (dict): A dictionary where keys are column names and values are the target data types.

    Returns:
    mapped_data (pd.DataFrame): The DataFrame with updated column types.
    """
    # Print column types before mapping
    print("Column types before mapping")
    print(data.dtypes, "\n")

    # Map column types
    mapped_data = data.astype(column_type_mapping)

    # Print column types after mapping
    print("Column types after mapping")
    print(mapped_data.dtypes)

    return mapped_data

##### Seperate columns by type

In [None]:
def separate_columns_by_type(data: pd.DataFrame, primary_key: list):
    """
    Selects and prints the categorical and numerical columns from a pandas DataFrame.
    
    Parameters:
    df (pd.DataFrame): The DataFrame to analyze.
    
    Returns:
    A tuple containing separate lists for categorical, numerical and text columns.
    """

    # Select categorical columns
    categorical_cols = data.select_dtypes(include="category").columns.to_list()
    
    # Select numerical columns
    numerical_cols = data.select_dtypes(include="number").columns.to_list()

    # Select text columns:
    text_cols = data.select_dtypes(include="string").columns.to_list()

    # Select key columns:
    key_cols = primary_key

    # Remove key columns from other lists
    for col in key_cols:
        try:
            categorical_cols.remove(col)
        except:
            pass

        try:
            numerical_cols.remove(col)
        except:
            pass

        try:
            text_cols.remove(col)
        except:
            pass
    
    # Print lists
    print("Key Columns:", key_cols)    
    print("Categorical Columns:", categorical_cols)
    print("Numerical Columns:", numerical_cols)
    print("Text Columns:", text_cols)
    
    return key_cols, categorical_cols, numerical_cols, text_cols

##### Missing values

In [None]:
def explore_missing_values(data: pd.DataFrame, show_all: bool = False) -> pd.DataFrame:
    """
    Explore missing values in a pandas DataFrame.

    Parameters:
    data (pd.DataFrame): The DataFrame to analyze.
    show_all (bool): A boolean to indicate whether columns without missing values should be included in the output.

    Returns:
    missing_summary (pd.DataFrame): A DataFrame summarizing the number and percentage of missing values for each column.
    """

    # Calculate the number of missing values per column
    missing_count = data.isnull().sum()

    # Calculate the percentage of missing values per column
    missing_percentage = (missing_count / len(data)) * 100

    # Create a summary DataFrame
    missing_summary = pd.DataFrame(
        {"Missing Count": missing_count, "Missing Percentage": missing_percentage}
    )

    # Filter out columns with no missing values for better readability
    if show_all == False:
        missing_summary = missing_summary[missing_summary["Missing Count"] > 0]

    # Sort the summary by the highest percentage of missing values
    missing_summary = missing_summary.sort_values(
        by="Missing Percentage", ascending=False
    )

    return missing_summary

In [None]:
def remove_missing_values(
    data: pd.DataFrame, columns_to_drop: list, how: str = "any", axis: int = 0
):
    """
    Remove rows or columns with missing values from a Pandas DataFrame.

    Parameters:
    data (pd.DataFrame): The DataFrame to clean.
    how (str): Determine if row/column is removed when it contains missing data:
               - "any": If any NA values are present, drop that row/column.
               - "all": If all values are NA, drop that row/column.
               - "list": Remove user-defined list of columns.
    axis (int): Determine if rows or columns are removed:
                - 0: Drop rows with missing values.
                - 1: Drop columns with missing values.
    columns: User-defined list of columns to drop.

    Returns:
    cleaned_data (pd.DataFrame): The cleaned DataFrame with missing values removed.
    """

    print(f"Number of records after missing value removal: {len(data)}")

    # Remove rows or columns with missing values
    if how in ("any", "all") and axis in (0, 1):
        cleaned_data = data.dropna(how=how, axis=axis)
        
    elif how == "list" and axis == 1:
        cleaned_data = data.drop(columns=columns_to_drop, axis=1)
        
    else:
        raise ValueError("When 'how' is 'list', axis must be '1'.")
    
    print(f"Number of records after missing value removal: {len(cleaned_data)}")

    return cleaned_data

##### Outliers

In [None]:
# Function to explore outliers (and extreme values) for continuous variables
def explore_outliers(data: pd.DataFrame, columns: list, method: str = "iqr"):
    """
    Explore outliers (and extreme values) for a continuous variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data to analyze.
    categorical_columns (list): List of categorical column names.
    method (str): The method to be used to determine outliers.
    """
    for col in columns:

        # Plot the boxplot
        plot_boxplot(data, col)

        # Adds outlier statistics to the boxplot
        add_outlier_statistics(data, col, method)

        # Show the plot
        plt.show()

In [None]:
# Sub-function to plot a boxplot
def plot_boxplot(data, col):
    """
    Plot a boxplot for a continuous variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data.
    col (str): The name of the continuous variable.
    """
    # Create boxplot
    sns.boxplot(x=col, data=data[data[col].notnull()], width=0.5, color=custom_colors["blue"])

    # Add plot aesthetics
    plt.title(f"Boxplot of **{col}**")
    plt.xlabel(col)

In [None]:
# Sub-function to add outlier statistics to boxplot
def add_outlier_statistics(data: pd.DataFrame, col: str, method: str):
    """Adds outlier statistics to the boxplot."""
    # Calculate number and share of outliers
    n = len(data[col][data[col].notnull()])

    if method == "iqr":
        n_outliers = n - len(remove_outliers_from_column(data, col, method))
        pct_outliers = n_outliers / n * 100
    
    # Add number and share of outliers to the plot
    if pct_outliers == 0:
        plt.text(
            x=1.02,
            y=0.90,
            s=f"Outliers: {n_outliers}",
            transform=plt.gca().transAxes,
            color=basic_colors["black"],
        )

        plt.text(
            x=1.02,
            y=0.85,
            s=f"Outliers %: {pct_outliers:.1f}",
            transform=plt.gca().transAxes,
            color=basic_colors["black"],
        )
    
    elif pct_outliers > 0 and pct_outliers < 10:
        plt.text(
            x=1.02,
            y=0.90,
            s=f"Outliers: {n_outliers}",
            transform=plt.gca().transAxes,
            color=custom_colors["orange"],
        )

        plt.text(
            x=1.02,
            y=0.85,
            s=f"Outliers %: {pct_outliers:.1f}",
            transform=plt.gca().transAxes,
            color=custom_colors["orange"],
        )
    else:
        plt.text(
            x=1.02,
            y=0.90,
            s=f"Outliers: {n_outliers}",
            transform=plt.gca().transAxes,
            color=custom_colors["red"],
        )

        plt.text(
            x=1.02,
            y=0.85,
            s=f"Outliers %: {pct_outliers:.1f}",
            transform=plt.gca().transAxes,
            color=custom_colors["red"],
        )

In [None]:
# Function to remove outliers from continuous variable
def remove_outliers_from_column(
    data: pd.DataFrame, col: str, method: str = "iqr"
) -> pd.Series:
    """
    Removes outliers from a continuous variable.

    Parameters:
    data (pd.DataFrame): The input DataFrame.
    col (str): The name of the column to remove outliers from.
    method (str): The method to be used to remove outliers.

    Returns:
    pd.Series: A pandas Series with outliers removed from the specified column.
    """

    if method == "iqr":
        # Calculate Q1 (25th percentile) and Q3 (75th percentile)
        Q1 = data[col].quantile(0.25)
        Q3 = data[col].quantile(0.75)

        # Calculate the Interquartile Range (IQR)
        IQR = Q3 - Q1

        # Define the bounds for non-outliers
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        # Filter the data to remove outliers
        filtered_series = data[col][
            (data[col] >= lower_bound) & (data[col] <= upper_bound)
        ]

    return filtered_series

In [None]:
def remove_outliers_from_dataframe(data: pd.DataFrame, cols: list, method="iqr"):
    """
    Remove outliers from specified columns in a pandas DataFrame using the IQR method.

    Parameters:
    df (pd.DataFrame): The DataFrame from which to remove outliers.
    columns (list): A list of column names from which to remove outliers.

    Returns:
    pd.DataFrame: The DataFrame with outliers removed from specified columns.
    """

    print(f"Number of records before outlier removal: {len(data)}")

    # Make copy of the input data
    cleaned_data = data.copy()  

    if method == "iqr":
        for col in cols:
            if col in cleaned_data.columns:
                # Calculate Q1 (25th percentile) and Q3 (75th percentile)
                Q1 = cleaned_data[col].quantile(0.25)
                Q3 = cleaned_data[col].quantile(0.75)

                # Calculate IQR
                IQR = Q3 - Q1

                # Define the bounds for outliers
                lower_bound = Q1 - 1.5 * IQR
                upper_bound = Q3 + 1.5 * IQR

                # Remove outliers
                cleaned_data = cleaned_data[
                    (cleaned_data[col] >= lower_bound)
                    & (cleaned_data[col] <= upper_bound)
                ]

    print(f"Number of records after outlier removal: {len(cleaned_data)}")

    return cleaned_data

##### Univariate distributions

In [3]:
# Main function
def explore_univariate(
    data: pd.DataFrame,
    columns: list,
    dist_type: str,
    relative_frequency: bool = True,
    n_bins: int = 20,
    remove_outliers: bool = False,
):
    """
    Main function to explore univariate distributions (discrete or continuous).

    Args:
    - data: A pandas DataFrame.
    - columns: A list of column names to explore.
    - dist_type: Distribution type: "discrete" or "continuous".
    - relative_frequency: If True, use relative frequencies for discrete variables.
    - n_bins: Number of bins for continuous distributions.
    - remove_outliers: Boolean to indicate if outliers should be removed from continuous variables before plotting.
    """
    if dist_type == "discrete":
        for col in columns:
            # Explore discrete variable
            explore_discrete(data, col, relative_frequency)

    elif dist_type == "continuous":
        for col in columns:
            # Explore continuous variable
            explore_continuous(data, col, n_bins, remove_outliers)

    else:
        raise ValueError("dist_type must be either 'discrete' or 'continuous'.")

In [None]:
def explore_discrete(data: pd.DataFrame, col: str, relative_frequency: bool):
    """
    Function to explore the distribution of a discrete variable.

    Args:
    - data: A pandas DataFrame.
    - col: The column to explore.
    - relative_frequency: If True, use relative frequencies, otherwise use absolute.
    """
    # Configure the barplot
    stat, fmt, ylim = configure_barplot(data, col, relative_frequency)

    # Plot the bar plot
    plot_barplot(data, col, stat, ylim, fmt)

    # Add cardinality information to the barplot
    add_cardinality_info(data, col)

    # Show the barplot
    plt.show()

In [None]:
# Sub-function to configure barplot
def configure_barplot(data: pd.DataFrame, col: str, relative_frequency: bool):
    """
    Configure plot settings for discrete distributions.

    Args:
    - data: A pandas DataFrame.
    - col: The column to explore.
    - relative_frequency: If True, use relative frequencies, otherwise absolute.

    Returns:
    - stat: Statistical option for seaborn plot.
    - fmt: Format for bar labels.
    - ylim: Upper limit for y-axis.
    """
    # Configure stat, fmt, ylim
    if relative_frequency:
        stat = "probability"
        fmt = "%.2f"
        ylim = 1
        
    else:
        stat = "count"
        fmt = "%.0f"
        ylim = data[col].value_counts().max() * 1.2

    return stat, fmt, ylim

In [None]:
# Sub-function to plot barplot
def plot_barplot(data: pd.DataFrame, col: str, stat: str, ylim: float, fmt: str):
    """
    Function to create a bar plot for the discrete variable.

    Args:
    - data: A pandas DataFrame.
    - col: The column to explore.
    - stat: Statistic to display ('count' or 'probability').
    - ylim: Y-axis limit.
    - fmt: Format for bar labels.
    """
    # Create bar plot
    ax = sns.histplot(
        data=data,
        x=col,
        stat=stat,
        discrete=True,
        shrink=0.95,
        alpha=0.50,
        color=custom_colors["blue"],
    )

    # Add value labels - if cardinality is not too high
    if hasattr(ax, "containers") and data[col].nunique() <= 20:
        for bar in ax.containers:
            ax.bar_label(bar, fmt=fmt, fontsize=10)

    # Add plot aesthetics
    plt.title(f"Distribution of **{col}**")
    plt.xlabel(col)
    plt.xticks(data[col].unique().dropna().sort_values(ascending=True))
    plt.xticks(rotation=90)
    plt.ylabel("Frequency")
    plt.ylim(0, ylim)

In [None]:
# Sub-function to add cardinality information to barplot
def add_cardinality_info(data: pd.DataFrame, col: str):
    """Adds cardinality information to the barplot."""
    # Calculate number of unique values
    cardinality = data[col].nunique()
    
    # Add number of unique values to the plot
    plt.text(
        x=1.02,
        y=0.90,
        s=f"Unique values: {cardinality}",
        transform=plt.gca().transAxes,
        color="#000000",
    )

    # Add indication for moderate or high cardinality to the plot
    if cardinality >= 10 and cardinality < 50:
        plt.text(
            x=1.02,
            y=0.80,
            s="Moderate Cardinality",
            transform=plt.gca().transAxes,
            color=custom_colors["orange"],
        )
    elif cardinality >= 50:
        plt.text(
            x=1.02,
            y=0.80,
            s="High Cardinality",
            transform=plt.gca().transAxes,
            color=custom_colors["red"],
        )

In [None]:
# Refactored function to explore continuous variables
def explore_continuous(data: pd.DataFrame, col: str, n_bins: int, remove_outliers: bool):
    """
    Function to explore the distribution of a continuous variable.

    Args:
    - data: A pandas DataFrame.
    - col: The column to explore.
    - n_bins: Number of bins to use for the histogram.
    - remove_outliers: Boolean to indicat if outliers should be removed before plotting.
    """
    # Make a copy of the data
    copied_data = data.copy()

    # Remove outliers
    if remove_outliers == True:
        copied_data[col] = remove_outliers_iqr(copied_data, col)[2]
    
    # Calculate continuous statistics
    stats = calculate_continuous_stats(copied_data, col)

    # Plot the histplot
    plot_histplot(copied_data, col, n_bins)

    # Add the continuous stats information
    add_continuous_stats_info(stats)

    # Add legend
    plt.legend()

    # Show the plot
    plt.show()


In [None]:
# Sub-function to calculate continuous statistics
def calculate_continuous_stats(data: pd.DataFrame, col: str):
    """
    Calculate basic statistics for a continuous variable.

    Args:
    - data: A pandas DataFrame.
    - col: The column to calculate statistics for.

    Returns:
    - A dictionary of calculated statistics.
    """
    # Calculate statistics
    stats = {
        "min": data[col].min(),
        "max": data[col].max(),
        "mean": data[col].mean(),
        "median": data[col].median(),
        "mode": data[col].mode()[0],
        "var": data[col].var(),
        "std": data[col].std(),
        "skew": data[col].skew(),
        "kurt": data[col].kurt(),
    }
    return stats

In [None]:
# Sub-function to plot histplot
def plot_histplot(data: pd.DataFrame, col: str, n_bins: int):
    """
    Function to create a histogram plot for the continuous variable.

    Args:
    - data: A pandas DataFrame.
    - col: The column to explore.
    - n_bins: Number of bins to use for the histogram.
    """
    # Create histogram
    ax = sns.histplot(
        data=data,
        x=col,
        bins=n_bins,
        kde=True,
        stat="probability",
        alpha=0.50,
        color=custom_colors["blue"],
    )
    
    # Add plot aesthetics
    plt.title(f"Distribution of **{col}**")
    plt.xlabel(col)
    plt.ylabel("Frequency")

In [None]:
# Sub-function to add statistical annotations to plots
def add_continuous_stats_info(stats: dict):
    """
    Add annotations (mean, median, etc.) to the plot.

    Args:
    - stats: A dictionary of calculated statistics.
    """
    # Add statistical annotations to the plot
    plt.axvline(
        stats["mean"],
        color="#000000",
        linestyle="-",
        alpha=1.0,
        label=f"Mean: {stats['mean']:.1f}",
    )
    plt.axvline(
        stats["median"],
        color="#000000",
        linestyle="--",
        alpha=1.00,
        label=f"Median: {stats['median']:.1f}",
    )
    plt.axvline(
        stats["mode"],
        color="#000000",
        linestyle=":",
        alpha=1.00,
        label=f"Mode: {stats['mode']:.1f}",
    )

    plt.text(
        x=1.02,
        y=0.80,
        s=f"Min: {stats['min']:.1f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )
    plt.text(
        x=1.02,
        y=0.75,
        s=f"Max: {stats['max']:.1f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )
    plt.text(
        x=1.02,
        y=0.50,
        s=f"Variance: {stats['var']:.1f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )
    plt.text(
        x=1.02,
        y=0.45,
        s=f"Std Dev: {stats['std']:.1f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )
    plt.text(
        x=1.02,
        y=0.20,
        s=f"Skewness: {stats['skew']:.1f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )
    plt.text(
        x=1.02,
        y=0.15,
        s=f"Kurtosis: {stats['kurt']:.1f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )

##### Bivariate relationships

In [None]:
# Main function
def explore_bivariate_relationships(
    data: pd.DataFrame,
    categorical_columns: list,
    numerical_columns: list,
    y_column: str,
    y_type: str,
):
    """
    Main function to explore bivariate relationships in a DataFrame based on the target variable's type.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data to analyze.
    categorical_columns (list): List of categorical column names.
    numerical_columns (list): List of numerical column names.
    y_column (str): The name of the target variable column.
    y_type (str): The type of the target variable ('continuous' or 'discrete').
    """
    
    if y_type == "continuous":
        print("**Continuous features**")
        # Explore relationships between continuous features and a continuous target variable
        explore_continuous_continuous(data, numerical_columns, y_column)
        
        print("**Discrete features**")
        # Explore relationship between discrete features and a continuous target variable
        explore_continuous_discrete(data, categorical_columns, y_column)

    elif y_type == "discrete":
        print("**Continuous features**")
        # Explore relationship between continuous features and a discrete target variable
        explore_discrete_continuous(data, numerical_columns, y_column)
        
        print("**Discrete features**")
        # Explore relationship between discrete features and a discrete target variable
        explore_discrete_discrete(data, categorical_columns, y_column)

In [None]:
# Sub-function for relationships between continuous features and a continuous target variable
def explore_continuous_continuous(data, numerical_columns, y_column):
    """
    Explore relationships between continuous features and a continuous target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data to analyze.
    numerical_columns (list): List of numerical column names.
    y_column (str): The name of the target variable column.
    """
    for col in numerical_columns:
        if col == y_column:
            continue
        
        # Calculate correlation coefficients
        pearson_corr, spearman_corr = calculate_correlation_coefficients(
            data, col, y_column
        )
        # Plot the scatterplot
        plot_scatterplot(data, col, y_column, pearson_corr, spearman_corr)

In [None]:
# Sub-function to calculate correlation coefficients
def calculate_correlation_coefficients(data, col, y_column):
    """
    Calculate Pearson and Spearman correlations.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data.
    col (str): The name of the numerical feature column.
    y_column (str): The name of the target variable column.

    Returns:
    tuple: Pearson and Spearman correlation coefficients.
    """
    # Calculate Pearson correlation
    pearson_corr = data[col].corr(data[y_column], method="pearson")

    # Calculate Spearman correlation
    spearman_corr = data[col].corr(data[y_column], method="spearman")
    
    return pearson_corr, spearman_corr

In [None]:
# Sub-function to plot a scatterplot
def plot_scatterplot(data, col, y_column, pearson_corr, spearman_corr):
    """
    Plot scatterplot for a continuous feature against the target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data.
    col (str): The name of the numerical feature column.
    y_column (str): The name of the target variable column.
    pearson_corr (float): Pearson correlation coefficient.
    spearman_corr (float): Spearman correlation coefficient.
    """
    # Create scatterplot
    sns.regplot(
        x=col,
        y=y_column,
        data=data,
        ci=None,
        marker="o",
        scatter_kws={"s": 50, "alpha": 0.10, "color": custom_colors["blue"]},
        line_kws={"color": basic_colors["black"], "linewidth": 2, "linestyle": "--"},
    )

    # Add plot aesthetics
    plt.title(f"Scatterplot of **{col}** / **{y_column}**")
    plt.xlabel(col)
    plt.ylabel(y_column)

    # Add annotations
    plt.text(
        x=1.02,
        y=0.90,
        s=f"Pearson: {pearson_corr:.2f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )
    plt.text(
        x=1.02,
        y=0.85,
        s=f"Spearman: {spearman_corr:.2f}",
        transform=plt.gca().transAxes,
        color="#000000",
    )

    # Show plot
    plt.show()

In [None]:
# Sub-function for relationships between discrete features and a continuous target variable
def explore_continuous_discrete(data, categorical_columns, y_column):
    """
    Explore relationships between discrete features and a continuous target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data to analyze.
    categorical_columns (list): List of categorical column names.
    y_column (str): The name of the target variable column.
    """
    # Calculate median of target variable
    y_median = data[y_column].median()

    for col in categorical_columns:
        # Plot the grouped boxplot
        plot_grouped_boxplot(data, col, y_column, y_median)

In [None]:
# Sub-function to plot a boxplot
def plot_grouped_boxplot(data, col, y_column, y_median):
    """
    Plot a boxplot for a discrete feature against a continuous target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data.
    col (str): The name of the categorical feature column.
    y_column (str): The name of the target variable column.
    y_median (float): The median of the target variable.
    """
    # Create boxplot
    sns.boxplot(x=col, y=y_column, data=data, width=0.5, color=custom_colors["blue"])

    # Add plot aesthetics
    plt.title(f"Grouped Boxplot of **{col}** / **{y_column}**")
    plt.xlabel(col)
    plt.ylabel(y_column)
    plt.axhline(
        y=y_median, alpha=0.50, color=basic_colors["black"], linestyle="--", linewidth=2
    )

    # Show plot
    plt.show()

In [None]:
# Sub-function for relationships between continuous features and a discrete target variable
def explore_discrete_continuous(data, numerical_columns, y_column):
    """
    Explore relationships between continuous features and a discrete target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data to analyze.
    numerical_columns (list): List of numerical column names.
    y_column (str): The name of the target variable column.
    """
    for col in numerical_columns:
        # Plot the grouped histplot
        plot_grouped_histplot(data, col, y_column)

In [None]:
# Sub-function to plot a stacked histogram
def plot_grouped_histplot(data, col, y_column):
    """
    Plot a histogram for a continuous feature against a discrete target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data.
    col (str): The name of the numerical feature column.
    y_column (str): The name of the target variable column.
    """
    # Select colors from color palette based on number of categories
    n_categories = data[y_column].nunique()
    selected_colors = custom_qualitative_palette[:n_categories]
    
    # Create grouped histplot
    sns.histplot(
        data=data,
        x=col,
        hue=y_column,
        kde=True,
        stat="probability",
        bins=10,
        common_norm=False,
        alpha=0.50,
        palette=selected_colors,
    )

    # Add plot aesthetics
    plt.title(f"Grouped Histplot of **{col}** / **{y_column}**")
    plt.xlabel(col)
    plt.ylabel(f"Probability of {y_column}")

    # Show plot
    plt.show()

In [None]:
# Sub-function for relationships between discrete features and a discrete target variable
def explore_discrete_discrete(data, categorical_columns, y_column):
    """
    Explore relationships between discrete features and a discrete target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data to analyze.
    categorical_columns (list): List of categorical column names.
    y_column (str): The name of the target variable column.
    """
    for col in categorical_columns:
        if col == y_column:
            continue
        
        # Plot the stacked barplot
        plot_stacked_barplot(data, col, y_column)

In [None]:
# Sub-function to plot a stacked barplot
def plot_stacked_barplot(data, col, y_column):
    """
    Plot a stacked bar chart for a discrete feature against a discrete target variable.

    Parameters:
    data (pd.DataFrame): The DataFrame containing the data.
    col (str): The name of the categorical feature column.
    y_column (str): The name of the target variable column.
    """
    # Create a contingency table
    cross_tab = pd.crosstab(data[col], data[y_column], normalize="index")

    # Select colors from color palette based on number of categories
    n_categories = data[y_column].nunique()
    selected_colors = custom_qualitative_palette[:n_categories]

    # Create stacked bar plot
    ax = cross_tab.plot(
        kind="bar", stacked=True, alpha=0.50, color=selected_colors
    )

    # Add value labels
    if hasattr(ax, "containers"):
        for bar in ax.containers:
            ax.bar_label(bar, label_type="center", fmt="%.2f", fontsize=10)

    # Add plot aesthetics
    plt.title(f"Stacked Barplot of **{col}** / **{y_column}**")
    plt.xlabel(col)
    plt.ylabel(f"Probability of {y_column}")
    plt.legend(title=y_column, bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=0)

    # Show plot
    plt.show()