In [None]:
# ================================================================================================
#   PYTHON SCRIPT: HIGH-PERFORMANCE NETCDF TIME-SERIES EXTRACTOR (SEQUENTIAL & RESILIENT)
# ================================================================================================

# ================================================================================================
#   STEP 1: IMPORT MODULES
# ================================================================================================

import xarray as xr
import pandas as pd
import numpy as np
import pathlib
import time
import shutil
import os
import glob
from typing import Dict, Any, List, Tuple
from tqdm import tqdm # Used for clear status bars

# ================================================================================================
#   STEP 2: USER-EDITABLE SETTINGS (MAXIMUM FLEXIBILITY)
# ================================================================================================

# --- A. FILE PATHS & DIRECTORIES ---

# 1. INPUT DIRECTORY (The folder containing all your NetCDF files)
NC_INPUT_DIRECTORY: str = "Observed_daily_dscharge/Raw_netcdf"

# 2. POINT FILE PATH (The CSV containing the longitudes and latitudes of the outlets/subbasins)
POINT_CSV_FILE_PATH: str = "Outlet_subbasin.csv"

# --- DERIVED OUTPUT PATHS (DO NOT EDIT THESE) ---
INPUT_ROOT_DIR = pathlib.Path(NC_INPUT_DIRECTORY)
# Temporary directory for yearly extracted CSVs (will be created/deleted)
YEARLY_DATA_DIR: str = str(INPUT_ROOT_DIR / "01_Intermediate_Yearly_CSVs")
# Final directory for the single, merged time-series CSV for each point
FINAL_MERGED_OUTPUT_DIR: str = str(INPUT_ROOT_DIR / "02_Final_Merged_Streamflow")

# --- B. NETCDF FILE SELECTION ---
# Pattern to match NetCDF files (e.g., "*.nc" for all files, or "ERA5L_*.nc")
NC_FILE_PATTERN: str = "*.nc"
# Process a subset of files: If empty list [], all files matching the pattern will be processed.
# Set this to [] to process all files, or keep your example ["ERA5L_9km_BG_daily_streamflow_1951.nc", "ERA5L_9km_BG_daily_streamflow_1952.nc"]to test the single-file scenario:
NC_FILES_TO_PROCESS: List[str] = [] # Example filter

# --- C. INPUT DATA CONFIGURATION ---
# The field in your point CSV that uniquely identifies each subbasin/outlet
POINT_ID_FIELD: str = "NAME"
# Column names in the point CSV for coordinates (NetCDF expects 'lon' and 'lat')
LONGITUDE_FIELD: str = "LONG"
LATITUDE_FIELD: str = "LAT"
# The data variable to extract from the NetCDF file
NETCDF_VARIABLE: str = "flw"

# --- D. OUTPUT DATA CONFIGURATION ---
# Defines the column names for the output CSVs
OUTPUT_COLUMN_NAMES: Dict[str, str] = {"time": "DATE", "value": "DISCHARGE"}

# --- E. RESILIENCE & CLEANUP ---
# If True, files in FINAL_MERGED_OUTPUT_DIR will be overwritten.
# If False, the script will enable a smart-resume feature (with data integrity check).
OVERWRITE_EXISTING: bool = False
# If True, the temporary YEARLY_DATA_DIR is deleted after a successful run.
CLEANUP_YEARLY_DIR: bool = False

# Global dictionary to cache datasets in the main thread's memory.
worker_datasets: Dict[str, xr.Dataset] = {}

# ================================================================================================
#   STEP 3: CORE WORKER FUNCTIONS (SEQUENTIAL PROCESSING LOGIC)
# ================================================================================================

def open_nc_dataset_cached(nc_file_path: str) -> xr.Dataset:
    """
    Opens a NetCDF dataset using a thread-local cache.
    This prevents redundant opening of the same file.
    """
    global worker_datasets
    if nc_file_path not in worker_datasets:
        worker_datasets[nc_file_path] = xr.open_dataset(nc_file_path, engine='netcdf4')
    return worker_datasets[nc_file_path]

def calculate_expected_length(nc_file_paths: List[str]) -> int:
    """
    Calculates the total expected number of time steps (records)
    from all NetCDF files being processed.
    """
    total_records = 0
    for nc_path in nc_file_paths:
        try:
            ds = open_nc_dataset_cached(nc_path)
            total_records += len(ds['time'])
        except Exception as e:
            return -1 # Signal failure/unknown length
    return total_records

def extraction_worker(point_data: pd.Series, nc_file_path: str) -> Tuple[bool, str, str]:
    """
    Sequential worker function for Phase 1: Extracts time-series for one point from one NetCDF file.
    Returns: success, point_id, status_message
    """
    point_id = str(point_data[POINT_ID_FIELD])
    nc_file_name = pathlib.Path(nc_file_path).name
    nc_stem = pathlib.Path(nc_file_name).stem
    
    try:
        lon = point_data[LONGITUDE_FIELD]
        lat = point_data[LATITUDE_FIELD]

        # Determine output path for the yearly intermediate file
        yearly_dir = pathlib.Path(YEARLY_DATA_DIR) / nc_stem
        yearly_dir.mkdir(parents=True, exist_ok=True)
        yearly_output_path = yearly_dir / f"{point_id}.csv"

        # Skip if already exists
        if not OVERWRITE_EXISTING and yearly_output_path.exists():
            # Return 'SKIPPED' action
            return True, point_id, f"{point_id} & {nc_stem}.csv | Status: SUCCESS | Action: SKIPPED"

        # 1. Access cached dataset
        ds = open_nc_dataset_cached(nc_file_path)

        # 2. Extract data using nearest-neighbor lookup
        data_slice = ds[NETCDF_VARIABLE].sel(
            lon=lon,
            lat=lat,
            method='nearest'
        )

        # 3. Convert to Pandas DataFrame and save intermediate CSV
        df = data_slice.to_dataframe(NETCDF_VARIABLE).reset_index()
        df = df.rename(columns={'time': 'time', NETCDF_VARIABLE: 'value'})
        df[['time', 'value']].to_csv(yearly_output_path, index=False)

        # Return 'EXTRACTED' action
        return True, point_id, f"{point_id} & {nc_stem}.csv | Status: SUCCESS | Action: EXTRACTED"

    except Exception as e:
        # Return 'FAILED' action
        return False, point_id, f"{point_id} & {nc_stem}.csv | Status: FAILURE | Action: {e.__class__.__name__}"

def merge_worker(point_id: str, nc_file_stems: List[str], expected_total_records: int) -> Tuple[bool, str, str]:
    """
    Sequential worker function for Phase 2: Merges all yearly CSVs for one point into a single file.
    """
    final_output_path = pathlib.Path(FINAL_MERGED_OUTPUT_DIR) / f"{point_id}.csv"

    # Check for smart resume (OVERWRITE_EXISTING is False)
    if not OVERWRITE_EXISTING and final_output_path.exists():
        try:
            # --- ROBUST CHECK: Verify File Record Count ---
            if expected_total_records > 0:
                with open(final_output_path, 'r', encoding='utf-8') as f:
                    actual_lines = sum(1 for line in f)
                    
                expected_lines = expected_total_records + 1 # +1 for header
                
                if actual_lines == expected_lines:
                    return True, point_id, f"Skipped (Record count verified: {final_output_path.name})"
            
        except Exception:
            # If resume check fails, proceed to merge/overwrite
            pass

    try:
        yearly_dfs = []
        for stem in nc_file_stems:
            yearly_path = pathlib.Path(YEARLY_DATA_DIR) / stem / f"{point_id}.csv"
            if yearly_path.exists():
                df = pd.read_csv(yearly_path, parse_dates=['time'], dtype={'value': np.float64})
                yearly_dfs.append(df)

        if not yearly_dfs:
            return False, point_id, f"Error: No yearly data found for Point ID {point_id}."

        # Concatenate, sort by date, and finalize the DataFrame
        merged_df = pd.concat(yearly_dfs, ignore_index=True)
        merged_df = merged_df.sort_values(by='time').drop_duplicates(subset='time', keep='last')

        # Final formatting and save
        merged_df.rename(columns=OUTPUT_COLUMN_NAMES, inplace=True)
        final_output_path.parent.mkdir(parents=True, exist_ok=True)
        merged_df.to_csv(final_output_path, index=False)
        
        # Post-merge safety check for data integrity
        if expected_total_records > 0 and len(merged_df) != expected_total_records:
             return False, point_id, f"Merge completed but record count mismatch: Expected {expected_total_records}, Found {len(merged_df)}"

        return True, point_id, str(final_output_path)

    except Exception as e:
        return False, point_id, f"Error merging data: {e}"

# ================================================================================================
#   STEP 4: MAIN WORKFLOW PHASES (SEQUENTIAL EXECUTION)
# ================================================================================================

def format_bytes(size_in_bytes: int) -> str:
    """Converts bytes to a human-readable format (MB or KB)."""
    if size_in_bytes is None or size_in_bytes < 0: return "N/A"
    if size_in_bytes >= 1048576: return f"{size_in_bytes / 1048576:.2f} MB"
    elif size_in_bytes >= 1024: return f"{size_in_bytes / 1024:.2f} KB"
    else: return f"{size_in_bytes} Bytes"

def print_verification_report(report_data: List[Dict[str, str]]):
    """Prints the final verification report in the requested table format."""

    COL_WIDTHS = [45, 12, 9, 15] 

    print("\n\n✅ FINAL VERIFICATION REPORT")
    print("-" * (sum(COL_WIDTHS) + 9))
    header = f"{'File Name':<{COL_WIDTHS[0]}} | {'Size':>{COL_WIDTHS[1]}} | {'Status':<{COL_WIDTHS[2]}} | {'Action/Reason':<{COL_WIDTHS[3]}}"
    print(header)
    print("-" * (sum(COL_WIDTHS) + 9))

    for row in report_data:
        status_display = row['Status']
        action = row['Action'].upper()
        
        output_row = (
            f"{row['FileName']:<{COL_WIDTHS[0]}} | "
            f"{row['Size']:>{COL_WIDTHS[1]}} | "
            f"{status_display:<{COL_WIDTHS[2]}} | "
            f"{action:<{COL_WIDTHS[3]}}"
        )
        print(output_row)
    print("-" * (sum(COL_WIDTHS) + 9))

def prepare_data_and_files() -> Tuple[pd.DataFrame, List[str]]:
    """Loads points, finds NC files, and prepares the processing list."""

    points_df = pd.read_csv(POINT_CSV_FILE_PATH)
    points_df[POINT_ID_FIELD] = points_df[POINT_ID_FIELD].astype(str)
    print(f"Loaded {len(points_df)} points from {pathlib.Path(POINT_CSV_FILE_PATH).name}.")

    nc_input_dir = pathlib.Path(NC_INPUT_DIRECTORY)
    all_nc_files = glob.glob(str(nc_input_dir / NC_FILE_PATTERN))

    if NC_FILES_TO_PROCESS:
        nc_files = [f for f in all_nc_files if pathlib.Path(f).name in NC_FILES_TO_PROCESS]
    else:
        nc_files = all_nc_files
    nc_files.sort()

    # --- Input File Report ---
    print("\n" + "--- [ Task 1.1: Input NetCDF File Report ] ---")
    COL_WIDTHS = [45, 12, 10]
    print("-" * (sum(COL_WIDTHS) + 6))
    nc_header = f"{'File Name':<{COL_WIDTHS[0]}} | {'Records':>{COL_WIDTHS[1]}} | {'Status':<{COL_WIDTHS[2]}}"
    print(nc_header)
    print("-" * (sum(COL_WIDTHS) + 6))
    
    total_expected_records = 0
    for nc_path in nc_files:
        nc_file_name = pathlib.Path(nc_path).name
        try:
            ds = open_nc_dataset_cached(nc_path)
            num_records = len(ds['time'])
            total_expected_records += num_records
            status = "LOADED"
        except Exception:
            num_records = "ERROR"
            status = "FAILED"
            
        output_row = (
            f"{nc_file_name:<{COL_WIDTHS[0]}} | "
            f"{str(num_records):>{COL_WIDTHS[1]}} | "
            f"{status:<{COL_WIDTHS[2]}}"
        )
        print(output_row)

    print("-" * (sum(COL_WIDTHS) + 6))
    print(f"Total Expected Records (All Files): {total_expected_records}")
    print(f"Found {len(nc_files)} NetCDF files for processing.")

    # --- Smart Resume Filtering ---
    if not OVERWRITE_EXISTING:
        processed_ids = []
        if pathlib.Path(FINAL_MERGED_OUTPUT_DIR).exists():
            # Check for final file existence AND correct record count
            expected_lines = total_expected_records + 1 if total_expected_records > 0 else -1
            
            for file_path in pathlib.Path(FINAL_MERGED_OUTPUT_DIR).glob("*.csv"):
                if expected_lines > 0:
                    try:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            actual_lines = sum(1 for line in f)
                        if actual_lines == expected_lines:
                             processed_ids.append(file_path.stem)
                    except Exception:
                         pass
                elif file_path.stat().st_size > 1024:
                     processed_ids.append(file_path.stem)


        processed_set = set(processed_ids)
        # Filter points for which the FINAL merged file is missing or invalid
        points_to_process = points_df[~points_df[POINT_ID_FIELD].astype(str).isin(processed_set)].copy()

        if len(points_to_process) < len(points_df):
            skipped_count = len(points_df) - len(points_to_process)
            print(f"Smart Resume Enabled: Skipping {skipped_count}/{len(points_df)} points for initial extraction.")
        else:
            print("Smart Resume: No valid final merged files found for initial skipping.")
    else:
        points_to_process = points_df.copy()

    return points_to_process, nc_files

def run_extraction_phase(points_to_process: pd.DataFrame, nc_files: List[str]) -> bool:
    """Phase 1: Extracts time-series data sequentially with real-time feedback."""
    print("\n" + "="*80)
    print("PHASE 1: SEQUENTIAL DATA EXTRACTION")
    print("="*80)

    extraction_tasks = []
    for index, point_data in points_to_process.iterrows():
        for nc_file_path in nc_files:
            extraction_tasks.append((point_data, nc_file_path))

    total_tasks = len(extraction_tasks)
    if total_tasks == 0:
        print("Sub-Process: No points or NC files to process. Skipping Phase 1.")
        return True

    print(f"Sub-Process: Initializing {total_tasks} extraction tasks for sequential execution...")
    
    # --- Real-Time Report Header ---
    print("\n✅ REAL-TIME EXTRACTION STATUS")
    print("-" * 80)
    print("| {:<50} | {:<10} | {:<14}".format('Point ID & File Stem', 'Status', 'Action'))
    print("-" * 80)
    
    successful_tasks = 0
    failed_tasks = 0
    
    # We use the tqdm description for the dynamic update
    with tqdm(extraction_tasks, total=total_tasks, desc="Extraction Progress", unit="task", leave=True) as pbar:
        for task in pbar:
            success, point_id, message = extraction_worker(*task)
            
            # Message format: "Subbasin_X & ERA5L_Y.csv | Status: SUCCESS | Action: SKIPPED"
            # Update the description bar with the last completed action
            pbar.set_description(f"Extraction Progress | LAST: {message}")
            
            if success:
                successful_tasks += 1
            else:
                failed_tasks += 1

    if failed_tasks > 0:
        print(f"\n⚠️ Phase 1 completed with {failed_tasks} failed tasks. Check intermediate files/logs.")
    else:
        print(f"\n✅ Phase 1 successfully completed. Total tasks: {successful_tasks}.")

    return failed_tasks == 0

def run_merging_phase(unique_point_ids: np.ndarray, nc_files: List[str]) -> bool:
    """Phase 2: Merges yearly CSV files for each point sequentially, with clean output."""
    print("\n" + "="*80)
    print("PHASE 2: SEQUENTIAL TIME-SERIES MERGING")
    print("="*80)

    yearly_data_path = pathlib.Path(YEARLY_DATA_DIR)
    if not yearly_data_path.exists():
        print("Sub-Process: Intermediate yearly data directory does not exist. Skipping Phase 2.")
        return True

    nc_file_stems = [pathlib.Path(f).stem for f in nc_files]
    
    expected_records = calculate_expected_length(nc_files)
    if expected_records > 0:
        print(f"Sub-Process: Calculated expected time steps: {expected_records} records.")
    else:
        print("Sub-Process: Could not determine expected time steps. Skipping robust record count check for merges.")

    merging_tasks = [(str(point_id), nc_file_stems, expected_records) for point_id in unique_point_ids]
    total_tasks = len(merging_tasks)

    if total_tasks == 0:
        print("Sub-Process: No points require merging. Skipping Phase 2.")
        return True

    print(f"Sub-Process: Initializing {total_tasks} merging tasks for sequential execution...")

    successful_merges = 0
    failed_merges = 0

    for task in tqdm(merging_tasks, total=total_tasks, desc="Merging Progress"):
        success, point_id, message = merge_worker(*task)
        if success:
            if not message.startswith("Skipped"):
                successful_merges += 1
        else:
            failed_merges += 1

    if failed_merges > 0:
        print(f"\n⚠️ Phase 2 completed with {failed_merges} failed merge tasks. Details will be in the Verification Report.")
    else:
        print(f"\n✅ Phase 2 successfully completed. Merged {successful_merges} new files.")

    return failed_merges == 0

def run_verification_phase(points_df: pd.DataFrame, nc_files: List[str]):
    """Phase 3: Final verification and reporting."""
    print("\n" + "="*80)
    print("PHASE 3: FINAL VERIFICATION & REPORTING")
    print("="*80)

    report_data = []
    total_input_points = len(points_df)
    final_output_dir = pathlib.Path(FINAL_MERGED_OUTPUT_DIR)
    
    if not final_output_dir.exists():
        print("Sub-Process: Final output directory does not exist. Cannot run verification.")
        return report_data

    all_point_ids = set(points_df[POINT_ID_FIELD].astype(str).unique())
    found_files = {file_path.stem: file_path for file_path in final_output_dir.glob("*.csv")}
        
    expected_records = calculate_expected_length(nc_files)
    expected_lines = expected_records + 1 if expected_records > 0 else -1

    for point_id in all_point_ids:
        file_name = f"{point_id}.csv"
        file_path = found_files.get(point_id)

        status = "FAILURE"
        action = "NOT FOUND"
        file_size_str = "0 Bytes"

        if file_path and file_path.exists():
            try:
                size_bytes = file_path.stat().st_size
                file_size_str = format_bytes(size_bytes)
                
                if size_bytes > 1024:
                    status = "SUCCESS"
                    action = "MERGED" 
                    
                    if expected_lines > 0:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            actual_lines = sum(1 for line in f)
                        
                        if actual_lines == expected_lines:
                            action = "VERIFIED"
                        else:
                            status = "WARNING"
                            action = f"INCOMPLETE ({actual_lines}/{expected_lines})"
                            
                else:
                    status = "FAILURE"
                    action = "ZERO SIZE"
                    
            except Exception as e:
                status = "ERROR"
                action = f"READ FAIL ({e.__class__.__name__})"
                
        report_data.append({
            "FileName": file_name,
            "Size": file_size_str,
            "Status": status,
            "Action": action
        })
        
    print("\nSub-Process: Internal Completeness Checks...")
    final_file_count = sum(1 for d in report_data if d['Status'] == "SUCCESS" or d['Action'] == "VERIFIED")

    if final_file_count == total_input_points:
        status_text = "✅ COMPLETE"
        message = f"All {total_input_points} input points have a final merged/verified file."
    else:
        status_text = "⚠️ PARTIAL"
        message = f"Found {final_file_count}/{total_input_points} final merged/verified files."

    print(f"  - Completeness Status: {status_text}")
    print(f"  - {message}")
    
    return report_data

# ================================================================================================
#   STEP 5: MAIN EXECUTION FLOW
# ================================================================================================

def main():
    """The main control function for the entire workflow."""

    overall_start_time = time.time()

    print("\n" + "#"*80)
    print("           UNIFIED NETCDF TIME-SERIES EXTRACTION AND MERGING WORKFLOW")
    print("#"*80)
    
    try:
        # --- Task 1: Preparation ---
        print("\n" + "--- [ Task 1: Initialization and File Discovery ] ---")
        points_to_process, nc_files = prepare_data_and_files() 

        # --- Task 2: Run Extraction (Phase 1) ---
        run_extraction_phase(points_to_process, nc_files)

        # --- Task 3: Run Merging (Phase 2) ---
        points_df_all = pd.read_csv(POINT_CSV_FILE_PATH)
        all_point_ids = points_df_all[POINT_ID_FIELD].astype(str).unique()
        run_merging_phase(all_point_ids, nc_files)

        # --- Task 4: Verification and Cleanup (Phase 3) ---
        points_df_all = pd.read_csv(POINT_CSV_FILE_PATH)
        final_report_data = run_verification_phase(points_df_all, nc_files)
        
        # --- Task 5: Print Final Report ---
        if final_report_data:
            print_verification_report(final_report_data)

        if CLEANUP_YEARLY_DIR and pathlib.Path(YEARLY_DATA_DIR).exists():
            print(f"\n--- [ Task 6: Cleanup ] ---")
            shutil.rmtree(YEARLY_DATA_DIR)

    except Exception as e:
        print(f"\n\n{'='*80}\nFATAL ERROR: The workflow stopped unexpectedly.\nError Details: {e}\n{'='*80}")
        global worker_datasets
        worker_datasets.clear()
        return 

    # --- FINAL NOTIFICATION ---
    overall_end_time = time.time()
    print(f"\n\n{'='*80}")
    print("✅ ALL TASKS COMPLETE")
    print(f"Total elapsed time: {overall_end_time - overall_start_time:.2f} seconds.")
    print(f"Results saved to: {FINAL_MERGED_OUTPUT_DIR}")
    print('='*80)

if __name__ == "__main__":
    main()