In [None]:
"""This script checks for rogue pixels in SPEAR-MED precipitation
outputs. The script runs through all the netCDF files in a directory
and logs the names of the files in a log file. In case a rogue pixel
is found, the script flags and logs the filename and plots that specific timestep
of the file, so that the location of that pixel can be visualized."""

# Import modules
import os
import glob
import logging
from datetime import timedelta
import netCDF4
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from tqdm import tqdm

# Configurations (will be transferred to yaml later)
CONFIG = {
    "input_directory": "data",
    "output_directory": "outputs",
    "log_file": "output.log",
    "variable_name": "precip",
    "time_dim": "time",
    "lat_dim": "lat",
    "lon_dim": "lon",
    "threshold": 1000.0, # The threshold beyond which the precipitation values will be flagged
    "dpi": 150,
    "cmap": "Blues",
    "figure_size": (12, 8),
}

# The plotting function
def create_alert_plot(data, lons, lats, date, filename, t_index):
    try:
        precip_total_mm = data * 21600 # Converting 6 hourly precipitation to 6 hourly total
        
        fig = plt.figure(figsize=CONFIG["figure_size"])
        ax = plt.axes(projection=ccrs.Robinson())
        ax.set_global()
        
        mesh = ax.pcolormesh(
            lons, lats, precip_total_mm,
            transform=ccrs.PlateCarree(),
            cmap=CONFIG["cmap"]
        )
        
        ax.coastlines()
        ax.gridlines(draw_labels=False)
        
        cbar = plt.colorbar(mesh, orientation='vertical', pad=0.02, aspect=30, shrink=0.8)
        cbar.set_label('Precipitation (mm/6-hr)')
        
        start_time = date
        end_time = start_time + timedelta(hours=6)
        title = (
            f"High Precipitation Found\n"
            f"Source: {filename} ({start_time.strftime('%Y-%m-%d %H:%M')}-{end_time.strftime('%H:%M')})"
        )
        ax.set_title(title, pad=20)
        
        plot_filename = f"ALERT_{filename}_timestep_{t_index:05d}.png"
        save_path = os.path.join(CONFIG["output_directory"], plot_filename)
        plt.savefig(save_path, dpi=CONFIG["dpi"], bbox_inches='tight')
        
    except Exception as e:
        logging.error(f"Failed to create plot for {filename} at timestep {t_index}: {e}")
    finally:
        plt.close(fig)

# Main flagging function
def main():
    os.makedirs(CONFIG["output_directory"], exist_ok=True)
    
    logging.basicConfig(
        filename=CONFIG["log_file"],
        level=logging.INFO,
        format='%(asctime)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        filemode='w'
    )
    
    search_path = os.path.join(CONFIG["input_directory"], '*.nc') # Runs through all the .nc files in a directory
    file_list = glob.glob(search_path)
    
    if not file_list:
        print(f"No .nc files found in '{CONFIG['input_directory']}'. Exiting.") 
        logging.warning(f"No .nc files found in '{CONFIG['input_body']}'.")
        return
        
    print(f"Found {len(file_list)} files to process.")
    
    for filepath in tqdm(file_list, desc="Processing files"):
        basename = os.path.basename(filepath)
        high_value_found_in_file = False
        
        try:
            with netCDF4.Dataset(filepath, 'r') as ds:
                precip_var = ds.variables[CONFIG["variable_name"]] # Variable is "precip" in this case
                lats = ds.variables[CONFIG["lat_dim"]][:]
                lons = ds.variables[CONFIG["lon_dim"]][:]
                time_var = ds.variables[CONFIG["time_dim"]]
                dates = netCDF4.num2date(time_var[:], time_var.units, getattr(time_var, 'calendar', 'standard'))
                
                for t_index in range(len(dates)):
                    data_slice = precip_var[t_index, :, :]
                    
                    if np.any(data_slice > CONFIG["threshold"]):
                        high_value_found_in_file = True
                        print(f"\nHigh value found in {basename} at timestep {t_index}.")
                        logging.info(f"ALERT in {basename}: Plotting timestep {t_index}.")
                        create_alert_plot(data_slice, lons, lats, dates[t_index], basename, t_index)

            if high_value_found_in_file:
                log_message = f"{basename} processed, values greater than {CONFIG['threshold']} kg/m2/s found!"
            else:
                log_message = f"{basename} processed, no value greater than {CONFIG['threshold']} kg/m2/s found"
            
            logging.info(log_message)

        except Exception as e:
            error_message = f"Failed to process {basename}. Error: {e}"
            print(error_message)
            logging.error(error_message)

    print(f"\nProcessing complete. Check '{CONFIG['log_file']}' for details.")

if __name__ == "__main__":
    main()