In [77]:
import os 
import sys
import ismn
import pandas as pd
import numpy as np
import xarray as xr
from multiprocessing import Pool, cpu_count # 
from pathlib import Path
from datetime import datetime
import warnings
from ismn.interface import ISMN_Interface
import matplotlib.pyplot as plt

In [78]:
from utils import longest_available_after_removing_long_gaps,trim_to_surface_valid_period_and_keep_well_covered_depths,gapfill_by_monthday_mean_with_feb29_fallback

In [79]:
insitu_dir = '/home/khanalp/data/ISMNsoilMoisture/'

In [80]:
# Read the data using ISMN_Interface
ds = ISMN_Interface("/home/khanalp/data/ISMNsoilMoisture/Data_separate_files_header_20140101_20251231_13107_18mx_20260208", parallel=True)

Using the existing ismn metadata in /home/khanalp/data/ISMNsoilMoisture/Data_separate_files_header_20140101_20251231_13107_18mx_20260208/python_metadata/Data_separate_files_header_20140101_20251231_13107_18mx_20260208.csv to set up ISMN_Interface. 
If there are issues with the data reader, you can remove the metadata csv file to repeat metadata collection.


In [81]:
ds

ismn.base.IsmnRoot Unzipped at /home/khanalp/data/ISMNsoilMoisture/Data_separate_files_header_20140101_20251231_13107_18mx_20260208
with Networks[Stations]:
------------------------
  AMMA-CATCH: ['Banizoumbou', 'Belefoungou-Mid', 'Belefoungou-Top', 'Nalohou-Mid', 'Nalohou-Top', 'Tondikiboro', 'Wankama'],
  ARM: ['Anthony', 'Ashton', 'Byron', 'Lamont-CF1', 'Lamont-CF2', 'MapleCity', 'Marshall', 'Medford', 'Morrison', 'Newkirk', 'Okmulgee', 'Omega', 'Pawhuska', 'Pawnee', 'Ringwood', 'Tryon', 'Tyro', 'Waukomis'],
  BDF_Saxony: ['Hilbersdorf', 'Koellitsch', 'Lippen', 'Schmorren'],
  BFG_Nw: ['BFG-Niederwerth-1-weighable'],
  BIEBRZA_S-1: ['grassland-soil-1', 'grassland-soil-2', 'grassland-soil-3', 'grassland-soil-4', 'grassland-soil-5', 'grassland-soil-6', 'grassland-soil-7', 'grassland-soil-8', 'grassland-soil-9', 'marshland-soil-11', 'marshland-soil-12', 'marshland-soil-13', 'marshland-soil-14', 'marshland-soil-15', 'marshland-soil-16', 'marshland-soil-17', 'marshland-soil-18', 'marshla

In [82]:
# 1. List networks and stations using modern ISMN interface
networks_stations = []
for network in ds.collection.networks:
    for station in ds.collection[network].stations:
        networks_stations.append({
            'network': network,
            'station': station
        })


In [None]:
df_stations = pd.DataFrame(networks_stations)
print(f"Total stations: {len(df_stations)}")
df_stations

In [None]:
# 2. Select one station to test (change these to your preferred station)
# test_network = df_stations.iloc[0]['network']  # First network
# test_station = df_stations.iloc[0]['station']  # First station
test_network = "PBO_H2O"
test_station = "MIDDLEGATE"

print(f"\n--- Testing with: {test_network} - {test_station} ---\n")
station_data = ds[test_network][test_station].to_xarray()
station_data  # To check how it looks like. 

In [None]:
# Create output directory
output_dir = Path('processed_soil_moisture')
output_dir.mkdir(exist_ok=True)

In [None]:
# Read station data
station_data = ds[network][station].to_xarray()

In [None]:
mask = (station_data["soil_moisture_flag"] == "G") | station_data["soil_moisture_flag"].astype(str).str.startswith("D")

# Apply mask to soil_moisture
soil_moisture_masked = station_data['soil_moisture'].where(mask)

# Assign depth_from as a coordinate for grouping
soil_moisture_masked = soil_moisture_masked.assign_coords(
    depth_group=('sensor', station_data['depth_to'].values) # Use depth_to instead of depth_from because some stations have depth_from = 0 for all sensors, but depth_to varies and can be used to group sensors by depth.
)

In [None]:
# Group by depth and average across sensors
depth_averaged = soil_moisture_masked.groupby('depth_group').mean(dim='sensor', skipna=True)
depth_averaged = depth_averaged.rename({'depth_group': 'depth'})

# Drop depths that have ALL NaN values (no valid data)
valid_count_per_depth = depth_averaged.count(dim='date_time')
depths_with_data = valid_count_per_depth > 0
depth_averaged = depth_averaged.where(depths_with_data, drop=True)

# # If no valid depths remain, return None
# if len(depth_averaged.depth) == 0:
#     return None

# Resample to daily
daily = depth_averaged.resample(date_time='1D').mean(dim='date_time', skipna=True)

# Count valid observations per day
count = depth_averaged.resample(date_time='1D').count(dim='date_time')

# Handle the casting more gracefully
try:
    count = count.fillna(0).astype(int)
except:
    count = count.astype(float)

# Mask out days with < 6 valid observations
daily_filtered = daily.where(count >= 6)

    # ---- depth-bin averaging AFTER daily filtering (no count saved) ----
depth_vals = daily_filtered["depth"].values.astype(float)
depth_cm = depth_vals * 100.0 if np.nanmax(depth_vals) <= 3 else depth_vals

depth_bin = pd.cut(
        depth_cm,
        bins=[0.0, 5.0, 20.0, 50.0, np.inf],
        labels=["0-5", "5-20", "20-50", ">50"],
        right=True,
        include_lowest=True
    )

daily_filtered = daily_filtered.assign_coords(depth_bin=("depth", depth_bin.astype(str)))

daily_binned = (
    daily_filtered.groupby("depth_bin")
    .mean(dim="depth", skipna=True)
    .rename({"depth_bin": "depth"})
)
# order = ["0-5", "5-20", "20-50", ">50"]
# daily_binned = daily_binned.reindex(depth=order)

result_ds = xr.Dataset({"soil_moisture": daily_binned})

# Add metadata as attributes
result_ds.attrs['network'] = network
result_ds.attrs['station'] = station
result_ds.attrs['latitude'] = float(station_data.attrs.get('lat', np.nan)) #The attrs lat and variables latitude are latitude.
result_ds.attrs['longitude'] = float(station_data.attrs.get('lon', np.nan)) # same for lon. 
result_ds.attrs["max_depth"] = float(np.nanmax(depth_vals))

In [None]:
has_any_nan_values = result_ds["soil_moisture"].notnull().any().item()
has_any_nan_values 

In [None]:
# Get metadata
lat = result_ds.attrs.get('latitude', np.nan)
lon = result_ds.attrs.get('longitude', np.nan)
depths = result_ds.depth.values.tolist()

# To remove long gaps greater than 7 days. 
longest_available = longest_available_after_removing_long_gaps(result_ds, max_gap_days=7) # dictionary with depth as key and longest available run (after removing long gaps) as value.

ds_clean = trim_to_surface_valid_period_and_keep_well_covered_depths(result_ds, longest_available, surface_depth="0-5", min_frac=0.95)

In [None]:
longest_available

In [None]:
ds_clean

In [None]:
# Get date range
valid_dates = ds_clean['soil_moisture'].dropna(dim='date_time', how='all').date_time


ds_gap_filled = gapfill_by_monthday_mean_with_feb29_fallback(ds_clean)
    
# export ONLY if there are zero NaNs left
if ds_gap_filled['soil_moisture'].isnull().sum().values == 0:
    # ds_gap_filled.to_netcdf(filepath)
    print(f"Exported: {filepath}")
else:
    print(f"Skipped (still has NaNs): {filepath}")

    
start_date = pd.to_datetime(valid_dates.min().values).strftime('%Y%m%d')
end_date = pd.to_datetime(valid_dates.max().values).strftime('%Y%m%d')

# Create filename
filename = f"{network}_{station}_{start_date}_{end_date}.nc"
filepath = output_dir / filename

In [None]:


da = result_ds["soil_moisture"]
# Ensure consistent dim order
if tuple(da.dims) != ("date_time", "depth"):
    da = da.transpose("date_time", "depth")

t = pd.to_datetime(da["date_time"].values)
depths = [str(d) for d in da["depth"].values]

plt.figure(figsize=(11, 5.5), dpi=150)
for i, d in enumerate(depths):
    y = da.sel(depth=da["depth"].values[i]).values
    plt.plot(t, y, linewidth=1.0, label=d)

plt.xlabel("Date")
plt.ylabel("Soil moisture")
plt.title("Soil moisture time series by depth")
plt.grid(True, linewidth=0.4, alpha=0.5)

# Legend handling (works for many depths)
plt.legend(title="Depth", ncol=2, fontsize=8, title_fontsize=9, frameon=False, loc="upper left", bbox_to_anchor=(1.02, 1.0))
plt.tight_layout()
plt.show()


In [None]:
longest_avail = longest_available_after_removing_long_gaps(result_ds, max_gap_days=7)
longest_avail

In [None]:
clean_ds = trim_to_surface_valid_period_and_keep_well_covered_depths(result_ds, longest_avail, surface_depth="0-5", min_frac=0.95)

In [None]:


# -----------------------
# Make test data (daily, 2 depths, 2019-2021)
# -----------------------
t = pd.date_range("2019-01-01", "2021-12-31", freq="D")
depth = ["0-5", "5-20"]

vals = np.random.rand(len(depth), len(t)).astype("float32")

da = xr.DataArray(
    vals,
    dims=("depth", "date_time"),
    coords={"depth": depth, "date_time": t},
    name="soil_moisture",
)

# Inject NaN gaps
# depth 0-5:
#  - 2019: two gaps (5 days and 3 days)
#  - 2020: one gap (10 days)
#  - 2021: one gap (20 days) -> should be the longest overall for this depth
da.loc[dict(depth="0-5", date_time=slice("2019-02-01", "2019-02-05"))] = np.nan
da.loc[dict(depth="0-5", date_time=slice("2019-07-10", "2019-07-12"))] = np.nan
da.loc[dict(depth="0-5", date_time=slice("2020-03-01", "2020-03-10"))] = np.nan
da.loc[dict(depth="0-5", date_time=slice("2021-10-01", "2021-10-20"))] = np.nan

# depth 5-20:
#  - 2019: one gap (15 days) -> longest overall for this depth
#  - 2020: three gaps (2, 4, 1 days)
da.loc[dict(depth="5-20", date_time=slice("2019-11-01", "2019-11-15"))] = np.nan
da.loc[dict(depth="5-20", date_time=slice("2020-01-05", "2020-01-06"))] = np.nan
da.loc[dict(depth="5-20", date_time=slice("2020-06-10", "2020-06-13"))] = np.nan
da.loc[dict(depth="5-20", date_time=slice("2020-12-31", "2020-12-31"))] = np.nan

trial_ds = xr.Dataset({"soil_moisture": da})


In [None]:
longest_avail = longest_available_after_removing_long_gaps(trial_ds, max_gap_days=7)

In [None]:
longest_avail

In [None]:
clean_ds = trim_to_surface_valid_period_and_keep_well_covered_depths(trial_ds, longest_avail, min_frac=0.95)

In [None]:
clean_ds

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# --- build the same trial dataset ---
t = pd.date_range("2019-01-01", "2021-12-31", freq="D")  # includes leap year 2020
depth = ["0-5", "5-20"]

rng = np.random.default_rng(0)
vals = rng.normal(loc=0.25, scale=0.05, size=(len(depth), len(t))).astype("float32")

da = xr.DataArray(
    vals,
    dims=("depth", "date_time"),
    coords={"depth": depth, "date_time": t},
    name="soil_moisture",
)

# Inject NaNs
da.loc[dict(depth="0-5", date_time=slice("2019-02-10", "2019-02-15"))] = np.nan
da.loc[dict(depth="0-5", date_time=slice("2020-02-10", "2020-02-12"))] = np.nan
da.loc[dict(depth="5-20", date_time=slice("2021-08-01", "2021-08-07"))] = np.nan
da.loc[dict(depth="0-5", date_time="2020-02-29")] = np.nan
da.loc[dict(depth="5-20", date_time="2020-02-29")] = np.nan

trial_ds = xr.Dataset({"soil_moisture": da})

# --- gap-fill with month-day mean + Feb-29 fallback (Feb-28 then Mar-01) ---
def gapfill_by_monthday_mean_with_feb29_fallback(
    ds: xr.Dataset,
    var: str = "soil_moisture",
    time_dim: str = "date_time",
):
    da = ds[var]
    orig_nan = da.isnull()

    md = da[time_dim].dt.strftime("%m-%d")
    clim = da.groupby(md).mean(time_dim, skipna=True)

    # Feb-29 fallback
    if "02-29" in clim[md.name].values:
        feb29 = clim.sel({md.name: "02-29"})
        if bool(feb29.isnull().all().item()):
            repl = clim.sel({md.name: "02-28"})
            if bool(repl.isnull().all().item()) and ("03-01" in clim[md.name].values):
                repl = repl.fillna(clim.sel({md.name: "03-01"}))
            clim.loc[{md.name: "02-29"}] = repl

    fill_vals = clim.sel({md.name: md})
    filled = da.where(~orig_nan, fill_vals)

    still_nan = filled.isnull()
    flag = xr.zeros_like(da, dtype=np.int8)
    flag = flag.where(orig_nan, other=0)
    flag = flag.where(~orig_nan, other=1)
    flag = flag.where(~still_nan, other=2)

    out = ds.copy()
    out[var] = filled
    out["gapfill_flag"] = flag
    return out

filled_ds = gapfill_by_monthday_mean_with_feb29_fallback(trial_ds)



In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd

depth_sel = "0-5"

fill = filled_ds["soil_moisture"].sel(depth=depth_sel).to_series()
flag = filled_ds["gapfill_flag"].sel(depth=depth_sel).to_series()  # 1 = filled

plt.figure(figsize=(11, 5.5), dpi=200)

for y in sorted(fill.index.year.unique()):
    f = fill[fill.index.year == y]
    fl = flag[flag.index.year == y]

    # use a leap year so 02-29 is valid
    x = pd.to_datetime("2000-" + f.index.strftime("%m-%d"))

    plt.plot(x, f.values, linewidth=1.0, label=str(y))

    fi = fl.index[fl.values == 1]
    if len(fi) > 0:
        x_fi = pd.to_datetime("2000-" + fi.strftime("%m-%d"))
        plt.scatter(x_fi, f.loc[fi].values, s=18, marker="x")

ax = plt.gca()
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter("%b"))
plt.xlim(pd.Timestamp("2000-01-01"), pd.Timestamp("2000-12-31"))

plt.xlabel("Month")
plt.ylabel("Soil moisture")
plt.title(f"{depth_sel}: yearly series on Jan–Dec axis (x = gap-filled)")
plt.grid(True, linewidth=0.4, alpha=0.5)
plt.legend(title="Year", ncol=3, frameon=False, fontsize=8, title_fontsize=9)
plt.tight_layout()
plt.show()

To learn about ISMN data quality flag: check https://ismn.earth/en/data/flag-overview/
In short, G = good, D = Dubious, C= outside plausible range.

Inside process_station, \
1. Create a mask for good-quality measurements: soil_moisture_flag == "G".
2. Apply the mask to soil_moisture so non-G values become NaN.
3. Add a coordinate depth_group (per sensor) using depth_from so sensors can be grouped by depth.
4. Group by depth_group and take the mean across sensors (depth-average), ignoring NaNs.
5. Rename depth_group to depth.
6. Compute how many valid values each depth has over all times; drop depths that are all NaN.
7. If no depths remain, return None.
8. Resample to daily mean soil moisture (per depth).
9. Resample to a daily count of valid observations (per depth).
10.Try to convert the daily count to an integer (fallback to a float if conversion fails).
11. Mask daily soil moisture where the daily count is < 6 (those days become NaN).
12. Return an xarray.Dataset with:
13. soil_moisture (daily, filtered by count)
14. observation_count (daily count)
Add metadata attributes: network, station, latitude, and longitude.

In [None]:
# Function to process each station
def process_station(station_data, network, station):
    """
    Process a single station dataset:
    - Depth average soil_moisture where flag = 'G'
    - Drop depths with no valid data
    - Convert to daily if >= 6 valid observations per day
    """
    # Create mask where soil_moisture_flag == "G" and also include "D" (Dubious) flags as valid data.
    # mask = station_data['soil_moisture_flag'] == "G"
    mask = (station_data["soil_moisture_flag"] == "G") | station_data["soil_moisture_flag"].astype(str).str.startswith("D")
    
    # Apply mask to soil_moisture
    soil_moisture_masked = station_data['soil_moisture'].where(mask)
    
    # Assign depth_from as a coordinate for grouping
    soil_moisture_masked = soil_moisture_masked.assign_coords(
        depth_group=('sensor', station_data['depth_to'].values) # Use depth_to instead of depth_from because some stations have depth_from = 0 for all sensors, but depth_to varies and can be used to group sensors by depth.
    )
    
    # Group by depth and average across sensors
    depth_averaged = soil_moisture_masked.groupby('depth_group').mean(dim='sensor', skipna=True)
    depth_averaged = depth_averaged.rename({'depth_group': 'depth'})
    
    # Drop depths that have ALL NaN values (no valid data)
    valid_count_per_depth = depth_averaged.count(dim='date_time')
    depths_with_data = valid_count_per_depth > 0
    depth_averaged = depth_averaged.where(depths_with_data, drop=True)
    
    # If no valid depths remain, return None
    if len(depth_averaged.depth) == 0:
        return None
    
    # Resample to daily
    daily = depth_averaged.resample(date_time='1D').mean(dim='date_time', skipna=True)
    
    # Count valid observations per day
    count = depth_averaged.resample(date_time='1D').count(dim='date_time')
    
    # Handle the casting more gracefully
    try:
        count = count.fillna(0).astype(int)
    except:
        count = count.astype(float)
    
    # Mask out days with < 6 valid observations
    daily_filtered = daily.where(count >= 6)
    
        # ---- depth-bin averaging AFTER daily filtering (no count saved) ----
    depth_vals = daily_filtered["depth"].values.astype(float)
    depth_cm = depth_vals * 100.0 if np.nanmax(depth_vals) <= 3 else depth_vals

    depth_bin = pd.cut(
        depth_cm,
        bins=[0.0, 5.0, 20.0, 50.0, np.inf],
        labels=["0-5", "5-20", "20-50", ">50"],
        right=True,
        include_lowest=True
    )

    daily_filtered = daily_filtered.assign_coords(depth_bin=("depth", depth_bin.astype(str)))

    daily_binned = (
        daily_filtered.groupby("depth_bin")
        .mean(dim="depth", skipna=True)
        .rename({"depth_bin": "depth"})
    )

    result_ds = xr.Dataset({"soil_moisture": daily_binned})
    
    # Add metadata as attributes
    result_ds.attrs['network'] = network
    result_ds.attrs['station'] = station
    result_ds.attrs['latitude'] = float(station_data.attrs.get('lat', np.nan)) #The attrs lat and variables latitude are latitude.
    result_ds.attrs['longitude'] = float(station_data.attrs.get('lon', np.nan)) # same for lon. 
    result_ds.attrs["max_depth"] = float(np.nanmax(depth_vals))
    return result_ds


## 'process_single_station' is just a wrapper for parallel processing. 
What really happens is:
1. station data is read into xarray. 
2. then process_station is called, which does QC, filtering, converts hourly to daily, which returns result_ds. 
3. If result_ds is None there we skip to other stations, if not
4. We get metadata like network, stations, longitude, latitude, depths, start_date, end_date, etc.
5. Filename is saved as f"{network}_{station}_{start_date}_{end_date}.nc" in output_dir.
6. Function returns metadata.


In [None]:
def process_single_station(args):
    """
    Wrapper function for parallel processing
    """
    network, station, idx, total = args
    
    try:
        print(f"[{idx+1}/{total}] Processing: {network}/{station}")
        
        # Read station data
        station_data = ds[network][station].to_xarray()
        
        # Process station
        result_ds = process_station(station_data, network, station)
        
        # Skip if no valid data
        if result_ds is None:
            print(f"[{idx+1}/{total}] Skipped (no valid data): {network}/{station}")
            return None
        
        # Get metadata
        lat = result_ds.attrs.get('latitude', np.nan)
        lon = result_ds.attrs.get('longitude', np.nan)
        depths = result_ds.depth.values.tolist()
        
        # Get date range
        valid_dates = result_ds['soil_moisture'].dropna(dim='date_time', how='all').date_time
        if len(valid_dates) == 0:
            print(f"[{idx+1}/{total}] Skipped (no valid dates): {network}/{station}")
            return None
            
        start_date = pd.to_datetime(valid_dates.min().values).strftime('%Y%m%d')
        end_date = pd.to_datetime(valid_dates.max().values).strftime('%Y%m%d')
        
        # Create filename
        filename = f"{network}_{station}_{start_date}_{end_date}.nc"
        filepath = output_dir / filename
        
        # Save to netCDF
        result_ds.to_netcdf(filepath)
        
        print(f"[{idx+1}/{total}] Success: {network}/{station} -> {filename}")
        
        # Return metadata
        metadata = {
            'network': network,
            'station': station,
            'latitude': lat,
            'longitude': lon,
            'depths': str(depths),  # Convert list to string for CSV
            'n_depths': len(depths),
            'start_date': start_date,
            'end_date': end_date,
            'n_days': len(valid_dates),
            'filename': filename
        }
        
        return metadata
        
    except Exception as e:
        print(f"[{idx+1}/{total}] Error processing {network}/{station}: {e}")
        return None


In [None]:
def process_single_station(args):
    """
    Wrapper function for parallel processing
    """
    network, station, idx, total = args
    
    try:
        print(f"[{idx+1}/{total}] Processing: {network}/{station}")
        
        # Read station data
        station_data = ds[network][station].to_xarray()
        
        # Process station
        result_ds = process_station(station_data, network, station)
        
        # Skip if no valid data
        if result_ds is None:
            print(f"[{idx+1}/{total}] Skipped (no valid data): {network}/{station}")
            return None
        
        # Get metadata
        lat = result_ds.attrs.get('latitude', np.nan)
        lon = result_ds.attrs.get('longitude', np.nan)
        depths = result_ds.depth.values.tolist()
        
        # Get date range
        valid_dates = result_ds['soil_moisture'].dropna(dim='date_time', how='all').date_time
        if len(valid_dates) == 0:
            print(f"[{idx+1}/{total}] Skipped (no valid dates): {network}/{station}")
            return None
            
        start_date = pd.to_datetime(valid_dates.min().values).strftime('%Y%m%d')
        end_date = pd.to_datetime(valid_dates.max().values).strftime('%Y%m%d')
        
        # Create filename
        filename = f"{network}_{station}_{start_date}_{end_date}.nc"
        filepath = output_dir / filename
        
        # Save to netCDF
        result_ds.to_netcdf(filepath)
        
        print(f"[{idx+1}/{total}] Success: {network}/{station} -> {filename}")
        
        # Return metadata
        metadata = {
            'network': network,
            'station': station,
            'latitude': lat,
            'longitude': lon,
            # 'depths': str(depths),  # Convert list to string for CSV
            # 'n_depths': len(depths),
            'max_depth (cm)': result_ds.attrs.get("max_depth", np.nan) * 100.0,  # Convert to cm
            'start_date': start_date,
            'end_date': end_date,
            'n_days': len(valid_dates),
            'filename': filename
        }
        
        return metadata
        
    except Exception as e:
        print(f"[{idx+1}/{total}] Error processing {network}/{station}: {e}")
        return None


In [None]:
# # Inspect the data structure
# print("Data structure:")
# print(ds_test)
# print("\n" + "="*60 + "\n")

# # Check dimensions
# print("Dimensions:")
# print(ds_test.dims)
# print("\n" + "="*60 + "\n")

# # Check coordinates
# print("Coordinates:")
# print(ds_test.coords)
# print("\n" + "="*60 + "\n")


In [None]:
# Collect all station tasks
all_tasks = []
for network in ds.collection.networks:
    for station in ds.collection[network].stations:
        all_tasks.append((network, station))

In [None]:

# ============================================
# CONFIGURE TEST RUN HERE
# ============================================
TEST_MODE = True  # Set to False to process all stations
N_TEST_STATIONS = 1  # Number of stations to test with

if TEST_MODE:
    tasks = all_tasks[:N_TEST_STATIONS]
    print(f"=== TEST MODE: Processing {N_TEST_STATIONS} stations ===")
else:
    tasks = all_tasks
    print(f"=== FULL MODE: Processing all {len(tasks)} stations ===")

# Add index and total count to tasks
tasks_with_idx = [(net, sta, i, len(tasks)) for i, (net, sta) in enumerate(tasks)]

print(f"Output directory: {output_dir}/")
print(f"Using {cpu_count()} CPU cores available")

In [None]:

# ============================================
# CONFIGURE PARALLELIZATION HERE
# ============================================
USE_PARALLEL = False  # Set to False for sequential (easier debugging)
N_WORKERS = 10  # Number of parallel workers (adjust as needed)

if USE_PARALLEL:
    print(f"Running in parallel with {N_WORKERS} workers")
    with Pool(N_WORKERS) as pool:
        results = pool.map(process_single_station, tasks_with_idx)
else:
    print("Running sequentially (no parallelization)")
    results = [process_single_station(task) for task in tasks_with_idx]

In [None]:

# Filter out None results
metadata_list = [r for r in results if r is not None]

# Create metadata DataFrame
if len(metadata_list) > 0:
    df_metadata = pd.DataFrame(metadata_list)
    
    print(f"\n{'='*60}")
    print(f"Processing complete!")
    print(f"Successfully processed: {len(df_metadata)}/{len(tasks)} stations")
    print(f"Files saved to: {output_dir}/")
    print(f"\nMetadata summary:")
    print(df_metadata.head(20))
    
    # Save metadata
    metadata_file = output_dir / 'station_metadata.csv'
    df_metadata.to_csv(metadata_file, index=False)
    print(f"\nMetadata saved to: {metadata_file}")
    
    # Print summary statistics
    print(f"\nSummary:")
    print(f"  Networks: {df_metadata['network'].nunique()}")
    print(f"  Stations: {len(df_metadata)}")
    print(f"  Date range: {df_metadata['start_date'].min()} to {df_metadata['end_date'].max()}")
    print(f"  Depth range: {df_metadata['n_depths'].min()}-{df_metadata['n_depths'].max()} depths per station")
    
else:
    print("\nNo valid data found across all stations")

In [None]:
df_metadata = pd.read_csv('/home/khanalp/code/PhD/soilMoisture/processed_soil_moisture/station_metadata.csv')

In [None]:
df_metadata

In [70]:
output_dir = "/home/khanalp/code/PhD/soilMoisture/processed_soil_moisture/"
# Create directory for plot 
plot_dir = "/home/khanalp/code/PhD/soilMoisture/processed_soil_moisture/plots_for_gap_filled/"
plot_dir = Path(plot_dir)
plot_dir.mkdir(exist_ok=True)

In [71]:
nc_files = list(Path(output_dir).glob("*.nc"))
len(nc_files)

58

In [85]:
# ---------------- Plot style (your template) ----------------
import matplotlib.pyplot as plt
import matplotlib as mpl
import scienceplots  # registers 'science', 'no-latex', etc.


plt.style.use(['science', 'no-latex'])
mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
    'font.size': 14,
    'axes.titlesize': 16,
    'axes.titleweight': 'bold',
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'lines.linewidth': 2,
    'legend.fontsize': 12,
    'figure.dpi': 300,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 6,
    'ytick.major.size': 6,
    'xtick.minor.size': 3,
    'ytick.minor.size': 3,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
})


In [86]:
# Add index + total (same pattern as your station processing)
nc_files_with_idx = [(str(f), str(plot_dir), i, len(nc_files)) for i, f in enumerate(nc_files)]

In [87]:
USE_PARALLEL = True   # False = sequential (easier debugging)
N_WORKERS = 50        # adjust as needed

In [88]:
def plot_single_file(task):
    f, plot_dir, i, total = task
    try:
        ds = xr.open_dataset(f)

        if "soil_moisture" not in ds.data_vars:
            ds.close()
            return None

        station = ds.attrs.get("station", Path(f).stem)
        network = ds.attrs.get("network", "")
        title = f"Soil Moisture - {station}" + (f" ({network})" if network else "")

        fig, ax = plt.subplots(figsize=(14, 6))

        if ("depth" in ds.dims) or ("depth" in ds.coords):
            for depth in ds["depth"].values:
                ds["soil_moisture"].sel(depth=depth).plot(ax=ax, label=str(depth))
            ax.legend(title="Depth")
        else:
            ds["soil_moisture"].plot(ax=ax, label="soil_moisture")
            ax.legend()

        ax.set_ylabel("Soil Moisture")
        ax.set_xlabel("Date")
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
        plt.tight_layout()

        out_png = Path(plot_dir) / f"{Path(f).stem}_soil_moisture_by_depth.png"
        plt.savefig(out_png, dpi=300, bbox_inches="tight")
        plt.close(fig)
        ds.close()

        if (i + 1) % 50 == 0:
            print(f"[{i+1}/{total}] done")

        return str(out_png)

    except Exception as e:
        return f"FAILED: {f} -> {e}"

In [89]:
if USE_PARALLEL:
    print(f"Running in parallel with {N_WORKERS} workers")
    with Pool(N_WORKERS) as pool:
        results = pool.map(plot_single_file, nc_files_with_idx)
else:
    print("Running sequentially (no parallelization)")
    results = [plot_single_file(task) for task in nc_files_with_idx]

# Optional: quick summary
n_ok = sum(r is not None and not str(r).startswith("FAILED:") for r in results)
n_fail = sum(isinstance(r, str) and r.startswith("FAILED:") for r in results)
print(f"Saved {n_ok} plots, failed {n_fail}. Output: {plot_dir}")

Running in parallel with 50 workers
[50/58] done
Saved 58 plots, failed 0. Output: /home/khanalp/code/PhD/soilMoisture/processed_soil_moisture/plots_for_gap_filled


In [83]:
df_metadata = pd.read_csv("/home/khanalp/code/PhD/soilMoisture/processed_soil_moisture/station_metadata.csv")
df_metadata

Unnamed: 0,network,station,status,remark,latitude,longitude,start_date,end_date,n_days,max_depth_cm,filename,filepath
0,GROW,pwvxhaqg,skipped,<1 year of valid daily data,40.88741,25.85531,20180818.0,20190520.0,276.0,,,
1,COSMOS,GLEES,saved,,41.36440,-106.23940,20151014.0,20170912.0,683.0,10.0,COSMOS_GLEES_20151014_20170912.nc,/home/khanalp/code/PhD/soilMoisture/processed_...
2,SCAN,MonoclineRidge,saved,,36.54417,-120.55463,20141030.0,20211101.0,2560.0,101.6,SCAN_MonoclineRidge_20141030_20211101.nc,/home/khanalp/code/PhD/soilMoisture/processed_...
3,Ru_CFR,Fyodorovskoyedrysprucestand,skipped,<1 year of valid daily data,56.44760,32.90188,,,0.0,,,
4,REMEDHUS,LasEritas,saved,,41.20548,-5.41558,20141111.0,20241231.0,3704.0,5.0,REMEDHUS_LasEritas_20141111_20241231.nc,/home/khanalp/code/PhD/soilMoisture/processed_...
...,...,...,...,...,...,...,...,...,...,...,...,...
95,SNOTEL,TogwoteePass,skipped,<1 year of valid daily data,43.74902,-110.05780,20141119.0,20150825.0,280.0,,,
96,iRON,SkyMountain,saved,,39.22098,-106.91313,20140101.0,20220607.0,3080.0,50.0,iRON_SkyMountain_20140101_20220607.nc,/home/khanalp/code/PhD/soilMoisture/processed_...
97,SNOTEL,PebbleCreek,saved,,42.76740,-112.10648,20211013.0,20251230.0,1540.0,50.8,SNOTEL_PebbleCreek_20211013_20251230.nc,/home/khanalp/code/PhD/soilMoisture/processed_...
98,SNOTEL,HartsPass,saved,,48.72047,-120.65860,20230916.0,20250420.0,583.0,50.8,SNOTEL_HartsPass_20230916_20250420.nc,/home/khanalp/code/PhD/soilMoisture/processed_...


In [None]:
ds_trial = xr.open_dataset('/home/khanalp/code/PhD/soilMoisture/processed_soil_moisture/WSMN_WSMN-8_20140613_20151020.nc')

In [None]:
ds_trial

In [67]:
df = pd.read_csv('/home/khanalp/code/PhD/soilMoisture/processed_soil_moisture/station_metadata.csv')

In [84]:
df[(df["status"] == "skipped") | (df["status"] == "error")]

Unnamed: 0,network,station,status,remark,latitude,longitude,start_date,end_date,n_days,max_depth_cm,filename,filepath
0,GROW,pwvxhaqg,skipped,<1 year of valid daily data,40.88741,25.85531,20180818.0,20190520.0,276.0,,,
3,Ru_CFR,Fyodorovskoyedrysprucestand,skipped,<1 year of valid daily data,56.4476,32.90188,,,0.0,,,
6,GROW,br4tnkrw,skipped,<1 year of valid daily data,40.891,25.85091,20180515.0,20180909.0,118.0,,,
10,COSMOS,NebField3,skipped,no surface run for 0-10 after removing long gaps,41.1649,-96.4701,,,,,,
11,COSMOS,Mapungubwe,skipped,no surface run for 0-10 after removing long gaps,-22.1917,29.3926,,,,,,
12,GROW,4w257zzr,skipped,<1 year of valid daily data,47.87487,20.58087,20170905.0,20180619.0,288.0,,,
13,PTSMN,Site-4,skipped,no surface run for 0-10 after removing long gaps,-40.73939,175.8639,,,,,,
14,ROMPS,Taragay,skipped,NaNs remain after gap-filling,41.729,77.80481,20220525.0,20230601.0,365.0,,,
16,COSMOS,Hermosillo,skipped,no surface run for 0-10 after removing long gaps,29.7369,-110.5054,,,,,,
18,PBO_H2O,MIDDLEGATE,skipped,no surface run for 0-10 after removing long gaps,39.3056,-117.9848,,,,,,


In [66]:
len(df[(df["status"] == "saved")])

58