In [None]:
import argparse
import glob
from datetime import datetime, timedelta
import numpy as np
import xarray as xr
import dask
from distributed import Client
from dask_jobqueue import PBSCluster
import os
import numpy as np
import matplotlib.pyplot as plt
from dask import delayed
from tqdm import tqdm
import pandas as pd

## Functions

In [None]:
# Function to get dates 15 days before and after a given day of the year
def get_dates_15_days_around(day_of_year, dminus, dplus, string_format_forecast_File='%Y-%m-%dT%H:%M:%S.%fZ', file_interval=0.5):
    current_year = datetime.now().year
    date_obj = datetime.strptime(f"{current_year} {day_of_year:03d}", "%Y %j")
    dates_list = []
    for delta in np.arange(dminus, dplus + 1, file_interval):
        date = date_obj + timedelta(days=delta)
        date_str = date.strftime(string_format_forecast_File).replace('Z', '.000000000')
        dates_list.append(date_str)
    return dates_list

def find_intersections(dates_get, DS1_time):
    # Convert DS1_time to list of strings if they are datetime64
    DS1_time_str = [str(date) for date in DS1_time]
    
    # Process dates_get to match format of DS1_time entries
    dates_get_processed = [date.replace('.000000.000000000', '.000000000') for date in dates_get]
    
    # Strip the year from DS1_time entries
    DS1_time_no_year = [date[4:] for date in DS1_time_str]
    
    # Find intersections without the year
    intersection = [date for date in DS1_time_str if date[4:] in dates_get_processed]
    
    return intersection


def gaussian_weights(dates, target_date, sigma):
    # Convert target date to datetime object
    target_date = datetime.strptime(target_date, "%Y-%m-%dT%H:%M:%S.%f")
    weights = []
    for date in dates:
        # Convert date to datetime object
        date = datetime.strptime(date, "%Y-%m-%dT%H:%M:%S.%f")
        # Calculate the difference in days
        delta_days = (date - target_date).days
        # Calculate the Gaussian weight
        weight = np.exp(-0.5 * (delta_days / sigma) ** 2)
        weights.append(weight)
    return np.array(weights)

## Dask

In [None]:
# Dask cluster setup
project_num = 'NAML0001'
cluster = PBSCluster(account=project_num,
                         walltime='04:00:00',
                         cores=1,
                         memory='25GB',
                         shared_temp_directory='/glade/derecho/scratch/wchapman/tmp',
                         queue='casper')

cluster.scale(jobs=22)
client = Client(cluster)
client

In [None]:
dirin = '/glade/campaign/cisl/aiml/ERA5_Forecasts/'
FNS = sorted(glob.glob(dirin+"*.zarr"))

In [None]:
DS2 = xr.open_zarr(FNS[0])
DS1 = xr.open_zarr(FNS[1])
DS1_time = np.array(DS1['time'])
DS2_time = np.array(DS2['time'])

In [None]:
ddo = pd.date_range(start='2020-01-01',end='2021-01-01',freq='12h')

In [None]:
import numpy as np
import xarray as xr
import dask.array as da
from datetime import datetime, timedelta
import os

# Function to get dates 15 days before and after a given day of the year
def get_dates_15_days_around(day_of_year, dminus, dplus, string_format_forecast_File, file_interval=0.5):
    current_year = datetime.now().year
    date_obj = datetime.strptime(f"{current_year} {day_of_year:03d}", "%Y %j")
    dates_list = []
    for delta in np.arange(dminus, dplus + 1, file_interval):
        date = date_obj + timedelta(days=delta)
        date_str = date.strftime(string_format_forecast_File).replace('Z', '.000000000')
        dates_list.append(date_str)
    return dates_list

# Function to find intersections without the year
def find_intersections(dates_get, DS1_time):
    DS1_time_str = [str(date) for date in DS1_time]
    dates_get_processed = [date.replace('.000000.000000000', '.000000000') for date in dates_get]
    DS1_time_no_year = [date[4:] for date in DS1_time_str]
    intersection = [date for date in DS1_time_str if date[4:] in dates_get_processed]
    return intersection

# Function to calculate Gaussian weights with wrapping around the year
def gaussian_weights(dates, target_doy, sigma):
    weights = []
    for date in dates:
        date_doy = datetime.strptime(date[:10], "%Y-%m-%d").timetuple().tm_yday
        delta_days = min(abs(date_doy - target_doy), 365 - abs(date_doy - target_doy))
        weight = np.exp(-0.5 * (delta_days / sigma) ** 2)
        weights.append(weight)
    return np.array(weights)

# Example usage
sigma = 10  # Standard deviation for Gaussian decay
ddo = pd.date_range(start='2020-01-01',end='2021-01-01',freq='12h')
# Assuming DS1, DS2, DS1_time, DS2_time, and ddo are already defined
for tt in ddo:
    target_doy = tt.dayofyear
    output_file = f'/glade/derecho/scratch/wchapman/IFS_forecast_climo/IFS_forecast_climo_{target_doy:03}.nc'
    print(f'attempting: {output_file}')
    # Check if the output file already exists
    if os.path.exists(output_file):
        print(f"File {output_file} already exists. Skipping.")
        continue
    
    dates_get = get_dates_15_days_around(target_doy, -15, 15, string_format_forecast_File='-%m-%dT%H:%M:%S.%fZ', file_interval=0.5)
    
    intersections1 = find_intersections(dates_get, DS1_time)
    DD1 = DS1.sel(time=intersections1)
    
    intersections2 = find_intersections(dates_get, DS2_time)
    DD2 = DS2.sel(time=intersections2)
    
    DS1_subset_time = DD1['time']
    DS2_subset_time = DD2['time']
    
    weights1 = gaussian_weights([str(date) for date in DS1_subset_time.values], target_doy, sigma)
    weights2 = gaussian_weights([str(date) for date in DS2_subset_time.values], target_doy, sigma)

    weights3 = weights1 / (np.sum(weights1) + np.sum(weights2))
    weights4 = weights2 / (np.sum(weights1) + np.sum(weights2))

    weights3_da = da.from_array(weights3, chunks=(len(weights3),))
    weights4_da = da.from_array(weights4, chunks=(len(weights4),))

    DD1_weighted = DD1 * weights3_da[:, np.newaxis, np.newaxis, np.newaxis]
    DD2_weighted = DD2 * weights4_da[:, np.newaxis, np.newaxis, np.newaxis]

    combined = xr.combine_by_coords([DD1_weighted, DD2_weighted]).sum('time')
    combined = combined.persist()  # Persist the combined result to optimize memory usage
    combined.to_netcdf(output_file)


In [None]:
f'/glade/derecho/scratch/wchapman/IFS_forecast_climo/Forecast_climo_{target_doy:03}.nc'

In [None]:
client.shutdown()

In [None]:
rm dask-worker*