In [2]:
import xml.etree.ElementTree as ET

def xml_to_dict(element):

    if len(element) == 0:
        return element.text
    
    result = {}
    
    for child in element:

        child_dict = xml_to_dict(child)

        if child.tag in result:
            if isinstance(result[child.tag], list):
                result[child.tag].append(child_dict)
            else:
                result[child.tag] = [result[child.tag], child_dict]
        else:
            result[child.tag] = child_dict
            
    return result

In [4]:
import os
import rasterio
import numpy as np
import rioxarray as rxr

def paramglint(ang: dict, solar_zn: float, view_zn: float, optical_depth_total: float, nw: float) -> dict:
    
    """
    Returns the Fresnel's reflectance and direct transmittance from atmosphere.
    """
        
    # Incidence angle:
    raa = abs(ang['solar_az'] - ang['view_az'])
    
    raa = np.where(raa > 180, 360 - raa, raa)
    
    cosTheta = np.sqrt((((np.cos(ang['solar_zn'] * (np.pi / 180)) * np.cos(ang['view_zn'] * (np.pi / 180))) + (np.sin(ang['solar_zn'] * (np.pi / 180)) * np.sin(ang['view_zn'] * (np.pi / 180)) * np.cos(raa * (np.pi / 180)))) + 1) / 2) # cos2Theta = 2cos^2(Theta) - 1:
    
    Theta = np.arccos(cosTheta)
    
    # Transmittance angle:
    Theta_t = np.arcsin((1 / nw) * np.sin(Theta))
    
    # Fresnel reflectance coefficient:
    rfresnel = 0.5 * (((np.sin(Theta - Theta_t) / np.sin(Theta + Theta_t)) ** 2) + ((np.tan(Theta - Theta_t) / np.tan(Theta + Theta_t)) ** 2))
    
    # Direct Transmittance:
    Tdir = (np.exp(-float(optical_depth_total) / np.cos(float(solar_zn) * (np.pi / 180)))) * (np.exp(-float(optical_depth_total) / np.cos(float(view_zn) * (np.pi / 180))))
    
    return {'rFresnel': rfresnel, 'Tdir': Tdir}

   
image_path = r"G:\Meu Drive\PhD\ADJCORR\21HVB"
band_path = [i for i in os.listdir(image_path)]

output_path = r"G:\Meu Drive\PhD\ADJCORR\out"

tree = ET.parse(r'G:\Meu Drive\PhD\ADJCORR\21HVB\MTD.xml')

root = tree.getroot()
metadata_6sv = xml_to_dict(root)

# Water refraction index --average per band:
if metadata_6sv["General_Info"]["satellite"] == "MSI_S2":
    w_refIndex = {0: 1.335982329,1: 1.332901801,2: 1.329898916,3: 1.326184102,4: 1.313277663}
    swir_ref_band = ["B11"]
    corr_bands = ["B02", "B2", "B03", "B3", "B04", "B8A", "B11"]
    
else:
    w_refIndex = {0: 1.336521293,1: 1.332967903,2: 1.330137378,3: 1.32620903,4: 1.31328015}
    swir_ref_band = ["B5", "B05"]
    corr_bands = ["B02", "B2", "B03", "B3", "B04", "B5", "B05", "B6", "B06"]
    
# Glint mask based on SWIR band:
swir_band = next((band for band in band_path if "B11" in band or "B6" in band or "B06" in band), None)
red_band = next((band for band in band_path if "B04" in band or "B4" in band), None)

xda_red = rxr.open_rasterio(os.path.join(image_path, red_band))
swir_20 = rxr.open_rasterio(os.path.join(image_path, swir_band))
arr1020 = swir_20.rio.reproject_match(xda_red).values.astype(float)
gmask = np.where(arr1020 >= 0.03, 1, 0) # default --threshold equal to 0.005.

# Angle images -> OAA, OZA, SAA, and SZA:
ang = {}
for i in metadata_6sv["InputData"]["geometry"]["B0"].keys():
    angle_value = metadata_6sv["InputData"]["geometry"]["B0"][i]
    arr = np.full_like(arr1020, angle_value, dtype=float)
    ang[i] = np.where(arr1020 == -9999, np.nan, arr)

# Extracts the fresnel reflectance and transmittance for reference band (1020nm):
filtered_dict = list({key: value for key, value in metadata_6sv["General_Info"]["bandname"].items() if any(band in value for band in swir_ref_band)}.keys())[0]

amtcor_input = metadata_6sv["InputData"]["sixSV_params"][filtered_dict]

reference = paramglint(ang, metadata_6sv["InputData"]["geometry"][filtered_dict]['solar_zn'], metadata_6sv["InputData"]["sixSV_params"][filtered_dict]['view_z'], 
                       amtcor_input['optical_depth_total'], w_refIndex[4])

# Glint correction of each band
filtered_dict = {key: value for key, value in metadata_6sv["General_Info"]["bandname"].items() if any(band in value for band in corr_bands)}
updated_dict = {key: value.replace(".jp2", ".tif") for key, value in filtered_dict.items()}
filtered_dict = updated_dict
    
band_names = list(filtered_dict.values())
band_index = list(filtered_dict.keys())
    
for i in range(len(band_index)):

    band_key = band_index[i]

    target = paramglint(ang, metadata_6sv["InputData"]["geometry"][band_key]['solar_zn'], metadata_6sv["InputData"]["sixSV_params"][band_key]['view_z'], metadata_6sv["InputData"]["sixSV_params"][band_key]['optical_depth_total'], w_refIndex[i])
    
    array_path = os.path.join(image_path, filtered_dict[band_key])
    
    with rasterio.open(array_path) as src:
        arr = src.read(1).astype(float)
        profile = src.profile

    r_glint = arr1020 * (target['Tdir'] / reference['Tdir']) * (target['rFresnel'] / reference['rFresnel'])
    
    m_glint = gmask == 1
    r_corr = np.copy(arr)
    
    m_glint = m_glint[0,:,:]
    r_glint = r_glint[0,:,:]
    
    r_corr[m_glint] = arr[m_glint] - r_glint[m_glint]

    r_corr[(r_corr > -0.2) & (r_corr < 0)] = 0.0001
    
    r_corr_final = np.where((np.isnan(r_corr)) | (r_corr < 0), arr, r_corr)
    
    mask_nodata = np.where((r_corr_final < 0) | (r_corr_final > 1) | (r_corr_final == np.nan), 0, 1)
    arr_integer = (r_corr_final * 10000).astype(np.int16)
    arr_integer = np.where(mask_nodata == 0, -9999, arr_integer)
    
    # Save corrected image
    profile.update(
        dtype=r_corr_final.dtype.name,
        count=1,
        compress='lzw',
        driver='GTiff',
        nodata=-9999
    )
    
    output_path_save = os.path.join(output_path, f"glintcorr_{band_names[i]}")
    
    with rasterio.open(output_path_save, 'w', **profile) as dest_dataset:
        dest_dataset.write(arr_integer, 1)

IndexError: boolean index did not match indexed array along axis 0; size of axis is 5490 but size of corresponding boolean axis is 10980