In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import matplotlib.dates as mdates
import os

# Loading the filled dataset
df = pd.read_csv('df_data_filled_multi.csv', parse_dates=['date'])

# Counting negative values before replacement
negative_count = (df['pm2_5'] < 0).sum()
negative_percentage = negative_count / len(df) * 100

print(f"Negative PM2.5 values found: {negative_count} ({negative_percentage:.4f}%)")
print(f"These values were set to zero before further analysis.")

# Replacing negative PM2.5 values with zero
df.loc[df['pm2_5'] < 0, 'pm2_5'] = 0

# 1. Basic PM2.5 Statistics
def basic_pm25_statistics(df):
    # General statistical indicators
    pm25_stats = df['pm2_5'].describe()
    
    # Additional metrics
    pm25_stats['median'] = df['pm2_5'].median()
    pm25_stats['mode'] = df['pm2_5'].mode()[0]
    pm25_stats['variance'] = df['pm2_5'].var()
    pm25_stats['skewness'] = df['pm2_5'].skew()
    pm25_stats['kurtosis'] = df['pm2_5'].kurtosis()
    
    # Percentiles
    percentiles = [1, 5, 10, 25, 50, 75, 90, 95, 99]
    for p in percentiles:
        pm25_stats[f'{p}%'] = df['pm2_5'].quantile(p/100)
    
    return pm25_stats

# 2. PM2.5 Distribution
def plot_pm25_distribution(df):
    plt.figure(figsize=(12, 6))
    
    # Histogram of PM2.5 distribution
    plt.subplot(1, 2, 1)
    sns.histplot(df['pm2_5'], kde=True, bins=30)
    plt.title('PM2.5 Distribution')
    plt.xlabel('PM2.5 Concentration (μg/m³)')
    plt.ylabel('Frequency')
    plt.axvline(df['pm2_5'].mean(), color='red', linestyle='--', label='Mean')
    plt.axvline(df['pm2_5'].median(), color='green', linestyle='-.', label='Median')
    plt.legend()
    
    # Q-Q plot to check for normal distribution
    plt.subplot(1, 2, 2)
    stats.probplot(df['pm2_5'], dist="norm", plot=plt)
    plt.title('Q-Q Plot (Normal Distribution)')
    
    plt.tight_layout()
    plt.savefig(os.path.join("output_diagrams", 'pm25_distribution.png'), dpi=600)
    plt.close()

# 3. PM2.5 Time Series
def plot_pm25_timeseries(df):
    # Resampling for daily and monthly averages
    daily_avg = df.groupby(df['date'].dt.date)['pm2_5'].mean().reset_index()
    daily_avg['date'] = pd.to_datetime(daily_avg['date'])
    
    monthly_avg = df.groupby(pd.Grouper(key='date', freq='M'))['pm2_5'].mean().reset_index()
    
    # Plotting graphs
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    # Hourly time series
    ax1.plot(df['date'], df['pm2_5'], 'b-', alpha=0.5, linewidth=0.8)
    ax1.set_title('Hourly PM2.5 Concentrations')
    ax1.set_ylabel('PM2.5 (μg/m³)')
    ax1.grid(True, alpha=0.3)
    
    # Daily averages
    ax2.plot(daily_avg['date'], daily_avg['pm2_5'], 'r-', linewidth=1.2)
    ax2.set_title('Daily Average PM2.5 Concentrations')
    ax2.set_ylabel('PM2.5 (μg/m³)')
    ax2.set_xlabel('Date')
    ax2.grid(True, alpha=0.3)
    
    # Formatting the X-axis for better date display
    date_format = mdates.DateFormatter('%Y-%m')
    ax2.xaxis.set_major_formatter(date_format)
    ax2.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig(os.path.join("output_diagrams", 'pm25_timeseries.png'), dpi=600)
    plt.close()
    
    # Return daily and monthly averages for further analysis
    return daily_avg, monthly_avg

# 4. Seasonal Statistics
def seasonal_statistics(df):
    # Calculation of seasonal averages and other metrics
    df_copy = df.copy()  # Create a copy of the dataframe
    season_names = {0: 'Winter', 1: 'Spring', 2: 'Summer', 3: 'Autumn'}
    df_copy['season_name'] = df_copy['season'].map(season_names)
    
    seasonal_stats = df_copy.groupby('season_name')['pm2_5'].agg([
        'count', 'mean', 'std', 'min', 'median', 'max',
        lambda x: x.quantile(0.05),  # 5th percentile
        lambda x: x.quantile(0.95)   # 95th percentile
    ]).reset_index()
    
    seasonal_stats.columns = ['Season', 'Count', 'Mean', 'Std', 'Min', 'Median', 'Max', '5th Percentile', '95th Percentile']
    
    # Plotting seasonal distributions
    plt.figure(figsize=(12, 6))
    sns.boxplot(x='season_name', y='pm2_5', data=df_copy, order=['Winter', 'Spring', 'Summer', 'Autumn'])
    # plt.title('PM2.5 Concentration by Season')
    plt.xlabel('Season')
    plt.ylabel('PM2.5 (μg/m³)')
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join("output_diagrams", 'pm25_seasonal_boxplot.png'), dpi=600)
    plt.close()
    
    return seasonal_stats

# 5. Monthly Statistics
def monthly_statistics(df):
    # Adding month name
    df_copy = df.copy()  # Create a copy of the dataframe
    df_copy['month'] = df_copy['date'].dt.month
    month_names = {1: 'Jan', 2: 'Feb', 3: 'Mar', 4: 'Apr', 5: 'May', 
                   6: 'Jun', 7: 'Jul', 8: 'Aug', 9: 'Sep', 
                   10: 'Oct', 11: 'Nov', 12: 'Dec'}
    df_copy['month_name'] = df_copy['month'].map(month_names)
    
    # Calculation of monthly statistics
    monthly_stats = df_copy.groupby('month_name')['pm2_5'].agg([
        'count', 'mean', 'std', 'min', 'median', 'max'
    ]).reset_index()
    
    # Renaming columns for clarity
    monthly_stats.columns = ['Month', 'Count', 'Mean', 'Std', 'Min', 'Median', 'Max']
    
    # Plotting average values by month
    plt.figure(figsize=(12, 6))
    
    # Ensuring the correct order of months
    month_order = [month_names[i] for i in range(1, 13)]
    
    # Use 'Month' instead of 'month_name' and 'Mean' instead of 'mean'
    sns.barplot(x='Month', y='Mean', data=monthly_stats, 
                order=month_order, color='steelblue')
    # plt.title('Monthly Average PM2.5 Concentrations')
    plt.xlabel('Month')
    plt.ylabel('Average PM2.5 (μg/m³)')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join("output_diagrams", 'pm25_monthly_averages.png'), dpi=600)
    plt.close()
    
    return monthly_stats

# Executing all functions
stats_summary = basic_pm25_statistics(df)
print("Basic PM2.5 Statistics:")
print(stats_summary)

plot_pm25_distribution(df)
daily_avg, monthly_avg = plot_pm25_timeseries(df)
seasonal_stats = seasonal_statistics(df)
monthly_stats = monthly_statistics(df)

print("\nSeasonal Statistics:")
print(seasonal_stats)

print("\nMonthly Statistics:")
print(monthly_stats)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.dates as mdates
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

def visualize_who_exceedances(df, pm25_column='pm2_5', date_column='date', who_daily_threshold=15, 
                             output_file=None, figsize=(20, 15)):
    """
    Visualizes exceedances of the WHO daily standard for PM2.5 and outputs statistics.
    
    Parameters:
    - df: DataFrame with PM2.5 data
    - pm25_column: name of the column with PM2.5 data
    - date_column: name of the column with dates
    - who_daily_threshold: WHO threshold value (default 15 μg/m³)
    - output_file: file name for saving (if None, the image is not saved)
    - figsize: figure size
    
    Returns:
    - a dictionary with exceedance statistics
    """
    # Data preparation
    df = df.copy()
    df[date_column] = pd.to_datetime(df[date_column])
    df['year'] = df[date_column].dt.year
    df['month'] = df[date_column].dt.month
    df['day'] = df[date_column].dt.day
    df['hour'] = df[date_column].dt.hour
    
    # Calculate exceedance categories
    df['exceedance_category'] = pd.cut(
        df[pm25_column], 
        bins=[0, who_daily_threshold, 2*who_daily_threshold, 
              4*who_daily_threshold, 8*who_daily_threshold, float('inf')],
        labels=[0, 1, 2, 3, 4]
    )
    
    # Determine the full range of months
    first_record_date = df[date_column].min()
    last_record_date = df[date_column].max()
    first_record_month_start = first_record_date.replace(day=1, hour=0, minute=0, second=0)
    last_record_month_start = last_record_date.replace(day=1, hour=0, minute=0, second=0)
    month_range = pd.date_range(start=first_record_month_start, end=last_record_month_start, freq="MS").to_pydatetime().tolist()
    
    # Create a colormap
    colors = ['lightblue', 'orange', 'red', 'purple', 'indigo']
    exceedance_cmap = ListedColormap(colors)
    
    # Plot preparation
    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=figsize)
    axes = axes.flatten()
    
    # Statistics to return
    stats = {
        "total_hours": len(df),
        "valid_hours": df[pm25_column].notna().sum(),
        "below_threshold": 0,
        "exceedance_1x": 0,
        "exceedance_2x": 0,
        "exceedance_4x": 0,
        "exceedance_8x_plus": 0,
        "max_value": df[pm25_column].max(),
        "monthly_stats": {},
        "exceedance_hours_percent": 0
    }
    
    # Filling in the statistics
    exceedance_counts = df['exceedance_category'].value_counts()
    stats["below_threshold"] = exceedance_counts.get(0, 0)
    stats["exceedance_1x"] = exceedance_counts.get(1, 0)
    stats["exceedance_2x"] = exceedance_counts.get(2, 0)
    stats["exceedance_4x"] = exceedance_counts.get(3, 0)
    stats["exceedance_8x_plus"] = exceedance_counts.get(4, 0)
    
    total_exceedances = stats["exceedance_1x"] + stats["exceedance_2x"] + stats["exceedance_4x"] + stats["exceedance_8x_plus"]
    stats["exceedance_hours_percent"] = (total_exceedances / stats["valid_hours"]) * 100
    
    # Plotting heatmaps for each month
    for i, month_datetime in enumerate(month_range):
        if i < len(axes):
            ax = axes[i]
            month_name = month_datetime.strftime("%Y-%m")
            
            # Filtering data for the current month
            month_data = df[
                (df['year'] == month_datetime.year) &
                (df['month'] == month_datetime.month)
            ]
            
            days_in_month = pd.Timestamp(month_datetime).days_in_month
            hourly_data = np.full((days_in_month, 24), np.nan)
            
            # Filling the data
            for _, row in month_data.iterrows():
                day_index = row['day'] - 1
                hour_index = row['hour']
                if 0 <= day_index < days_in_month and 0 <= hour_index < 24:
                    hourly_data[day_index, hour_index] = row['exceedance_category']
            
            # Calculating monthly statistics
            valid_data = month_data[month_data[pm25_column].notna()]
            if not valid_data.empty:
                month_stats = {
                    "valid_hours": len(valid_data),
                    "below_threshold": (valid_data['exceedance_category'] == 0).sum(),
                    "exceedance_1x": (valid_data['exceedance_category'] == 1).sum(),
                    "exceedance_2x": (valid_data['exceedance_category'] == 2).sum(),
                    "exceedance_4x": (valid_data['exceedance_category'] == 3).sum(),
                    "exceedance_8x_plus": (valid_data['exceedance_category'] == 4).sum(),
                    "max_value": valid_data[pm25_column].max()
                }
                total_month_exceedances = month_stats["exceedance_1x"] + month_stats["exceedance_2x"] + \
                                         month_stats["exceedance_4x"] + month_stats["exceedance_8x_plus"]
                month_stats["exceedance_hours_percent"] = (total_month_exceedances / month_stats["valid_hours"]) * 100
                stats["monthly_stats"][month_name] = month_stats
            
            # Plotting the heatmap
            im = ax.imshow(hourly_data, aspect='auto', cmap=exceedance_cmap, interpolation='nearest', vmin=0, vmax=4)
            
            # Setting up the X-axis (hours)
            ax.set_xticks(np.arange(0, 24, 6))
            ax.set_xticklabels([f"{h}:00" for h in range(0, 24, 6)])
            ax.set_xlabel("Hour of day")
            
            # Setting up the Y-axis (days)
            ax.set_yticks(np.arange(0, days_in_month, 7) if days_in_month > 15 else np.arange(days_in_month))
            ax.set_yticklabels([str(day+1) for day in range(0, days_in_month, 7)] if days_in_month > 15 else [str(day+1) for day in range(days_in_month)])
            ax.set_ylabel("Day of month")
            
            # Title with additional information
            if month_name in stats["monthly_stats"]:
                month_exc_percent = stats["monthly_stats"][month_name]["exceedance_hours_percent"]
                ax.set_title(f"{month_name}\nWHO Exceedances: {month_exc_percent:.1f}%", fontsize=10)
            else:
                ax.set_title(f"{month_name}", fontsize=10)
            
            # Add grid
            ax.grid(which="minor", color="w", linestyle='-', linewidth=0.5, alpha=0.2)
            ax.minorticks_on()
    
    # Hide empty axes
    for ax in axes[len(month_range):]:
        ax.axis("off")
    
    # Add a color bar
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_ticks([0.4, 1.2, 2, 2.8, 3.6])
    cbar.set_ticklabels([
        f'<{who_daily_threshold} μg/m³', 
        f'{who_daily_threshold}-{2*who_daily_threshold} μg/m³', 
        f'{2*who_daily_threshold}-{4*who_daily_threshold} μg/m³', 
        f'{4*who_daily_threshold}-{8*who_daily_threshold} μg/m³', 
        f'>{8*who_daily_threshold} μg/m³'
    ])
    
    # Main title
    # fig.suptitle(f"PM2.5 WHO Daily Threshold ({who_daily_threshold} μg/m³) Exceedances\n"
    #             f"Total exceedance: {stats['exceedance_hours_percent']:.1f}% of hours", 
    #             fontsize=16)
    
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])
    
    # Saving the image
    if output_file:
        plt.savefig(output_file, dpi=600, bbox_inches='tight')
        print(f"Figure saved to {output_file}")
    
    plt.show()
    
    return stats

# Example usage
exceedance_stats = visualize_who_exceedances(
    df,  # DataFrame with imputed values
    pm25_column='pm2_5', 
    date_column='date',
    who_daily_threshold=15,
    output_file=os.path.join("output_diagrams", 'who_exceedances_heatmap.png')
)

# Output statistics for the article text
print(f"Total hours analyzed: {exceedance_stats['valid_hours']}")
print(f"Hours below WHO threshold: {exceedance_stats['below_threshold']} ({exceedance_stats['below_threshold']/exceedance_stats['valid_hours']*100:.1f}%)")
print(f"Hours with exceedances: {exceedance_stats['valid_hours'] - exceedance_stats['below_threshold']} ({exceedance_stats['exceedance_hours_percent']:.1f}%)")
print(f"1-2x exceedances: {exceedance_stats['exceedance_1x']} ({exceedance_stats['exceedance_1x']/exceedance_stats['valid_hours']*100:.1f}%)")
print(f"2-4x exceedances: {exceedance_stats['exceedance_2x']} ({exceedance_stats['exceedance_2x']/exceedance_stats['valid_hours']*100:.1f}%)")
print(f"4-8x exceedances: {exceedance_stats['exceedance_4x']} ({exceedance_stats['exceedance_4x']/exceedance_stats['valid_hours']*100:.1f}%)")
print(f"Over 8x exceedances: {exceedance_stats['exceedance_8x_plus']} ({exceedance_stats['exceedance_8x_plus']/exceedance_stats['valid_hours']*100:.1f}%)")
print(f"Maximum PM2.5 concentration: {exceedance_stats['max_value']:.1f} μg/m³")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap

def calculate_us_epa_aqi(df, pm25_column='pm2_5'):
    """
    Calculates the Air Quality Index (AQI) according to the US EPA methodology based on PM2.5 concentrations.
    
    Parameters:
    - df: DataFrame with PM2.5 data
    - pm25_column: name of the column with PM2.5 concentrations (µg/m³)
    
    Returns:
    - DataFrame with added 'AQI' and 'AQI_Category' columns
    """
    df = df.copy()
    
    # PM2.5 concentration breakpoints (µg/m³)
    pm25_breakpoints = [0, 12.1, 35.5, 55.5, 150.5, 250.5, 350.5, 500.5]
    
    # Corresponding AQI values
    aqi_breakpoints = [0, 51, 101, 151, 201, 301, 401, 501]
    
    # Air quality categories
    aqi_categories = [
        'Good',
        'Moderate',
        'Unhealthy for Sensitive Groups',
        'Unhealthy',
        'Very Unhealthy',
        'Hazardous',
        'Extremely Hazardous'
    ]
    
    # Category colors
    aqi_colors = [
        '#00E400',  # Green
        '#FFFF00',  # Yellow
        '#FF7E00',  # Orange
        '#FF0000',  # Red
        '#99004C',  # Purple
        '#7E0023',  # Maroon
        '#663300'   # Brown
    ]
    
    # Calculating AQI for each PM2.5 value
    aqi_values = []
    aqi_cats = []
    aqi_color_values = []
    
    for pm25 in df[pm25_column]:
        if pd.isna(pm25):
            aqi_values.append(np.nan)
            aqi_cats.append(np.nan)
            aqi_color_values.append(np.nan)
            continue
            
        # Determining the index of the corresponding range
        for i in range(len(pm25_breakpoints) - 1):
            if pm25_breakpoints[i] <= pm25 < pm25_breakpoints[i + 1]:
                # Linear interpolation to calculate AQI
                aqi = ((aqi_breakpoints[i + 1] - aqi_breakpoints[i]) / 
                       (pm25_breakpoints[i + 1] - pm25_breakpoints[i]) * 
                       (pm25 - pm25_breakpoints[i]) + 
                       aqi_breakpoints[i])
                aqi_values.append(int(round(aqi)))
                aqi_cats.append(aqi_categories[i])
                aqi_color_values.append(aqi_colors[i])
                break
        else:
            # For values above the maximum threshold
            if pm25 >= pm25_breakpoints[-1]:
                aqi_values.append(int(round(aqi_breakpoints[-1])))
                aqi_cats.append(aqi_categories[-1])
                aqi_color_values.append(aqi_colors[-1])
            else:
                aqi_values.append(np.nan)
                aqi_cats.append(np.nan)
                aqi_color_values.append(np.nan)
    
    # Adding results to the DataFrame
    df['AQI'] = aqi_values
    df['AQI_Category'] = aqi_cats
    df['AQI_Color'] = aqi_color_values
    
    return df

def analyze_aqi_distribution(df):
    """
    Analyzes the distribution of AQI categories and returns statistics.
    
    Parameters:
    - df: DataFrame with calculated AQI
    
    Returns:
    - a dictionary with AQI category statistics
    """
    # Counting the number and percentage for each category
    category_counts = df['AQI_Category'].value_counts()
    category_percentages = df['AQI_Category'].value_counts(normalize=True) * 100
    
    # Calculating the mean, minimum, and maximum AQI
    mean_aqi = df['AQI'].mean()
    median_aqi = df['AQI'].median()
    min_aqi = df['AQI'].min()
    max_aqi = df['AQI'].max()
    
    # Creating a time index to calculate monthly statistics
    df_with_dt = df.copy()
    if 'date' in df.columns:
        df_with_dt['month'] = pd.to_datetime(df['date']).dt.strftime('%Y-%m')
        monthly_stats = df_with_dt.groupby('month')['AQI'].agg(['mean', 'min', 'max'])
        monthly_categories = []
        for month in df_with_dt['month'].unique():
            month_category_counts = df_with_dt[df_with_dt['month'] == month]['AQI_Category'].value_counts()
            month_category_pcts = df_with_dt[df_with_dt['month'] == month]['AQI_Category'].value_counts(normalize=True) * 100
            monthly_categories.append({
                'month': month,
                'counts': month_category_counts,
                'percentages': month_category_pcts
            })
    else:
        monthly_stats = None
        monthly_categories = None
    
    return {
        'category_counts': category_counts,
        'category_percentages': category_percentages,
        'mean_aqi': mean_aqi,
        'median_aqi': median_aqi,
        'min_aqi': min_aqi,
        'max_aqi': max_aqi,
        'monthly_stats': monthly_stats,
        'monthly_categories': monthly_categories
    }

def visualize_aqi_distribution(df, stats, output_file=None):
    """
    Visualizes the distribution of AQI categories using pie charts.
    
    Parameters:
    - df: DataFrame with calculated AQI
    - stats: statistics obtained from analyze_aqi_distribution
    - output_file: path to save the image
    """
    # Create a figure with multiple subplots - now only 2 plots
    fig = plt.figure(figsize=(18, 8))
    
    # 1. Pie chart for the overall distribution of AQI categories
    ax1 = plt.subplot2grid((1, 2), (0, 0))
    
    # Preparing data for the pie chart - remove zero values and merge categories
    categories = stats['category_counts'].index.tolist()
    counts = stats['category_counts'].values
    percentages = stats['category_percentages'].values

    # Create a dictionary to group categories
    grouped_data = {}
    for i, category in enumerate(categories):
        if category in ["Very Unhealthy", "Hazardous"]:
            # Combine Very Unhealthy and Hazardous
            if "Very Unhealthy & Hazardous" not in grouped_data:
                grouped_data["Very Unhealthy & Hazardous"] = 0
            grouped_data["Very Unhealthy & Hazardous"] += counts[i]
        else:
            if category not in grouped_data:
                grouped_data[category] = 0
            grouped_data[category] += counts[i]

    # Convert grouped data back into lists
    total_count = sum(grouped_data.values())
    categories = list(grouped_data.keys())
    counts = list(grouped_data.values())
    percentages = [count / total_count * 100 for count in counts]

    # Define the desired order of categories for the legend
    category_order = [
        "Good", 
        "Moderate", 
        "Unhealthy for Sensitive Groups",
        "Unhealthy", 
        "Very Unhealthy & Hazardous"
    ]

    # Sort categories in the desired order
    indices = []
    for cat in category_order:
        if cat in categories:
            idx = categories.index(cat)
            indices.append(idx)

    categories = [categories[i] for i in indices]
    counts = [counts[i] for i in indices]
    percentages = [percentages[i] for i in indices]

    # Get category colors
    colors = []
    for category in categories:
        if category == "Very Unhealthy & Hazardous":
            # Use the color of Very Unhealthy for the combined category
            category_color = df[df['AQI_Category'] == "Very Unhealthy"]['AQI_Color'].iloc[0]
        else:
            category_color = df[df['AQI_Category'] == category]['AQI_Color'].iloc[0]
        colors.append(category_color)

    # Create a pie chart with settings to prevent overlap
    wedges, texts, autotexts = ax1.pie(
        counts, 
        labels=None,  # Remove labels from the plot itself
        autopct='%1.1f%%',
        startangle=90,
        colors=colors,
        wedgeprops={'edgecolor': 'w', 'linewidth': 1}
    )

    # Customize the appearance of percentages
    for autotext in autotexts:
        autotext.set_fontsize(11)
        autotext.set_color('black')
        autotext.set_fontweight('bold')

    # Add the legend separately from the chart in the correct order
    ax1.legend(wedges, categories, loc="upper right", fontsize=12, bbox_to_anchor=(1.4, 1), 
            frameon=True, framealpha=0.8)

    ax1.set_title('Distribution of AQI Categories', fontsize=16)
    
    # 2. Histogram of AQI values
    ax2 = plt.subplot2grid((1, 2), (0, 1))
    sns.histplot(df['AQI'].dropna(), bins=20, kde=True, ax=ax2)
    ax2.set_xlabel('AQI Value', fontsize=14)
    ax2.set_ylabel('Frequency', fontsize=14)
    ax2.set_title('Distribution of AQI Values', fontsize=16)
    ax2.tick_params(axis='both', which='major', labelsize=12)  # Increased to 12
    # Add an annotation with statistics
    stats_text = (f"Mean AQI: {stats['mean_aqi']:.1f}\n"
                 f"Median AQI: {stats['median_aqi']:.1f}\n"
                 f"Min AQI: {stats['min_aqi']:.1f}\n"
                 f"Max AQI: {stats['max_aqi']:.1f}")
    ax2.annotate(stats_text, xy=(0.85, 0.95), xycoords='axes fraction',
                horizontalalignment='center', verticalalignment='top', fontsize=12, 
                bbox=dict(boxstyle="round,pad=0.3", fc="w", ec="k", alpha=0.8))
    
    # Main title
    # plt.suptitle('Air Quality Index (AQI) Analysis Based on PM2.5 Concentrations', fontsize=16, y=0.98)
    
    # Adjusting the overall layout
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    # Saving the image
    if output_file:
        plt.savefig(output_file, dpi=600, bbox_inches='tight')
        print(f"Figure saved to {output_file}")
    
    plt.show()
    
    # Additionally: create a table of monthly category distribution
    if stats['monthly_categories'] is not None:
        fig, ax = plt.figure(figsize=(16, len(stats['monthly_stats'])*0.8)), plt.subplot(111)
        
        # Create data for the heatmap
        all_categories = set()
        for item in stats['monthly_categories']:
            all_categories.update(item['counts'].index)
        
        all_categories = sorted(list(all_categories), key=lambda x: [
            'Good', 'Moderate', 'Unhealthy for Sensitive Groups',
            'Unhealthy', 'Very Unhealthy', 'Hazardous', 'Extremely Hazardous'
        ].index(x) if x in [
            'Good', 'Moderate', 'Unhealthy for Sensitive Groups',
            'Unhealthy', 'Very Unhealthy', 'Hazardous', 'Extremely Hazardous'
        ] else 999)
        
        # Convert full month names to three-letter abbreviations
        month_abbr = []
        for month in stats['monthly_stats'].index:
            # Assumes YYYY-MM format
            year, m = month.split('-')
            # Convert month number to three-letter name
            month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 
                          'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
            month_abbr.append(f"{month_names[int(m)-1]} {year[-2:]}")
        
        heatmap_data = []
        for month in stats['monthly_stats'].index:
            month_data = next((item for item in stats['monthly_categories'] if item['month'] == month), None)
            row = []
            for category in all_categories:
                percentage = month_data['percentages'].get(category, 0) if month_data else 0
                row.append(percentage)
            heatmap_data.append(row)
        # Increase the font size of axis labels
        ax.tick_params(axis='both', which='major', labelsize=11)
        # Create a heatmap
        sns.heatmap(heatmap_data, annot=True, fmt='.1f', cmap='YlOrRd', 
                   xticklabels=all_categories, yticklabels=month_abbr, ax=ax, annot_kws={"size": 12})
        
        # plt.title('Monthly Distribution of AQI Categories (%)', fontsize=14)
        plt.tight_layout()
        
        if output_file:
            monthly_output = output_file.replace('.', '_monthly_distribution.')
            plt.savefig(monthly_output, dpi=600, bbox_inches='tight')
            
        plt.show()

# Main block of code for data processing and visualization
def analyze_air_quality_index(df, pm25_column='pm2_5', date_column='date', output_file=None):
    """
    Comprehensive analysis of the air quality index based on PM2.5.
    
    Parameters:
    - df: DataFrame with PM2.5 data
    - pm25_column: name of the column with PM2.5 concentrations
    - date_column: name of the column with dates
    - output_file: path for saving visualizations
    
    Returns:
    - AQI statistics
    """
    # Calculate AQI
    df_with_aqi = calculate_us_epa_aqi(df, pm25_column)
    
    # Analyze AQI distribution
    aqi_stats = analyze_aqi_distribution(df_with_aqi)

    # Visualize results
    visualize_aqi_distribution(df_with_aqi, aqi_stats, output_file)
    
    # Output main statistics for use in the article text
    print("\nAir Quality Index (AQI) Analysis Results:")
    print(f"Mean AQI: {aqi_stats['mean_aqi']:.1f}")
    print(f"Median AQI: {aqi_stats['median_aqi']:.1f}")
    print(f"Minimum AQI: {aqi_stats['min_aqi']:.1f}")
    print(f"Maximum AQI: {aqi_stats['max_aqi']:.1f}")
    print("\nAQI Category Distribution:")
    
    for category, count in aqi_stats['category_counts'].items():
        percentage = aqi_stats['category_percentages'][category]
        print(f"{category}: {count} hours ({percentage:.1f}%)")
    
    return aqi_stats, df_with_aqi

# Example usage
aqi_stats, df_with_aqi = analyze_air_quality_index(
    df,  # DataFrame with imputed values
    pm25_column='pm2_5', 
    date_column='date',
    output_file=os.path.join("output_diagrams", 'air_quality_index_analysis.png')
)

# Optionally, save the data with the calculated AQI for further analysis
df_with_aqi.to_csv('data_with_aqi.csv', index=False)