This Jupyter Notebook loads climate projection data, identifies heatwave events, and analyses their frequency, duration, and intensity, producing visualizations to compare trends across historical and future climate scenarios.

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import calendar
import os
from scipy import stats

# Define parameters
LOCATIONS = {
    "Western US": (30, 49),
    "Los Angeles": (33.9, 34.3),
}
MODELS = ["CESM2-WACCM"]
HISTORICAL_PERIOD = ("1985-01-01", "2014-12-31")
FUTURE_PERIOD = ("2015-01-01", "2100-12-31")

# Create figures directory if it doesn't exist
os.makedirs('figures', exist_ok=True)

# Function to extract data for a region
def extract_cmip6_data(dataset, var, lat_range, lon_range):
    if dataset.lon.max() > 180:
        dataset = dataset.assign_coords(lon=(((dataset.lon + 180) % 360) - 180))
        dataset = dataset.sortby('lon')
    ds = dataset[var].sel(lat=slice(*lat_range), lon=slice(*lon_range))
    return ds

# Function to find nearest point to Los Angeles in the dataset
def find_la_point(data_array):
    """Find the nearest grid point to Los Angeles"""
    # Los Angeles coordinates
    target_lat = 34.1
    target_lon = -118.3
    
    # Get available coordinates
    lats = data_array.lat.values
    lons = data_array.lon.values
    
    # Find nearest latitude
    lat_idx = np.abs(lats - target_lat).argmin()
    nearest_lat = lats[lat_idx]
    
    # Find nearest longitude
    lon_idx = np.abs(lons - target_lon).argmin()
    nearest_lon = lons[lon_idx]
    
    print(f"Using coordinates: lat={nearest_lat}, lon={nearest_lon}")
    return nearest_lat, nearest_lon

# Function to convert Kelvin to Celsius
def kelvin_to_celsius(temp_data):
    return temp_data - 273.15

# Function to identify heatwaves
def identify_heatwaves(temp_df, threshold_percentile=95, min_duration=3):
    """
    Identify heatwaves in temperature data using a percentile threshold.
    A heatwave is defined as consecutive days exceeding the threshold for at least min_duration days.
    
    Args:
        temp_df: DataFrame with 'time' and 'tas' columns
        threshold_percentile: Percentile threshold for defining extreme heat (default: 95th percentile)
        min_duration: Minimum number of consecutive days to qualify as a heatwave
        
    Returns:
        DataFrame with heatwave events and their characteristics
    """
    # Add month for monthly threshold calculation
    temp_df['month'] = temp_df['time'].apply(lambda x: x.month)
    temp_df['year'] = temp_df['time'].apply(lambda x: x.year)
    
    # Calculate monthly thresholds (95th percentile of each month across all years)
    monthly_thresholds = temp_df.groupby('month')['tas'].quantile(threshold_percentile/100)
    
    # Assign threshold to each day based on its month
    temp_df['threshold'] = temp_df['month'].map(monthly_thresholds)
    
    # Mark days exceeding threshold
    temp_df['exceeds_threshold'] = temp_df['tas'] > temp_df['threshold']
    
    # Initialize variables for finding consecutive periods
    heatwave_events = []
    current_start = None
    
    # Find consecutive days exceeding threshold
    for i, row in temp_df.iterrows():
        if row['exceeds_threshold']:
            if current_start is None:
                current_start = i
        elif current_start is not None:
            # End of a potential heatwave
            duration = i - current_start
            if duration >= min_duration:
                # This qualifies as a heatwave
                event_data = temp_df.loc[current_start:i-1]
                start_date = event_data['time'].iloc[0]
                end_date = event_data['time'].iloc[-1]
                max_temp = event_data['tas'].max()
                mean_temp = event_data['tas'].mean()
                year = start_date.year
                month = start_date.month
                
                heatwave_events.append({
                    'start_date': start_date,
                    'end_date': end_date,
                    'duration': duration,
                    'max_temp': max_temp,
                    'mean_temp': mean_temp,
                    'year': year,
                    'month': month
                })
            current_start = None
    
    # Check if we ended with an active heatwave
    if current_start is not None:
        duration = len(temp_df) - current_start
        if duration >= min_duration:
            event_data = temp_df.loc[current_start:]
            start_date = event_data['time'].iloc[0]
            end_date = event_data['time'].iloc[-1]
            max_temp = event_data['tas'].max()
            mean_temp = event_data['tas'].mean()
            year = start_date.year
            month = start_date.month
            
            heatwave_events.append({
                'start_date': start_date,
                'end_date': end_date,
                'duration': duration,
                'max_temp': max_temp,
                'mean_temp': mean_temp,
                'year': year,
                'month': month
            })
    
    if not heatwave_events:
        return pd.DataFrame()
    
    return pd.DataFrame(heatwave_events)

# Function to analyze monthly heatwave characteristics
def analyze_monthly_heatwaves(heatwave_df, period_label):
    """Analyze monthly heatwave characteristics"""
    # Create a monthly summary
    monthly_stats = pd.DataFrame(index=range(1, 13))
    
    # If no heatwaves were found, return empty dataframe with zeros
    if heatwave_df.empty:
        monthly_stats['frequency'] = 0
        monthly_stats['avg_duration'] = 0
        monthly_stats['avg_max_temp'] = 0
        monthly_stats['period'] = period_label
        monthly_stats['month_name'] = [calendar.month_abbr[m] for m in monthly_stats.index]
        return monthly_stats
    
    # Calculate frequency, average duration, and average max temp by month
    monthly_count = heatwave_df.groupby('month').size()
    monthly_stats['frequency'] = monthly_count / (heatwave_df['year'].max() - heatwave_df['year'].min() + 1)
    
    # Aggregate other statistics
    duration_by_month = heatwave_df.groupby('month')['duration'].mean()
    max_temp_by_month = heatwave_df.groupby('month')['max_temp'].mean()
    
    monthly_stats['avg_duration'] = duration_by_month
    monthly_stats['avg_max_temp'] = max_temp_by_month
    
    # Fill NaN values (months with no heatwaves)
    monthly_stats = monthly_stats.fillna(0)
    
    # Add month names and period label
    monthly_stats['month_name'] = [calendar.month_abbr[m] for m in monthly_stats.index]
    monthly_stats['period'] = period_label
    
    return monthly_stats

# Function to plot heatwave frequency comparison
def plot_heatwave_frequency(hist_stats, ssp245_stats, ssp585_stats):
    """Create a plot comparing monthly heatwave frequency across scenarios"""
    plt.figure(figsize=(14, 8))
    
    # Extract data
    months = range(1, 13)
    month_names = [calendar.month_abbr[m] for m in months]
    hist_freq = hist_stats['frequency'].values
    ssp245_freq = ssp245_stats['frequency'].values
    ssp585_freq = ssp585_stats['frequency'].values
    
    # Define width of bars and positions
    width = 0.25
    x = np.arange(len(months))
    
    # Create bars
    plt.bar(x - width, hist_freq, width, label='Historical (1985-2014)', color='#1f77b4', edgecolor='black')
    plt.bar(x, ssp245_freq, width, label='SSP245 (2015-2100)', color='#ff7f0e', edgecolor='black')
    plt.bar(x + width, ssp585_freq, width, label='SSP585 (2015-2100)', color='#d62728', edgecolor='black')
    
    # Customize plot
    plt.xlabel('Month', fontsize=14)
    plt.ylabel('Average Number of Heatwaves per Year', fontsize=14)
    plt.title('Monthly Heatwave Frequency in Los Angeles: Historical vs Future Scenarios', fontsize=16)
    plt.xticks(x, month_names)
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels on top of bars
    for i, v in enumerate(hist_freq):
        if v > 0.05:  # Only label if value is significant
            plt.text(i - width, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=9)
    
    for i, v in enumerate(ssp245_freq):
        if v > 0.05:
            plt.text(i, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=9)
    
    for i, v in enumerate(ssp585_freq):
        if v > 0.05:
            plt.text(i + width, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.savefig('figures/heatwave_frequency_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    
def plot_heatwave_characteristics(hist_stats, ssp245_stats, ssp585_stats):
    """Create a bar plot comparing heatwave duration across scenarios"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12), sharex=True)
    
    # Extract data
    months = range(1, 13)
    month_names = [calendar.month_abbr[m] for m in months]
    
    # Duration data
    hist_duration = hist_stats['avg_duration'].values
    ssp245_duration = ssp245_stats['avg_duration'].values
    ssp585_duration = ssp585_stats['avg_duration'].values
    
    # Temperature data (convert to Celsius for better interpretation)
    hist_temp = hist_stats['avg_max_temp'].values
    ssp245_temp = ssp245_stats['avg_max_temp'].values
    ssp585_temp = ssp585_stats['avg_max_temp'].values
    
    # Set width for bars and positions
    bar_width = 0.25
    r1 = np.array(months) - bar_width
    r2 = np.array(months)
    r3 = np.array(months) + bar_width
    
    # Plot durations as bars
    ax1.bar(r1, hist_duration, width=bar_width, label='Historical (1985-2014)', 
            color='#1f77b4', edgecolor='black', linewidth=0.5)
    ax1.bar(r2, ssp245_duration, width=bar_width, label='SSP245 (2015-2100)', 
            color='#ff7f0e', edgecolor='black', linewidth=0.5)
    ax1.bar(r3, ssp585_duration, width=bar_width, label='SSP585 (2015-2100)', 
            color='#d62728', edgecolor='black', linewidth=0.5)
    
    ax1.set_ylabel('Average Duration (days)', fontsize=14)
    ax1.set_title('Average Heatwave Duration by Month', fontsize=16)
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax1.legend(loc='upper left')
    
    # Plot temperatures as bars
    ax2.bar(r1, hist_temp, width=bar_width, label='Historical (1985-2014)', 
            color='#1f77b4', edgecolor='black', linewidth=0.5)
    ax2.bar(r2, ssp245_temp, width=bar_width, label='SSP245 (2015-2100)', 
            color='#ff7f0e', edgecolor='black', linewidth=0.5)
    ax2.bar(r3, ssp585_temp, width=bar_width, label='SSP585 (2015-2100)', 
            color='#d62728', edgecolor='black', linewidth=0.5)
    
    ax2.set_xlabel('Month', fontsize=14)
    ax2.set_ylabel('Average Maximum Temperature (°C)', fontsize=14)
    ax2.set_title('Average Heatwave Maximum Temperature by Month', fontsize=16)
    ax2.set_xticks(months)
    ax2.set_xticklabels(month_names)
    ax2.grid(True, linestyle='--', alpha=0.7)
    ax2.legend(loc='upper left')
    
    plt.tight_layout()
    plt.savefig('figures/heatwave_characteristics_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

# Function to analyze yearly trends in heatwave frequency
def analyze_yearly_trends(hist_hw_df, ssp245_hw_df, ssp585_hw_df):
    """Analyze and plot yearly trends in heatwave frequency"""
    # Function to aggregate yearly data
    def get_yearly_counts(df):
        if df.empty:
            return pd.DataFrame()
        return df.groupby('year').size().reset_index(name='count')
    
    # Get yearly counts
    hist_yearly = get_yearly_counts(hist_hw_df)
    ssp245_yearly = get_yearly_counts(ssp245_hw_df)
    ssp585_yearly = get_yearly_counts(ssp585_hw_df)
    
    # Add scenario label
    if not hist_yearly.empty:
        hist_yearly['scenario'] = 'Historical'
    if not ssp245_yearly.empty:
        ssp245_yearly['scenario'] = 'SSP245'
    if not ssp585_yearly.empty:
        ssp585_yearly['scenario'] = 'SSP585'
    
    # Combine dataframes
    combined_df = pd.concat([hist_yearly, ssp245_yearly, ssp585_yearly], ignore_index=True)
    
    if combined_df.empty:
        print("No yearly data available to plot.")
        return
    
    # Plot yearly trends
    plt.figure(figsize=(16, 8))
    
    # Plot points
    scenarios = ['Historical', 'SSP245', 'SSP585']
    colors = ['#1f77b4', '#ff7f0e', '#d62728']
    
    # Plot data and regression lines for each scenario
    for scenario, color in zip(scenarios, colors):
        scenario_data = combined_df[combined_df['scenario'] == scenario]
        if not scenario_data.empty:
            years = scenario_data['year'].values
            counts = scenario_data['count'].values
            
            # Plot actual data points
            plt.scatter(years, counts, color=color, alpha=0.6, label=f'{scenario} Data')
            
            # Calculate and plot trend line
            if len(years) > 1:  # Need at least 2 points for regression
                slope, intercept, r_value, p_value, std_err = stats.linregress(years, counts)
                line_years = np.array([min(years), max(years)])
                line_counts = slope * line_years + intercept
                plt.plot(line_years, line_counts, color=color, linewidth=2, 
                        label=f'{scenario} Trend (Slope: {slope:.4f} events/year)')
                
                # Calculate and show average per decade
                avg_per_decade = slope * 10
                plt.text(max(years), line_counts[-1], 
                        f"{'+' if avg_per_decade > 0 else ''}{avg_per_decade:.2f} events/decade", 
                        ha='right', va='bottom', color=color, fontsize=12)
    
    # Customize plot
    plt.xlabel('Year', fontsize=14)
    plt.ylabel('Number of Heatwave Events', fontsize=14)
    plt.title('Yearly Heatwave Frequency in Los Angeles (1985-2100)', fontsize=16)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc='upper left')
    
    # Add vertical line separating historical and future periods
    plt.axvline(x=2014.5, color='black', linestyle='--', alpha=0.7)
    plt.text(2014.5, plt.ylim()[1]*0.95, 'Historical | Future', ha='center', va='top', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('figures/heatwave_yearly_trends.png', dpi=300, bbox_inches='tight')
    plt.close()

# Main function
def main():
    print("Loading temperature datasets...")
    
    time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
    
    # Load historical data
    hist_tas_ds = xr.open_dataset("datasets/cmip6_tas_historical_CESM2-WACCM.nc", decode_times=time_coder)
    
    # Load future projections
    ssp245_tas_ds = xr.open_dataset("datasets/cmip6_tas_ssp245_CESM2-WACCM.nc", decode_times=time_coder)
    ssp585_tas_ds = xr.open_dataset("datasets/cmip6_tas_ssp585_CESM2-WACCM.nc", decode_times=time_coder)
    
    # Extract data for Western US
    lat_range = LOCATIONS['Western US']
    lon_range = (-125, -100)  # Western US longitude range
    
    print("Extracting regional temperature data...")
    hist_tas = extract_cmip6_data(hist_tas_ds, 'tas', lat_range, lon_range)
    ssp245_tas = extract_cmip6_data(ssp245_tas_ds, 'tas', lat_range, lon_range)
    ssp585_tas = extract_cmip6_data(ssp585_tas_ds, 'tas', lat_range, lon_range)
    
    # Find nearest point to Los Angeles in each dataset
    print("Finding Los Angeles coordinates in datasets...")
    hist_lat, hist_lon = find_la_point(hist_tas)
    ssp245_lat, ssp245_lon = find_la_point(ssp245_tas)
    ssp585_lat, ssp585_lon = find_la_point(ssp585_tas)
    
    # Get point data for LA using nearest available coordinates
    print("Extracting LA point temperature data...")
    hist_la_tas = hist_tas.sel(lat=hist_lat, lon=hist_lon)
    ssp245_la_tas = ssp245_tas.sel(lat=ssp245_lat, lon=ssp245_lon)
    ssp585_la_tas = ssp585_tas.sel(lat=ssp585_lat, lon=ssp585_lon)
    
    # Convert temperature from Kelvin to Celsius
    print("Converting temperature units from Kelvin to Celsius...")
    hist_la_tas_c = kelvin_to_celsius(hist_la_tas)
    ssp245_la_tas_c = kelvin_to_celsius(ssp245_la_tas)
    ssp585_la_tas_c = kelvin_to_celsius(ssp585_la_tas)
    
    # Convert to DataFrames for easier analysis
    print("Converting to DataFrames...")
    hist_df = hist_la_tas_c.to_dataframe(name='tas').reset_index()
    ssp245_df = ssp245_la_tas_c.to_dataframe(name='tas').reset_index()
    ssp585_df = ssp585_la_tas_c.to_dataframe(name='tas').reset_index()
    
    # Filter by time periods
 
    # Convert cftime objects to strings for comparison
    hist_df = hist_df[(hist_df['time'].astype(str) >= HISTORICAL_PERIOD[0]) & 
                     (hist_df['time'].astype(str) <= HISTORICAL_PERIOD[1])]
    
    ssp245_df = ssp245_df[(ssp245_df['time'].astype(str) >= FUTURE_PERIOD[0]) & 
                          (ssp245_df['time'].astype(str) <= FUTURE_PERIOD[1])]
    
    ssp585_df = ssp585_df[(ssp585_df['time'].astype(str) >= FUTURE_PERIOD[0]) & 
                          (ssp585_df['time'].astype(str) <= FUTURE_PERIOD[1])]


    # Identify heatwaves
    print("Identifying heatwaves in historical data...")
    hist_heatwaves = identify_heatwaves(hist_df)
    
    print("Identifying heatwaves in SSP245 scenario...")
    ssp245_heatwaves = identify_heatwaves(ssp245_df)
    
    print("Identifying heatwaves in SSP585 scenario...")
    ssp585_heatwaves = identify_heatwaves(ssp585_df)
    
    # Analyze monthly heatwave characteristics
    print("Analyzing monthly heatwave patterns...")
    hist_monthly_stats = analyze_monthly_heatwaves(hist_heatwaves, "Historical")
    ssp245_monthly_stats = analyze_monthly_heatwaves(ssp245_heatwaves, "SSP245")
    ssp585_monthly_stats = analyze_monthly_heatwaves(ssp585_heatwaves, "SSP585")
    
    # Create visualizations
    print("Creating heatwave frequency comparison plot...")
    plot_heatwave_frequency(hist_monthly_stats, ssp245_monthly_stats, ssp585_monthly_stats)
    
    print("Creating heatwave characteristics comparison plot...")
    plot_heatwave_characteristics(hist_monthly_stats, ssp245_monthly_stats, ssp585_monthly_stats)
    
    print("Creating yearly trend analysis plot...")
    analyze_yearly_trends(hist_heatwaves, ssp245_heatwaves, ssp585_heatwaves)
    
    # Optional: Save the data for further analysis
    hist_monthly_stats.to_csv('historical_monthly_heatwaves.csv')
    ssp245_monthly_stats.to_csv('ssp245_monthly_heatwaves.csv')
    ssp585_monthly_stats.to_csv('ssp585_monthly_heatwaves.csv')
    
    print("Analysis complete! Heatwave visualizations have been saved.")

# Run main function
if __name__ == "__main__":
    main()

Loading temperature datasets...
Extracting regional temperature data...
Finding Los Angeles coordinates in datasets...
Using coordinates: lat=34.3979057591623, lon=-118.75
Using coordinates: lat=34.3979057591623, lon=-118.75
Using coordinates: lat=34.3979057591623, lon=-118.75
Extracting LA point temperature data...
Converting temperature units from Kelvin to Celsius...
Converting to DataFrames...
Identifying heatwaves in historical data...
Identifying heatwaves in SSP245 scenario...
Identifying heatwaves in SSP585 scenario...
Analyzing monthly heatwave patterns...
Creating heatwave frequency comparison plot...
Creating heatwave characteristics comparison plot...
Creating yearly trend analysis plot...
Analysis complete! Heatwave visualizations have been saved.
