In [1]:
import os
import numpy as np
import xarray as xr
import pandas as pd
from pathos.multiprocessing import ProcessingPool as Pool
import multiprocessing
import datetime

def load_netcdf(file_path):
    """Load a NetCDF file and return as xarray Dataset."""
    try:
        return xr.open_dataset(file_path)
    except Exception as e:
        print(f"Error loading NetCDF file {file_path}: {e}")
        return None

def preprocess_dataset(ds, rename_dict=None):
    """Preprocess the dataset by renaming and rounding coordinates."""
    if rename_dict:
        ds = ds.rename(rename_dict)
    ds['latitude'] = ds['latitude'].astype('float64').round(2)
    ds['longitude'] = ds['longitude'].astype('float64').round(2)
    return ds

def extract_data_slice(ds, var_name, time_index=-1, nbins=2, time=1):
    """Extract a 2D slice from the dataset."""
    data = ds[var_name].values
    if data.ndim == 3:
        return data[:, :, time_index]
    elif data.ndim == 4:
        return data[nbins-1, time-1, :, :]
    else:
        raise ValueError(f"Unexpected number of dimensions: {data.ndim}")

def create_dataframe(lon, lat, data, columns):
    """Create a DataFrame from grid data."""
    lon_grid, lat_grid = np.meshgrid(lon, lat)
    if isinstance(data, xr.DataArray):
        data_values = data.values.flatten()
    else:
        data_values = data.flatten()
    return pd.DataFrame({
        'lon': lon_grid.flatten(),
        'lat': lat_grid.flatten(),
        columns[-1]: data_values
    })
def classify_drought(row):
    """Classify drought conditions based on CDI and rainfall."""
    cdi, rain = row['cdi'], row['rain']
    if cdi < 0.2:
        if rain < 50:
            return 5 if cdi < 0.02 else 6  # Persists or Worsens
        elif rain < 70:
            return 5  # Persists
        else:
            return 2 if 0.1 <= cdi < 0.2 else 3  # Removed or Improved
    else:
        return 4 if rain < 30 else 1  # Develops or No drought

def parallel_classify(df, num_cores):
    """Apply drought classification in parallel."""
    with Pool(num_cores) as p:
        return p.map(classify_drought, [df.iloc[i] for i in range(df.shape[0])])

def create_output_dataset(df_out, lat, lon, time):
    """Create xarray Dataset from output DataFrame."""
    da = xr.DataArray(df_out['outlook'].values.reshape(len(lat), len(lon)),
                      coords=[('lat', lat), ('lon', lon)],
                      dims=['lat', 'lon'],
                      name='outlook')
    da.attrs['varunit'] = ''
    da.attrs['longname'] = 'drought outlook'
    ds = da.to_dataset()
    ds['time'] = ('time', [time])
    return ds

def save_netcdf(ds, file_path):
    """Save xarray Dataset as NetCDF file."""
    try:
        ds.to_netcdf(file_path, mode='w')
        print(f"File saved successfully: {file_path}")
    except Exception as e:
        print(f"Error saving NetCDF file {file_path}: {e}")

def print_value_counts(da):
    """Print value counts of the DataArray."""
    unique_values, counts = np.unique(da.values.flatten(), return_counts=True)
    for value, count in zip(unique_values, counts):
        print(f"Value: {value}, Count: {count}")

def calculate_3month_average(ds, var_name):
    """Calculate 3-month average from the dataset."""
    rain_var = ds[var_name]
    first_three_months = rain_var[1, :3, :, :]
    return first_three_months.mean(dim="time")

def process_drought_outlook(cdi_path, rain_path, output_path_1month, output_path_3month):
    """Main function to process drought outlook for both 1-month and 3-month average."""
    # Load and preprocess CDI data
    ds_cdi = load_netcdf(cdi_path)
    if ds_cdi is None:
        return
    ds_cdi = preprocess_dataset(ds_cdi)
    cdi_slice = extract_data_slice(ds_cdi, "cdi")
    cdi_df = create_dataframe(ds_cdi.longitude.values, ds_cdi.latitude.values, cdi_slice, ['lon', 'lat', 'cdi'])

    # Load rainfall data
    ds_rain = load_netcdf(rain_path)
    if ds_rain is None:
        return
    ds_rain = preprocess_dataset(ds_rain, {'lat': 'latitude', 'lon': 'longitude'})

    # Process 1-month outlook
    rain_slice_1month = extract_data_slice(ds_rain, "percentage_of_ensembles")
    rain_df_1month = create_dataframe(ds_rain.longitude.values, ds_rain.latitude.values, rain_slice_1month, ['lat', 'lon', 'rain'])
    process_outlook(cdi_df, rain_df_1month, ds_cdi, ds_rain.time[0].values, output_path_1month, "1-month")


    # Process 3-month average outlook
    rain_avg_3month = calculate_3month_average(ds_rain, "percentage_of_ensembles")
    rain_df_3month = create_dataframe(ds_rain.longitude.values, ds_rain.latitude.values, rain_avg_3month, ['lat', 'lon', 'rain'])
    process_outlook(cdi_df, rain_df_3month, ds_cdi, ds_rain.time[0].values, output_path_3month, "3-month average")


   

def process_outlook(cdi_df, rain_df, ds_cdi, time, output_path, outlook_type):
    """Process outlook for either 1-month or 3-month average."""
    join_df = cdi_df.merge(rain_df, how='left', on=['lon', 'lat'])
    rmna_df = join_df.dropna()
    num_cores = min(multiprocessing.cpu_count(), 4)
    classified = parallel_classify(rmna_df, num_cores)

    df_out = pd.DataFrame({'lat': cdi_df['lat'], 'lon': cdi_df['lon'], 'outlook': np.nan})
    df_out.loc[rmna_df.index, 'outlook'] = classified
    df_out['outlook'] = df_out['outlook'].astype(np.float32)

    ds_out = create_output_dataset(df_out, ds_cdi.latitude.values, ds_cdi.longitude.values, time)
    save_netcdf(ds_out, output_path)
    print(f"Value counts for {outlook_type} outlook:")
    print_value_counts(ds_out.outlook)

if __name__ == "__main__":
    current_month = datetime.datetime.now().strftime("%m")
    cdi_path = "/Users/sabinmaharjan/projects/python/do/static/file/cdi_1.nc"
    rain_path = f"/Users/sabinmaharjan/projects/python/do/static/file/p_atmos_q5_pr_s_maq5_pumedian_2024{current_month}01_rt.nc"
    output_path_1month = f"/Users/sabinmaharjan/projects/python/do/static/result/nc/1_months/{current_month}_Final_2024.nc"
    output_path_3month = f"/Users/sabinmaharjan/projects/python/do/static/result/nc/3_months/app{current_month}_Final_2024.nc"

    

    process_drought_outlook(cdi_path, rain_path, output_path_1month, output_path_3month)

File saved successfully: /Users/sabinmaharjan/projects/python/do/static/result/nc/1_months/10_Final_2024.nc
Value counts for 1-month outlook:
Value: 1.0, Count: 238464
Value: 2.0, Count: 3440
Value: 3.0, Count: 962
Value: 4.0, Count: 2200
Value: 5.0, Count: 12349
Value: 6.0, Count: 16147
Value: nan, Count: 299159


AttributeError: 'DataArray' object has no attribute 'flatten'