In [None]:
import os
import re
import numpy as np
import pytz
from datetime import datetime
from netCDF4 import Dataset
from scipy.spatial import cKDTree
from timezonefinder import TimezoneFinder

def get_variable(dataset, variable_names, potential_paths):
    for variable_name in variable_names:
        if variable_name in dataset.variables:
            return dataset.variables[variable_name][:]
    for path in potential_paths:
        current_object = dataset
        for group_name in path.split('/'):
            if group_name and group_name in current_object.groups:
                current_object = current_object.groups[group_name]
            else:
                break
        else:
            for variable_name in variable_names:
                if variable_name in current_object.variables:
                    return current_object.variables[variable_name][:]
    raise KeyError(f"None of the variable names {variable_names} were found in any of the specified paths.")

def convert_utc_to_local(utc_dt, lat, lon):
    tf = TimezoneFinder()
    if lon > 180:
        lon -= 360
    tz_str = tf.timezone_at(lng=lon, lat=lat)
    if tz_str is None:
        return utc_dt
    local_tz = pytz.timezone(tz_str)
    local_dt = utc_dt.replace(tzinfo=pytz.utc).astimezone(local_tz)
    return local_dt.replace(tzinfo=None)

def find_adcirc_points_in_pixel(adcirc_tree, pixel_lat_min, pixel_lat_max, pixel_lon_min, pixel_lon_max):
    """Find ADCIRC nodes within a given satellite pixel."""
    points_in_pixel = adcirc_tree.query_ball_point(
        [[pixel_lat_min, pixel_lon_min], [pixel_lat_max, pixel_lon_max]],
        r=0  # Checking exact bounds
    )
    # Flatten list of points
    points_in_pixel = set(p for sublist in points_in_pixel for p in sublist)
    return points_in_pixel

def process_satellite_file(filepath, adcirc_tree, bounding_box):
    """
    Process a single satellite file with rasterized (pixelated) observations.
    Maps ADCIRC points to pixels and assigns uniform values for points inside the same pixel.
    """
    try:
        satellite_data = Dataset(filepath, 'r')
        paths = ['', 'data_01', 'data_01/ku']
        
        # Extract relevant satellite data (ssh_karin, height_cor_xover, geoid, and its quality flag)
        ssh_karin = get_variable(satellite_data, ['ssh_karin'], paths)
        height_cor_xover = get_variable(satellite_data, ['height_cor_xover'], paths)
        geoid = get_variable(satellite_data, ['geoid'], paths)
        ssh_karin_qual = get_variable(satellite_data, ['ssh_karin_qual'], paths)
        satellite_lats = get_variable(satellite_data, ['latitude', 'lat'], paths)
        satellite_lons = get_variable(satellite_data, ['longitude', 'lon'], paths)
        satellite_lons = np.where(satellite_lons > 180, satellite_lons - 360, satellite_lons)
        
        # FillValue handling (no need for scale factor as netCDF4 handles it)
        fill_value = 2147483647
        
        closest_nodes = {}
        num_lines, num_pixels = ssh_karin.shape
        
        for i in range(num_lines):
            for j in range(num_pixels):
                # Skip if the SSH value is invalid or the quality flag is bad
                if ssh_karin[i, j] == fill_value or ssh_karin_qual[i, j] != 0:
                    continue
                
                # Apply corrections: add height_cor_xover, subtract geoid
                corrected_ssh = ssh_karin[i, j] + height_cor_xover[i, j]
                final_ssh = corrected_ssh - geoid[i, j]
                
                # Get the pixel boundaries (assuming regular pixel spacing)
                lat_min = satellite_lats[i, j]
                lat_max = satellite_lats[i, j]  # Adjust as needed if pixel sizes vary
                lon_min = satellite_lons[i, j]
                lon_max = satellite_lons[i, j]  # Adjust as needed
                
                if bounding_box[0] <= lat_min <= bounding_box[1] and bounding_box[2] <= lon_min <= bounding_box[3]:
                    ssh_value = final_ssh  # Corrected SSH value with crossover and geoid adjustments

                    # Find all ADCIRC points within this pixel
                    adcirc_points_in_pixel = find_adcirc_points_in_pixel(adcirc_tree, lat_min, lat_max, lon_min, lon_max)
                    
                    for node in adcirc_points_in_pixel:
                        if node not in closest_nodes:
                            closest_nodes[node] = ssh_value
                        
        return closest_nodes
    except Exception as e:
        print(f"Failed to process {filepath}: {str(e)}")
        return {}
    finally:
        satellite_data.close()

def parse_filename_datetime(filename):
    """Parses the start datetime from the filename using a regex pattern."""
    pattern = r'_(\d{8}T\d{6})_'  # Match the start datetime (e.g., 20240731T232256)
    match = re.search(pattern, filename)
    if match:
        return datetime.strptime(match.group(1), '%Y%m%dT%H%M%S')
    return None

def main():
    satellite_directory = '/work2/07174/soelem/stampede3/Paper-3/karin_data'
    adcirc_data_path = '/scratch/07174/soelem/global_2-20km/fort.63.nc'
    output_file_path = 'swot_swath.dat'
    bounding_box = [-90, 90, 180, -188]
    start_time = datetime(2024, 8, 1)

    adcirc_data = Dataset(adcirc_data_path, 'r')
    adcirc_lats = adcirc_data.variables['y'][:]
    adcirc_lons = np.where(adcirc_data.variables['x'][:] > 180, adcirc_data.variables['x'][:] - 360, adcirc_data.variables['x'][:])
    adcirc_tree = cKDTree(np.column_stack((adcirc_lats, adcirc_lons)))
    adcirc_data.close()

    # Parse filenames and extract datetime using the new function
    files_datetimes = [(f, parse_filename_datetime(f)) for f in os.listdir(satellite_directory)]
    files_datetimes = [fd for fd in files_datetimes if fd[1] is not None]
    files_datetimes.sort(key=lambda x: x[1])

    with open(output_file_path, 'w') as file:
        file.write("# SWOT Swath flagged observations\n3600.0\n0.0\n")
        last_written_hour = 0

        for i, (filename, file_datetime) in enumerate(files_datetimes):
            filepath = os.path.join(satellite_directory, filename)
            
            if i == 0:
                local_time = convert_utc_to_local(file_datetime, bounding_box[0], bounding_box[2])
                current_hour = int((local_time - start_time).total_seconds() / 3600)
                if current_hour > 0:
                    file.write("##\n" * current_hour)
                last_written_hour = current_hour
            else:
                local_time = convert_utc_to_local(file_datetime, bounding_box[0], bounding_box[2])
                current_hour = int((local_time - start_time).total_seconds() / 3600)

            if current_hour > last_written_hour:
                file.write("##\n" * (current_hour - last_written_hour))
                last_written_hour = current_hour

            closest_nodes = process_satellite_file(filepath, adcirc_tree, bounding_box)
            for node, topo_rounded in closest_nodes.items():
                file.write(f"{node} {round(topo_rounded, 4)}\n")

if __name__ == "__main__":
    main()