In [9]:
"""
Map elevation, ERA5 annual mean temperature/pressure, and FY-4B scan time to the LUT grid.

========================================================================================
Function
----------------------------------------------------------------------------------------
map_elev_era5_scan_to_lut_safe()

========================================================================================
Purpose
----------------------------------------------------------------------------------------
This function maps high-resolution terrain elevation (m), ERA5 annual mean
2-meter air temperature (K) and surface pressure (Pa), as well as FY-4B AGRI
scan time (UTC), onto the grid defined by the FY-4B GEO Look-Up Table (LUT).
The resulting dataset is saved as a unified NetCDF file.

========================================================================================
Input Files and Variables
----------------------------------------------------------------------------------------
elev_file : str  
    Path to the elevation dataset (NetCDF format).  
    Example: `"world_ll_elev_0.05deg.nc4"`  
    Variables:
        - `latitude`  [degrees_north]  
        - `longitude` [degrees_east]  
        - `elevation` [m]

era5_file : str  
    Path to the ERA5 monthly mean dataset (NetCDF format).  
    Example: `"2024monthly.nc"`  
    Variables:
        - `t2m` [K]  → 2-meter air temperature  
        - `sp`  [Pa] → surface pressure  
    Dimensions:
        `(month, latitude, longitude)`  
    The function calculates the annual mean for each variable.

lut_file : str  
    Path to the FY-4B GEO LUT file (binary .raw format).  
    Example: `"FY4B-_DISK_1050E_GEO_NOM_LUT_20240227000000_4000M_V0001.raw"`  
    Description:
        - Binary file containing latitude and longitude pairs  
        - Shape: (2748, 2748, 2)  
        - Each element corresponds to a pixel location on the FY-4B disk view  
    Units:
        - degrees_north, degrees_east

hdf_file : str  
    Path to the FY-4B AGRI Level-1 observation file (HDF format).  
    Example:  
        `"FY4B-_AGRI--_N_DISK_1050E_L1-_FDI-_MULT_NOM_20240923040000_20240923041459_4000M_V0001.HDF"`  
    Dataset:
        - `"NOMObs/NOMObsTime"` → Start and end UTC time of each scan line  
          Format: integer `YYYYMMDDHHMMSSsss` (UTC, with millisecond precision)

out_file : str  
    Path to the output NetCDF file where all mapped variables will be saved.  
    Example: `"result.nc"`

========================================================================================
Output File (NetCDF4 format)
----------------------------------------------------------------------------------------
Dimensions:
    - `y = 2748`
    - `x = 2748`

Variables:
    - `latitude(y, x)`   [degrees_north] → LUT grid latitude  
    - `longitude(y, x)`  [degrees_east]  → LUT grid longitude  
    - `altitude(y, x)`   [m]             → Mapped elevation  
    - `t2m(y, x)`        [K]             → ERA5 annual mean 2-m air temperature  
    - `sp(y, x)`         [Pa]            → ERA5 annual mean surface pressure  
    - `scan_time(y, x)`  [int64, UTC]    → Per-pixel scan time interpolated
                                           from start and end line times,
                                           formatted as `YYYYMMDDHHMMSSsss`

========================================================================================
"""

import numpy as np
from netCDF4 import Dataset
import h5py
import os
from datetime import datetime, timedelta, timezone
from scipy.interpolate import RegularGridInterpolator

def map_elev_era5_scan_to_lut_safe(
    elev_file: str,
    era5_file: str,
    lut_file: str,
    hdf_file: str,
    out_file: str
):


    # ---------------- 1. Read elevation data ----------------
    with Dataset(elev_file, 'r') as nc:
        lat_elev = nc.variables['latitude'][:]   # degrees_north
        lon_elev = nc.variables['longitude'][:]  # degrees_east
        elev     = nc.variables['elevation'][:]  # meters

    # ---------------- 2. Read FY-4B LUT (lat/lon grid) ----------------
    raw_image = np.fromfile(lut_file)
    raw_image = raw_image.reshape(-1, 2)
    lat_lut   = raw_image[:, 0].reshape(2748, 2748)
    lon_lut   = raw_image[:, 1].reshape(2748, 2748)
    mask = (lat_lut > 999) | (lon_lut > 999)
    valid = ~mask
    ny, nx = lat_lut.shape

    # ---------------- 3. Map elevation to LUT grid ----------------
    # Use nearest neighbor index mapping
    def nearest_idx(src_coords, target_coords):
        idx = np.searchsorted(src_coords, target_coords, side="left")
        idx = np.clip(idx, 1, len(src_coords)-1)
        idx = idx - (np.abs(target_coords - src_coords[idx-1]) < np.abs(target_coords - src_coords[idx]))
        return idx

    lat_idx_elev = nearest_idx(lat_elev, lat_lut)
    lon_idx_elev = nearest_idx(lon_elev, lon_lut)
    elev_resampled = np.full(lat_lut.shape, np.nan)
    elev_resampled[valid] = elev[lat_idx_elev[valid], lon_idx_elev[valid]]

    # ---------------- 4. Read ERA5 monthly data and compute annual mean ----------------
    with Dataset(era5_file, 'r') as nc:
        t2m_monthly = nc.variables['t2m'][:]  # [K]
        sp_monthly  = nc.variables['sp'][:]   # [Pa]
        lat_era5    = nc.variables['latitude'][:]
        lon_era5    = nc.variables['longitude'][:]

    t2m_annual = np.mean(t2m_monthly, axis=0)
    sp_annual  = np.mean(sp_monthly, axis=0)

    # ---------------- 5. Interpolate ERA5 data to LUT grid ----------------
    # If ERA5 latitude is descending, reverse it for correct interpolation
    t2m_interp = RegularGridInterpolator(
        (lat_era5[::-1], lon_era5),
        t2m_annual[::-1, :],
        bounds_error=False,
        fill_value=np.nan
    )
    sp_interp = RegularGridInterpolator(
        (lat_era5[::-1], lon_era5),
        sp_annual[::-1, :],
        bounds_error=False,
        fill_value=np.nan
    )

    points = np.stack([lat_lut.ravel(), lon_lut.ravel()], axis=-1)
    t2m_resampled = t2m_interp(points).reshape(lat_lut.shape)
    sp_resampled  = sp_interp(points).reshape(lat_lut.shape)

    # ---------------- 6. Read FY-4B HDF scan times ----------------
    with h5py.File(hdf_file, 'r') as f:
        scan_time_raw = f["NOMObs/NOMObsTime"][:]  # shape: (ny_raw, 2)

    # Extract start and end timestamps from filename
    filename = os.path.basename(hdf_file)
    parts = filename.split('_')
    start_str = parts[-4]
    end_str   = parts[-3]
    ts_start = int(start_str + "000")
    te_end   = int(end_str + "000")

    # Replace invalid scan times (marked as 9999)
    ny_raw = scan_time_raw.shape[0]
    for i in range(ny_raw):
        ts, te = scan_time_raw[i]
        if ts == 9999 and te == 9999:
            if i < ny_raw // 2:
                scan_time_raw[i, 0] = ts_start
                scan_time_raw[i, 1] = ts_start
            else:
                scan_time_raw[i, 0] = te_end
                scan_time_raw[i, 1] = te_end

    # Helper functions for time conversion
    def parse_time(ts_int):
        """Convert FY-4B integer timestamp to Python datetime."""
        ts_str = f"{int(ts_int):017d}"
        dt = datetime.strptime(ts_str[:14], "%Y%m%d%H%M%S")
        ms = int(ts_str[14:])
        return dt + timedelta(milliseconds=ms)

    def datetime_to_unix(dt):
        """Convert datetime to UNIX seconds (float)."""
        return dt.replace(tzinfo=timezone.utc).timestamp()

    def unix_to_int(ts):
        """Convert UNIX seconds to FY-4B integer timestamp format."""
        dt = datetime.fromtimestamp(ts, tz=timezone.utc)
        ms = int(dt.microsecond / 1000)
        return int(dt.strftime("%Y%m%d%H%M%S") + f"{ms:03d}")

    # ---------------- 7. Interpolate scan time across each scan line ----------------
    scan_time_sec = np.full((ny, nx), np.nan, dtype=np.float64)
    for i in range(ny):
        ts_int, te_int = scan_time_raw[i]
        dt_start = parse_time(ts_int)
        dt_end   = parse_time(te_int)
        ts_sec = datetime_to_unix(dt_start)
        te_sec = datetime_to_unix(dt_end)

        valid_cols = np.where(valid[i, :])[0]
        n_valid = len(valid_cols)
        if n_valid > 0:
            scan_time_sec[i, valid_cols] = np.linspace(ts_sec, te_sec, n_valid)

    # Convert interpolated UNIX seconds back to FY-4B integer timestamp format
    scan_time_int = np.full((ny, nx), 0, dtype=np.int64)
    valid_idx = ~np.isnan(scan_time_sec)
    scan_time_int[valid_idx] = np.vectorize(unix_to_int)(scan_time_sec[valid_idx])

    # ---------------- 8. Write all mapped data to NetCDF ----------------
    ds_out = Dataset(out_file, 'w', format='NETCDF4')
    ds_out.createDimension('y', ny)
    ds_out.createDimension('x', nx)

    # Define variables
    lat_var  = ds_out.createVariable('latitude',  'f8', ('y','x'))
    lon_var  = ds_out.createVariable('longitude', 'f8', ('y','x'))
    elev_var = ds_out.createVariable('altitude',  'f8', ('y','x'), fill_value=np.nan)
    t2m_var  = ds_out.createVariable('t2m',       'f8', ('y','x'), fill_value=np.nan)
    sp_var   = ds_out.createVariable('sp',        'f8', ('y','x'), fill_value=np.nan)
    scan_var = ds_out.createVariable('scan_time', 'i8', ('y','x'), fill_value=0)

    # Assign data
    lat_var[:, :]   = lat_lut
    lon_var[:, :]   = lon_lut
    elev_var[:, :]  = elev_resampled
    t2m_var[:, :]   = t2m_resampled
    sp_var[:, :]    = sp_resampled
    scan_var[:, :]  = scan_time_int

    # Add metadata
    lat_var.units  = 'degrees_north'
    lon_var.units  = 'degrees_east'
    elev_var.units = 'm'
    t2m_var.units  = 'K'
    sp_var.units   = 'Pa'
    scan_var.units = 'YYYYMMDDHHMMSSsss (UTC)'

    elev_var.description = 'High-resolution elevation mapped to LUT grid'
    t2m_var.description  = 'ERA5 annual mean 2-meter temperature mapped to LUT grid'
    sp_var.description   = 'ERA5 annual mean surface pressure mapped to LUT grid'
    scan_var.description = 'FY-4B scan time interpolated from line start/end times, UTC with millisecond precision'

    ds_out.close()
    print("✅ Saved successfully:", out_file)


# ================= Example Usage =================
elev_file = "E:/fengyun/code/input/world_ll_elev_0.05deg.nc4"
era5_file = "E:/fengyun/code/input/2024monthly.nc"
lut_file  = "E:/fengyun/code/input/FY4B-_DISK_1050E_GEO_NOM_LUT_20240227000000_4000M_V0001.raw"
hdf_file  = "E:/fengyun/code/input/FY4B-_AGRI--_N_DISK_1050E_L1-_FDI-_MULT_NOM_20240923040000_20240923041459_4000M_V0001.HDF"
out_file  = "E:/fengyun/code/output/result.nc"

map_elev_era5_scan_to_lut_safe(elev_file, era5_file, lut_file, hdf_file, out_file)


✅ Saved successfully: E:/fengyun/code/output/result.nc


In [10]:
"""
Fast Solar Position Computation using pvlib SPA
-----------------------------------------------

This script computes solar position parameters (e.g., zenith, azimuth, and equation of time)
for each pixel in a Look-Up Table (LUT) stored as a NetCDF file. It employs pvlib’s
high-accuracy Solar Position Algorithm (SPA) and supports parallel computation for
large LUT datasets. The implementation is optimized for modern NumPy versions
(>= 1.24) and ensures fast, accurate, and memory-efficient solar geometry calculation.

------------------------------------------------------------------------------
🔹 INPUT
------------------------------------------------------------------------------
The function reads the following variables from the input NetCDF LUT file:

- **latitude**   (2D array, degrees)  
- **longitude**  (2D array, degrees)  
- **altitude**   (2D array, meters above sea level)  
- **t2m**        (2D array, air temperature in Kelvin , K)  
- **sp**         (2D array, surface air pressure in Pascals, Pa)  
- **scan_time**  (2D array, 17-digit integer UTC timestamps,  
                  e.g., `20250101053045123` → 2025-01-01 05:30:45.123 UTC)

Optional parameters:
- **delta_t** : float  
  Difference between Terrestrial Time (TT) and UT1 in seconds.  
  Typical value ≈ 67–70 s (default: 67.0).  

- **numthreads** : int  
  Number of CPU threads used for parallel computation (default: 4).

------------------------------------------------------------------------------
🔹 OUTPUT
------------------------------------------------------------------------------
The computed solar geometry parameters are written **back into the same NetCDF file**.
New variables will be created if they do not already exist:

1. **apparent_zenith**       (degrees)  
2. **zenith**                (degrees)  
3. **azimuth**               (degrees)  
4. **equation_of_time**      (minutes)

Each variable is stored as a 2D array with the same dimensions as the latitude/longitude grids.

"""

import os
import numpy as np
import pandas as pd
from netCDF4 import Dataset
from joblib import Parallel, delayed
import pvlib

def compute_solarposition_nc_fast(nc_file, delta_t=67.0, numthreads=4):


    ds = Dataset(nc_file, 'r+')
    try:
        # ---------------- 1. Read variables ----------------
        lat = ds['latitude'][:]
        lon = ds['longitude'][:]
        elev = ds['altitude'][:]
        scan_time_int = ds['scan_time'][:]
        t2m = ds['t2m'][:]   # K
        sp  = ds['sp'][:]    # Pa

        ny, nx = lat.shape
        valid_mask = ~np.isnan(lat) & (scan_time_int > 0)
        scan_time_valid = scan_time_int[valid_mask].astype(np.int64)

        # ---------------- 2. Accurate vectorized time parsing ----------------
        # Ensure all values are 17-digit zero-padded strings
        ts_str = np.char.zfill(scan_time_valid.astype(str), 17)

        # Extract first 14 digits (datetime part) and last 3 digits (milliseconds)
        dt_part = np.array([s[:14] for s in ts_str])
        ms_part = np.array([s[14:] for s in ts_str]).astype(np.int64)

        # Convert to datetime + milliseconds
        time_flat = pd.to_datetime(dt_part, format='%Y%m%d%H%M%S', utc=True)
        time_flat = time_flat + pd.to_timedelta(ms_part, unit='ms')

        # ---------------- 3. Flatten valid variables ----------------
        lat_flat = lat[valid_mask]
        lon_flat = lon[valid_mask]
        elev_flat = elev[valid_mask]
        pressure_flat = sp[valid_mask]
        temperature_flat = t2m[valid_mask] - 273.15  # K -> °C

        # ---------------- 4. Parallel chunked SPA computation ----------------
        n = len(time_flat)
        if n == 0:
            print("⚠️ No valid scan_time entries found. Exiting.")
            return

        n_jobs = min(numthreads, max(1, os.cpu_count() or 1))
        chunk_size = int(np.ceil(n / n_jobs))
        chunks = [(i, min(i + chunk_size, n)) for i in range(0, n, chunk_size)]

        def compute_chunk(ch):
            s, e = ch
            return pvlib.solarposition.spa_python(
                time=time_flat[s:e],
                latitude=lat_flat[s:e],
                longitude=lon_flat[s:e],
                altitude=elev_flat[s:e],
                pressure=pressure_flat[s:e],
                temperature=temperature_flat[s:e],
                delta_t=delta_t,
                how='numpy'
            )

        results = Parallel(n_jobs=n_jobs, prefer='threads')(
            delayed(compute_chunk)(ch) for ch in chunks
        )

        solpos_flat = pd.concat(results, ignore_index=True)

        # ---------------- 5. Reconstruct 2D grids ----------------
        def to_grid(colname):
            grid = np.full((ny, nx), np.nan, dtype=np.float32)
            grid[valid_mask] = solpos_flat[colname].values
            return grid

        output_vars = {
            'apparent_zenith':      ('degrees', to_grid('apparent_zenith')),
            'zenith':               ('degrees', to_grid('zenith')),
            # 'apparent_elevation':   ('degrees', to_grid('apparent_elevation')),
            # 'elevation':            ('degrees', to_grid('elevation')),
            'azimuth':              ('degrees', to_grid('azimuth')),
            'equation_of_time':     ('minutes', to_grid('equation_of_time')),
        }

        # ---------------- 6. Write results back to NetCDF ----------------
        for name, (units, data) in output_vars.items():
            if name in ds.variables:
                ds[name][:] = data
            else:
                var = ds.createVariable(
                    name, 'f8', ('y', 'x'),
                    zlib=True, complevel=4, fill_value=np.nan
                )
                var.units = units
                var[:, :] = data

        print(f"✅ Fast & accurate solar position computation completed and saved to: {nc_file}")

    finally:
        ds.close()


# Example usage:
if __name__ == "__main__":
    nc_file = "E:/fengyun/code/output/result.nc"
    compute_solarposition_nc_fast(nc_file, delta_t=69.1322, numthreads=8)


✅ Fast & accurate solar position computation completed and saved to: E:/fengyun/code/output/result.nc


In [11]:
"""
FY-4B Satellite Viewing Angle Computation
-----------------------------------------

This script defines a function to compute the *viewing geometry* (zenith and azimuth angles)
from the FY-4B geostationary satellite to each ground point. The computation is based on
the latitude, longitude, and terrain elevation (altitude) grids stored in a NetCDF file.

The results represent the line-of-sight direction from the ground to the satellite
and are written back to the same NetCDF file.

------------------------------------------------------------------------------
🔹 INPUT
------------------------------------------------------------------------------
The function reads the following variables from the input NetCDF file:

- **latitude**  (2D array, degrees)
- **longitude** (2D array, degrees)
- **altitude**  (2D array, meters above sea level , m)

Optional parameters:
- **sat_lon_deg** : Satellite sub-point longitude in degrees (default = 105.0° for FY-4B)
- **sat_height**  : Satellite orbital height in meters (default = 35,786,000 m)

------------------------------------------------------------------------------
🔹 OUTPUT
------------------------------------------------------------------------------
The function writes two new variables back into the *same* NetCDF file:

1. **satellite_zenith_angle**  
   - Unit: degrees  
   - Description: Angle between the local zenith and the satellite line-of-sight  
     (0° = directly overhead, 90° = horizon)

2. **satellite_azimuth_angle**  
   - Unit: degrees  
   - Description: Azimuth of the satellite direction on the local horizontal plane,  
     measured clockwise from north (0–360°)

Both variables are 2D arrays matching the latitude/longitude grid shape.
"""

import math
import numpy as np 
from netCDF4 import Dataset

def compute_satellite_angles(nc_file: str, sat_lon_deg: float = 105.0, sat_height: float = 35786000.0):
    """
    Compute viewing zenith angle (θ) and azimuth angle (φ) from FY-4B satellite
    to the ground, and write them to the input NetCDF file.

    Parameters:
    -----------
    nc_file : str
        Input/output NetCDF file containing 'latitude', 'longitude', and 'altitude'.
    sat_lon_deg : float, optional
        Satellite longitude in degrees (default is 105.0° for FY-4B).
    sat_height : float, optional
        Satellite orbit height in meters (default is 35786000 m).
    """

    # ---------------- Step 1: Open NetCDF file ----------------
    nc = Dataset(nc_file, 'r+')

    # ---------------- Step 2: Read latitude, longitude, and altitude ----------------
    lat_arr = np.array(nc.variables['latitude'][:], dtype=np.float64)
    lon_arr = np.array(nc.variables['longitude'][:], dtype=np.float64)
    height_arr = np.array(nc.variables['altitude'][:], dtype=np.float64)

    # ---------------- Step 3: Satellite and Earth parameters (GRS80) ----------------
    a = 6378137.0                   # Equatorial radius (m)
    f = 1 / 298.257222101           # Flattening
    E2 = 2 * f - f**2               # Square of eccentricity

    sat_lon_rad = math.radians(sat_lon_deg)
    x_s = (a + sat_height) * math.cos(sat_lon_rad)
    y_s = (a + sat_height) * math.sin(sat_lon_rad)
    z_s = 0.0

    # ---------------- Step 4: Initialize output arrays ----------------
    zenith_angle_arr = np.full(lat_arr.shape, np.nan, dtype=np.float64)
    azimuth_angle_arr = np.full(lat_arr.shape, np.nan, dtype=np.float64)

    # ---------------- Step 5: Valid point mask ----------------
    valid_mask = (~np.isnan(lat_arr)) & (~np.isnan(lon_arr)) & (lat_arr != 0) & (lon_arr != 0)

    lat_valid = lat_arr[valid_mask]
    lon_valid = lon_arr[valid_mask]
    height_valid = height_arr[valid_mask]

    # ---------------- Step 6: Convert to radians and compute prime vertical radius N ----------------
    lat_rad = np.radians(lat_valid)
    lon_rad = np.radians(lon_valid)
    sin_lat = np.sin(lat_rad)
    cos_lat = np.cos(lat_rad)
    N = a / np.sqrt(1 - E2 * sin_lat**2)

    # ---------------- Step 7: Ground point Cartesian coordinates ----------------
    x_p = (N + height_valid) * cos_lat * np.cos(lon_rad)
    y_p = (N + height_valid) * cos_lat * np.sin(lon_rad)
    z_p = ((1 - E2) * N + height_valid) * sin_lat

    # ---------------- Step 8: Direction vector d = S - P ----------------
    x_d = x_s - x_p
    y_d = y_s - y_p
    z_d = z_s - z_p

    # ---------------- Step 9: Local ENU coordinates ----------------
    sin_lon = np.sin(lon_rad)
    cos_lon = np.cos(lon_rad)
    e = -sin_lon * x_d + cos_lon * y_d
    n = -sin_lat * cos_lon * x_d - sin_lat * sin_lon * y_d + cos_lat * z_d
    u =  cos_lat * cos_lon * x_d + cos_lat * sin_lon * y_d + sin_lat * z_d

    # ---------------- Step 10: Compute zenith (θ) and azimuth (φ) angles ----------------
    theta_deg = 90.0 - np.degrees(np.arctan2(u, np.sqrt(e**2 + n**2)))
    phi_deg = np.degrees(np.arctan2(e, n))
    phi_deg = np.where(phi_deg < 0, 360.0 + phi_deg, phi_deg)

    # ---------------- Step 11: Assign to output arrays ----------------
    zenith_angle_arr[valid_mask] = theta_deg
    azimuth_angle_arr[valid_mask] = phi_deg

    # ---------------- Step 12: Write to NetCDF ----------------
    # Zenith angle
    if 'satellite_zenith_angle' in nc.variables:
        zen_var = nc.variables['satellite_zenith_angle']
    else:
        lat_dim, lon_dim = nc.variables['latitude'].dimensions
        zen_var = nc.createVariable('satellite_zenith_angle', 'f8', (lat_dim, lon_dim))
    zen_var[:, :] = zenith_angle_arr
    zen_var.units = 'degree'
    zen_var.long_name = 'Viewing Zenith Angle (θ)'
    zen_var.description = 'Angle between local zenith and line-of-sight to FY-4B satellite'

    # Azimuth angle
    if 'satellite_azimuth_angle' in nc.variables:
        azi_var = nc.variables['satellite_azimuth_angle']
    else:
        lat_dim, lon_dim = nc.variables['latitude'].dimensions
        azi_var = nc.createVariable('satellite_azimuth_angle', 'f8', (lat_dim, lon_dim))
    azi_var[:, :] = azimuth_angle_arr
    azi_var.units = 'degree'
    azi_var.long_name = 'Viewing Azimuth Angle (φ)'
    azi_var.description = 'Azimuth of satellite direction in local horizon plane (0–360°)'

    # ---------------- Step 13: Close NetCDF ----------------
    nc.close()
    print("✅ Viewing zenith (θ) and azimuth (φ) angles have been successfully written to the NetCDF file.")


# ---------------- Example Usage ----------------

nc_file_path = "E:/fengyun/code/output/result.nc"
compute_satellite_angles(nc_file_path)


✅ Viewing zenith (θ) and azimuth (φ) angles have been successfully written to the NetCDF file.


In [12]:
"""
Atmospheric Refraction Correction for Satellite Zenith Angle
------------------------------------------------------------

This script defines functions to correct satellite viewing zenith angles
for the effect of atmospheric refraction, based on local air temperature
and pressure from a NetCDF file (e.g., FY-4B LUT).

Atmospheric refraction bends the light path toward the surface, making the
observed satellite appear slightly higher in the sky than its true
geometric position. This correction ensures more accurate satellite
geometry and radiative transfer modeling.

------------------------------------------------------------------------------
🔹 INPUT
------------------------------------------------------------------------------
The NetCDF file must contain the following variables:

- **satellite_zenith_angle**  (2D array, degrees)
  Uncorrected zenith angle from the ground to the satellite (0° = overhead, 90° = horizon)

- **t2m**  (2D array, Kelvin , K)
  2-meter air temperature

- **sp**  (2D array, Pascals, Pa)
  Surface air pressure

------------------------------------------------------------------------------
🔹 OUTPUT
------------------------------------------------------------------------------
The function creates or overwrites the following variable in the same NetCDF file:

- **satellite_zenith_angle_corrected**  (2D array, degrees)
  Satellite zenith angle corrected for atmospheric refraction effects.
  Units: degrees
  Dimensions: same as `satellite_zenith_angle`
  Fill value: NetCDF default for float32

------------------------------------------------------------------------------
"""

import numpy as np
from netCDF4 import Dataset, default_fillvals

def atmospheric_refraction_correction(local_pressure, local_temp,
                                      topocentric_elevation_angle_wo_atmosphere,
                                      atmos_refract=0.5667):
    """
    Calculate atmospheric refraction correction for a topocentric elevation angle.
    """
    switch = topocentric_elevation_angle_wo_atmosphere >= -1.0 * (0.26667 + atmos_refract)
    delta_e = ((local_pressure / 1010.0) * (283.0 / (273 + local_temp))
               * 1.02 / (60 * np.tan(np.radians(
                   topocentric_elevation_angle_wo_atmosphere
                   + 10.3 / (topocentric_elevation_angle_wo_atmosphere + 5.11))))) * switch
    return delta_e


def correct_atmospheric_refraction(nc_file_path):
    """
    Apply atmospheric refraction correction to satellite zenith angle in a NetCDF file.

    Parameters
    ----------
    nc_file_path : str
        Path to the NetCDF file containing 'satellite_zenith_angle', 't2m', and 'sp'.

    The function creates a new variable 'satellite_zenith_angle_corrected' in the same file.
    """
    # Open NetCDF file
    nc = Dataset(nc_file_path, "r+")

    # Read variables
    zenith   = nc.variables["satellite_zenith_angle"][:]  # degrees
    temp_K   = nc.variables["t2m"][:]                     # Kelvin
    press_Pa = nc.variables["sp"][:]                      # Pascals

    # ---- Unit conversion ----
    temp_C    = temp_K - 273.15      # Kelvin → Celsius
    press_hPa = press_Pa / 100.0     # Pascals → hPa

    # ---- Convert to elevation angle ----
    elev = 90.0 - zenith

    # ---- Valid mask: all three inputs are not NaN ----
    valid = (~np.isnan(zenith)) & (~np.isnan(temp_C)) & (~np.isnan(press_hPa))

    # ---- Initialize corrected zenith angle array as NaN ----
    zenith_corrected = np.full_like(zenith, np.nan, dtype=np.float32)

    # ---- Apply correction only to valid grid points ----
    if np.any(valid):
        delta_e = atmospheric_refraction_correction(
            local_pressure=press_hPa[valid],
            local_temp=temp_C[valid],
            topocentric_elevation_angle_wo_atmosphere=elev[valid]
        )
        elev_corrected = elev[valid] + delta_e
        zenith_corrected[valid] = 90.0 - elev_corrected

    # ---- Write the corrected zenith angle as a new variable ----
    if "satellite_zenith_angle_corrected" not in nc.variables:
        zen_var = nc.createVariable(
            "satellite_zenith_angle_corrected", "f4",
            nc.variables["satellite_zenith_angle"].dimensions,
            fill_value=default_fillvals["f4"]
        )
        zen_var.units = "degrees"
        zen_var.long_name = "satellite zenith angle corrected for atmospheric refraction"

    nc.variables["satellite_zenith_angle_corrected"][:] = zenith_corrected
    nc.close()


# ==== Usage example ====
correct_atmospheric_refraction("E:/fengyun/code/output/result.nc")
