In [4]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, fixed
from IPython.display import display, clear_output
import scipy.stats as stats

# Function to load and preprocess the datasets
def load_datasets():
    ecmwf = xr.open_dataset(r"/Users/riya/Desktop/SAC/datasets/ecmwf datas.nc")
    precip = xr.open_dataset(r"/Users/riya/Desktop/SAC/datasets/precip_konkan.nc")
    sst = xr.open_dataset(r"/Users/riya/Desktop/SAC/datasets/noaa_sst_masked.nc")
   
    # Prepare variables dictionary
    variables = {}
   
    # Add ECMWF variables
    for var in ecmwf.data_vars:
        if var not in ['skt']:
             variables[f"{var}"] = ecmwf[var]
    # Add PRECIP variables
    for var in precip.data_vars:
        if var not in ['time_bnds', 'lat_bnds', 'lon_bnds']:  # Skip bounds variables
            variables[f"{var}"] = precip[var]
   
    # Add SST variables
    for var in sst.data_vars:
        variables[f"{var}"] = sst[var]
   
    return variables

# Function to compute area-averaged time series for a variable
def compute_area_average(variable):
    # Check if variable has time dimension
    if 'time' in variable.dims or 'valid_time' in variable.dims:
        # Determine time dimension name
        time_dim = 'time' if 'time' in variable.dims else 'valid_time'
       
        # Calculate area average - handle different coordinate names
        if 'lat' in variable.dims and 'lon' in variable.dims:
            return variable.mean(dim=['lat', 'lon'])
        elif 'latitude' in variable.dims and 'longitude' in variable.dims:
            return variable.mean(dim=['latitude', 'longitude'])
        else:
            print(f"Warning: Cannot compute area average for variable with dimensions {variable.dims}")
            return None
    else:
        print(f"Warning: Variable does not have a time dimension. Dimensions: {variable.dims}")
        return None

# Function to compute lagged correlation
def compute_lagged_correlation(ts1, ts2, max_lag=12):
    """
    
    ts1, ts2 : xarray DataArray
        Time series to correlate
  
    best_lag : int
        Lag that produces the maximum correlation
    """
    # Ensure the time series have compatible time coordinates
    # Convert to pandas Series for easier manipulation
    if isinstance(ts1, xr.DataArray) and isinstance(ts2, xr.DataArray):
        # Extract time dimension name
        time_dim1 = 'time' if 'time' in ts1.dims else 'valid_time'
        time_dim2 = 'time' if 'time' in ts2.dims else 'valid_time'
       
        # Convert to pandas Series
        s1 = ts1.to_series()
        s2 = ts2.to_series()
       
        # Reset indices if they're MultiIndex
        if isinstance(s1.index, pd.MultiIndex):
            s1 = s1.reset_index().set_index(time_dim1)
        if isinstance(s2.index, pd.MultiIndex):
            s2 = s2.reset_index().set_index(time_dim2)
       
        # Make sure the series are sorted by time
        s1 = s1.sort_index()
        s2 = s2.sort_index()
       
        # Resample to ensure consistent monthly frequency
        s1 = s1.resample('MS').mean()
        s2 = s2.resample('MS').mean()
       
        # Compute correlations for different lags
        corrs = []
        lags = list(range(-max_lag, max_lag + 1))
       
        for lag in lags:
            if lag < 0:
                # s1 leads s2
                s1_lagged = s1.shift(-lag)
                corr = s1_lagged.corr(s2)
            else:
                # s2 leads s1
                s2_lagged = s2.shift(lag)
                corr = s1.corr(s2_lagged)
           
            corrs.append(corr)
       
        # Find the lag with maximum correlation
        corrs = np.array(corrs)
        max_idx = np.nanargmax(np.abs(corrs))
        max_corr = corrs[max_idx]
        best_lag = lags[max_idx]
       
        return corrs, max_corr, best_lag
    else:
        print("Error: Input time series are not xarray DataArrays")
        return None, None, None

# Function to compute corrected correlation using the best lag
def compute_corrected_correlation(ts1, ts2, best_lag):
  
    # Convert to pandas Series
    time_dim1 = 'time' if 'time' in ts1.dims else 'valid_time'
    time_dim2 = 'time' if 'time' in ts2.dims else 'valid_time'
   
    s1 = ts1.to_series()
    s2 = ts2.to_series()
   
    # Reset indices if they're MultiIndex
    if isinstance(s1.index, pd.MultiIndex):
        s1 = s1.reset_index().set_index(time_dim1)
    if isinstance(s2.index, pd.MultiIndex):
        s2 = s2.reset_index().set_index(time_dim2)
   
    # Make sure the series are sorted by time
    s1 = s1.sort_index()
    s2 = s2.sort_index()
   
    # Resample to ensure consistent monthly frequency
    s1 = s1.resample('MS').mean()
    s2 = s2.resample('MS').mean()
   
    # Apply lag correction
    if best_lag < 0:
        # s1 leads s2
        s1_lagged = s1.shift(-best_lag)
        corrected_corr = s1_lagged.corr(s2)
    else:
        # s2 leads s1
        s2_lagged = s2.shift(best_lag)
        corrected_corr = s1.corr(s2_lagged)
   
    return corrected_corr

# Function to plot the results
def plot_correlation_results(ts1, ts2, var1_name, var2_name, corrs, lags, best_lag, max_corr, corrected_corr):
   
    if isinstance(ts1, xr.DataArray) and isinstance(ts2, xr.DataArray):
        time_dim1 = 'time' if 'time' in ts1.dims else 'valid_time'
        time_dim2 = 'time' if 'time' in ts2.dims else 'valid_time'
       
        s1 = ts1.to_series()
        s2 = ts2.to_series()
       
        # Reset indices if they're MultiIndex
        if isinstance(s1.index, pd.MultiIndex):
            s1 = s1.reset_index().set_index(time_dim1)
        if isinstance(s2.index, pd.MultiIndex):
            s2 = s2.reset_index().set_index(time_dim2)
       
        # Make sure the series are sorted by time
        s1 = s1.sort_index()
        s2 = s2.sort_index()
       
        # Resample to ensure consistent monthly frequency
        s1 = s1.resample('MS').mean()
        s2 = s2.resample('MS').mean()
    else:
        print("Error: Cannot plot, input time series are not xarray DataArrays")
        return

    # Determine which variable is affecting which based on the lag
    if best_lag < 0:
        cause_var = var1_name
        effect_var = var2_name
        lag_months = abs(best_lag)
    elif best_lag > 0:
        cause_var = var2_name
        effect_var = var1_name
        lag_months = best_lag
    else:  # best_lag == 0
        cause_var = "Neither"
        effect_var = "Both variables are synchronous"
        lag_months = 0
   
    # Print the causality statement
    if lag_months > 0:
        print(f"\nCausality Relationship: {cause_var} is affecting {effect_var} with a lag of {lag_months} months.")
    else:
        print(f"\nCausality Relationship: {effect_var}")
   
    # Create the figure with just the correlation plot
    plt.figure(figsize=(10, 6))
   
    # Plot lagged correlation
    plt.plot(lags, corrs, 'o-')
    plt.axvline(x=best_lag, color='r', linestyle='--', label=f'Best Lag: {best_lag}')
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    plt.title('Lagged Correlation')
    plt.xlabel('Lag (months)')
    plt.ylabel('Correlation')
    plt.grid(True)
    plt.legend()
   
    # Add text with correlation information
    text_info = (f"Original Correlation: {round(corrs[lags.index(0)], 3)}\n"
                f"Max Correlation: {round(max_corr, 3)}\n"
                f"Best Lag: {best_lag} months\n"
                f"Corrected Correlation: {round(corrected_corr, 3)}")
   
    plt.text(0.02, 0.02, text_info, transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.8))
   
    plt.tight_layout()
    plt.show()

# Create the main application function
def correlation_analysis_app():

    try:
        variables = load_datasets()
       
        var_names = list(variables.keys())
       
        # Create widgets
        var1_dropdown = widgets.Dropdown(
            options=var_names,
            description='Variable 1:',
            disabled=False,
            style={'description_width': 'initial'}
        )
       
        var2_dropdown = widgets.Dropdown(
            options=var_names,
            description='Variable 2:',
            disabled=False,
            style={'description_width': 'initial'}
        )
       
        output = widgets.Output()
       
        # Create compute button
        compute_button = widgets.Button(
            description='Compute',
            button_style='success',
            tooltip='Click to compute correlation analysis'
        )
       
        # Define the function to run when the button is clicked
        def on_compute_button_clicked(b):
            with output:
                clear_output()
               
                # Get selected variables
                var1_name = var1_dropdown.value
                var2_name = var2_dropdown.value
            
               
                # Get the variable data
                var1 = variables[var1_name]
                var2 = variables[var2_name]
               
                # Compute area averages
                ts1 = compute_area_average(var1)
                ts2 = compute_area_average(var2)
               
                if ts1 is None or ts2 is None:
                    print("Error: Could not compute area averages for the selected variables.")
                    return
               
                # Compute lagged correlation
                max_lag = 12  # Fixed max lag of 12 months
                lags = list(range(-max_lag, max_lag + 1))
                corrs, max_corr, best_lag = compute_lagged_correlation(ts1, ts2, max_lag)
               
                if corrs is None:
                    print("Error: Could not compute lagged correlation.")
                    return
               
                # Compute corrected correlation
                corrected_corr = compute_corrected_correlation(ts1, ts2, best_lag)
               
                # Create a table with the results
                results = pd.DataFrame({
                    'Variable 1': [var1_name],
                    'Variable 2': [var2_name],
                    'Original Correlation': [corrs[lags.index(0)]],
                    'Max Correlation': [max_corr],
                    'Best Lag (months)': [best_lag],
                    'Corrected Correlation': [corrected_corr]
                })
               
                # Print the results table
                print("\nCorrelation Analysis Results:")
                display(results)
               
                # Plot the results
                plot_correlation_results(ts1, ts2, var1_name, var2_name, corrs, lags, best_lag, max_corr, corrected_corr)
       
        # Attach the function to the button
        compute_button.on_click(on_compute_button_clicked)
       
        # Display the widgets and output
        display(var1_dropdown, var2_dropdown, compute_button, output)
       
    except Exception as e:
        print(f"Error loading datasets: {e}")
       

if __name__ == "__main__":
    correlation_analysis_app()

Error loading datasets: found the following matches with the input file in xarray's IO backends: ['netcdf4', 'h5netcdf']. But their dependencies may not be installed, see:
https://docs.xarray.dev/en/stable/user-guide/io.html 
https://docs.xarray.dev/en/stable/getting-started-guide/installing.html


In [2]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

# Function to load and preprocess the datasets
def load_datasets():
    # Load the datasets
    ecmwf = xr.open_dataset(r'datasets/ecmwf datas.nc')
    precip = xr.open_dataset(r'datasets/precip_konkan.nc')
    sst = xr.open_dataset(r'datasets/noaa_sst_masked.nc')
    
    # Prepare variables dictionary
    variables = {}
    
    # Add ECMWF variables
    for var in ecmwf.data_vars:
        if var not in ['skt']:
            variables[f"{var}"] = ecmwf[var]
    
    # Add PRECIP variables
    for var in precip.data_vars:
        if var not in ['time_bnds', 'lat_bnds', 'lon_bnds']:  # Skip bounds variables
            variables[f"{var}"] = precip[var]
    
    # Add SST variables
    for var in sst.data_vars:
        variables[f"{var}"] = sst[var]
    
    return variables

# Function to compute area-averaged time series for a variable
def compute_area_average(variable):
    # Check if variable has time dimension
    if 'time' in variable.dims or 'valid_time' in variable.dims:
        # Determine time dimension name
        time_dim = 'time' if 'time' in variable.dims else 'valid_time'
        
        # Calculate area average - handle different coordinate names
        if 'lat' in variable.dims and 'lon' in variable.dims:
            return variable.mean(dim=['lat', 'lon'])
        elif 'latitude' in variable.dims and 'longitude' in variable.dims:
            return variable.mean(dim=['latitude', 'longitude'])
        else:
            print(f"Warning: Cannot compute area average for variable with dimensions {variable.dims}")
            return None
    else:
        print(f"Warning: Variable does not have a time dimension. Dimensions: {variable.dims}")
        return None

# Function to compute lagged correlation
def compute_lagged_correlation(ts1, ts2, max_lag=12):
    # Ensure the time series have compatible time coordinates
    # Convert to pandas Series for easier manipulation
    if isinstance(ts1, xr.DataArray) and isinstance(ts2, xr.DataArray):
        # Extract time dimension name
        time_dim1 = 'time' if 'time' in ts1.dims else 'valid_time'
        time_dim2 = 'time' if 'time' in ts2.dims else 'valid_time'
        
        # Convert to pandas Series
        s1 = ts1.to_series()
        s2 = ts2.to_series()
        
        # Reset indices if they're MultiIndex
        if isinstance(s1.index, pd.MultiIndex):
            s1 = s1.reset_index().set_index(time_dim1)
        if isinstance(s2.index, pd.MultiIndex):
            s2 = s2.reset_index().set_index(time_dim2)
        
        # Make sure the series are sorted by time
        s1 = s1.sort_index()
        s2 = s2.sort_index()
        
        # Resample to ensure consistent monthly frequency
        s1 = s1.resample('MS').mean()
        s2 = s2.resample('MS').mean()
        
        # Compute correlations for different lags
        corrs = []
        lags = list(range(-max_lag, max_lag + 1))
        
        for lag in lags:
            if lag < 0:
                # s1 leads s2
                s1_lagged = s1.shift(-lag)
                corr = s1_lagged.corr(s2)
            else:
                # s2 leads s1
                s2_lagged = s2.shift(lag)
                corr = s1.corr(s2_lagged)
            
            corrs.append(corr)
        
        # Find the lag with maximum correlation
        corrs = np.array(corrs)
        max_idx = np.nanargmax(np.abs(corrs))
        max_corr = corrs[max_idx]
        best_lag = lags[max_idx]
        
        # Get original correlation (at lag 0)
        orig_corr = corrs[lags.index(0)]
        
        return orig_corr, max_corr, best_lag
    else:
        print("Error: Input time series are not xarray DataArrays")
        return None, None, None

# Function to create correlation matrices
def create_correlation_matrices(variables_dict, max_lag=12):
    # Get variable names
    var_names = list(variables_dict.keys())
    n_vars = len(var_names)
    
    # Initialize matrices for original correlations, corrected correlations, and best lags
    orig_corr_matrix = np.zeros((n_vars, n_vars))
    corr_matrix = np.zeros((n_vars, n_vars))
    lag_matrix = np.zeros((n_vars, n_vars), dtype=int)
    
    # Compute area averages for all variables
    time_series = {}
    for var_name in var_names:
        ts = compute_area_average(variables_dict[var_name])
        if ts is not None:
            time_series[var_name] = ts
        else:
            print(f"Warning: Could not compute area average for {var_name}")
    
    # Compute correlations for all variable pairs
    for i, var1 in enumerate(var_names):
        if var1 not in time_series:
            continue
            
        for j, var2 in enumerate(var_names):
            if var2 not in time_series:
                continue
                
            if i == j:
                # Variable with itself: correlation = 1, lag = 0
                orig_corr_matrix[i, j] = 1.0
                corr_matrix[i, j] = 1.0
                lag_matrix[i, j] = 0
            elif i < j:  # Only compute for upper triangle
                orig_corr, max_corr, best_lag = compute_lagged_correlation(
                    time_series[var1], time_series[var2], max_lag
                )
                
                if orig_corr is not None:
                    # Fill both upper and lower triangles
                    orig_corr_matrix[i, j] = orig_corr
                    orig_corr_matrix[j, i] = orig_corr
                    
                    corr_matrix[i, j] = max_corr
                    corr_matrix[j, i] = max_corr
                    
                    lag_matrix[i, j] = best_lag
                    lag_matrix[j, i] = -best_lag  # Reverse lag for opposite direction
    
    return orig_corr_matrix, corr_matrix, lag_matrix, var_names

# Function to plot correlation matrices
def plot_correlation_matrices(orig_corr_matrix, corr_matrix, lag_matrix, var_names):
    # Create a figure with two subplots side by side
    fig, axes = plt.subplots(1, 2, figsize=(20, 9))
    
    # Plot original correlation matrix
    sns.heatmap(orig_corr_matrix, annot=True, fmt=".2f", cmap="RdBu_r", 
                vmin=-1, vmax=1, center=0, ax=axes[0], 
                xticklabels=var_names, yticklabels=var_names)
    axes[0].set_title('Original Correlation Matrix', fontsize=16)
    axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha='right')
    
    # Plot corrected correlation matrix
    sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="RdBu_r", 
                vmin=-1, vmax=1, center=0, ax=axes[1], 
                xticklabels=var_names, yticklabels=var_names)
    axes[1].set_title('Corrected Correlation Matrix (After Best Lag)', fontsize=16)
    axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha='right')
    
    # Add text box to explain colors
    textstr = 'Left: Original correlations (lag=0)\nRight: Corrected correlations (best lag)'
    fig.text(0.5, 0.02, textstr, fontsize=14, 
             bbox=dict(facecolor='white', alpha=0.8), 
             horizontalalignment='center')
    
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.suptitle('Correlation Matrices', fontsize=18)
    plt.subplots_adjust(bottom=0.15)
    
    # Display the plot
    plt.show()
    
    # Create a list of all variable pairs (upper triangle only to avoid repetition)
    lag_pairs = []
    for i in range(len(var_names)):
        for j in range(i+1, len(var_names)):
            var1 = var_names[i]
            var2 = var_names[j]
            lag = lag_matrix[i, j]
            
            # Include all combinations, including those with zero lag
            if lag > 0:
                lag_pairs.append(f"{var2} leads {var1} by {lag} months")
            elif lag < 0:
                lag_pairs.append(f"{var1} leads {var2} by {abs(lag)} months")
            else:  # lag == 0
                lag_pairs.append(f"{var1} and {var2} are synchronous (lag = 0)")
    
    # Print out all lag pairs
    print("\nBest Time Lags (months):")
    for pair in lag_pairs:
        print(f"- {pair}")

# Main function
def correlation_matrix_analysis():
    try:
        # Load datasets
        print("Loading datasets...")
        variables = load_datasets()
        
        print(f"Found {len(variables)} variables. Computing correlations...")
        
        # Compute correlation matrices
        orig_corr_matrix, corr_matrix, lag_matrix, var_names = create_correlation_matrices(variables)
        
        # Plot the matrices
        plot_correlation_matrices(orig_corr_matrix, corr_matrix, lag_matrix, var_names)
        
    except Exception as e:
        print(f"Error in correlation matrix analysis: {e}")

# Run the analysis
if __name__ == "__main__":
    correlation_matrix_analysis()

Loading datasets...
Error in correlation matrix analysis: [Errno 2] No such file or directory: '/Users/riya/Desktop/SAC/notebooks/datasets/ecmwf datas.nc'
