In [3]:
import xarray as xr
import numpy as np
from scipy.interpolate import griddata
import pandas as pd
import os
from datetime import datetime 
import dask
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator

In [13]:
def reproject_mosdac_to_regular_grid(input_file, output_file,
                                     lat_min=16.0, lat_max=21.0, lon_min=68.0, lon_max=74.0,
                                     lat_res=0.25, lon_res=0.25, chunk_size=10,
                                     target_time_start='2019-01-01', target_time_end='2024-12-01',
                                     target_time_freq='MS'):
    """
    Reproject MOSDAC SST data from irregular grid to regular lat-lon grid (EPSG:4326)
    using chunked processing to avoid memory issues.
   
    Parameters:
    -----------
    input_file : str
        Path to the input MOSDAC netCDF file
    output_file : str
        Path to the output regular grid netCDF file
    lat_min, lat_max, lon_min, lon_max : float
        Boundaries of the target grid
    lat_res, lon_res : float
        Resolution of the target grid
    chunk_size : int
        Number of time steps to process at once
    target_time_start, target_time_end : str
        Start and end dates for the target time axis (format: 'YYYY-MM-DD')
    target_time_freq : str
        Frequency for target time axis (e.g., 'MS' for month start, 'D' for daily)
    """
    print(f"Starting reprojection of {input_file} at {datetime.now()}")
   
    # Create target grid
    target_lats = np.arange(lat_min, lat_max + lat_res/2, lat_res)
    target_lons = np.arange(lon_min, lon_max + lon_res/2, lon_res)
   
    # Create target time coordinates matching other datasets
    target_times = pd.date_range(start=target_time_start, end=target_time_end, freq=target_time_freq)
    print(f"Created target time axis with {len(target_times)} time steps from {target_time_start} to {target_time_end}")
   
    # Define time_dim variable - this was missing in the original code
    time_dim = len(target_times)
   
    # Open the dataset with dask to enable chunking
    with xr.open_dataset(input_file, chunks={'time': chunk_size}) as ds:
        print(f"Dataset opened. Original has {len(ds.time)} time steps. Processing in chunks of {chunk_size}")
       
        # Create output dataset with target coordinates
        out_ds = xr.Dataset(
            coords={
                'time': target_times,
                'lat': target_lats,
                'lon': target_lons
            }
        )
       
        # Initialize the output data array
        out_sst = np.zeros((time_dim, len(target_lats), len(target_lons)), dtype=np.float32)
        out_sst.fill(np.nan)  # Fill with NaN initially
       
        # Create a mapping between target time indices and source time indices
        # This appears to be used but not defined in the original code
        time_mapping = {}
        processed_target_indices = set()
       
        # Map source times to target times based on closest match
        source_times = ds.time.values
        for target_idx, target_time in enumerate(target_times):
            # Convert target_time to numpy datetime64 for comparison if needed
            if not isinstance(target_time, np.datetime64):
                target_time = np.datetime64(target_time)
           
            # Handle different datetime types by converting to a compatible format
            # Convert both to datetime64[ns] for comparison
            source_times_ns = source_times.astype('datetime64[ns]')
            target_time_ns = target_time.astype('datetime64[ns]')
           
            # Calculate time differences using datetime64[ns] format
            time_diffs = np.abs(source_times_ns - target_time_ns)
            closest_idx = np.argmin(time_diffs)
            time_mapping[target_idx] = closest_idx
       
        # Process by chunks to avoid memory issues
        for t_start in range(0, len(ds.time), chunk_size):
            t_end = min(t_start + chunk_size, len(ds.time))
            print(f"Processing time steps {t_start} to {t_end-1}")
           
            # Get chunk of source data
            chunk = ds.isel(time=slice(t_start, t_end))
           
            # Find which target times are mapped to this chunk's source times
            target_indices_in_chunk = [
                target_idx for target_idx, source_idx in time_mapping.items()
                if t_start <= source_idx < t_end
            ]
           
            if not target_indices_in_chunk:
                print(f"  No target times map to this source chunk, skipping")
                continue
               
            print(f"  Processing {len(target_indices_in_chunk)} target time steps in this chunk")
           
            # Process each mapped time step
            for target_idx in target_indices_in_chunk:
                if target_idx in processed_target_indices:
                    continue  # Skip if already processed
                   
                source_idx = time_mapping[target_idx]
                source_rel_idx = source_idx - t_start  # Index relative to chunk
               
                print(f"  Processing target time index {target_idx} (using source time index {source_idx})")
               
                # Get coordinates and data for this time step
                time_slice = chunk.isel(time=source_rel_idx)
               
                # Check if lat and lon are 3D arrays
                if len(time_slice.lat.shape) == 3:
                    # Extract lat/lon as 2D arrays for this time step
                    src_lats = time_slice.lat.values
                    src_lons = time_slice.lon.values
                else:
                    # Use as is if they're already 2D
                    src_lats = time_slice.lat.values
                    src_lons = time_slice.lon.values
               
                # Extract SST data
                # Handle the case where SST has different dimensions
                if 'dim_0' in time_slice.sst.dims:
                    # For the specific structure with dim_0, dim_1, dim_2
                    sst_values = time_slice.sst.values[0]  # Assuming dim_0 is single-valued
                else:
                    # Regular case
                    sst_values = time_slice.sst.values
               
                # Flatten the arrays for interpolation
                valid_mask = ~np.isnan(sst_values)
                if not np.any(valid_mask):
                    print(f"  Warning: No valid data for time step {target_idx}")
                    continue
               
                # Only use valid data points for interpolation
                points = np.column_stack((src_lats[valid_mask].flatten(),
                                          src_lons[valid_mask].flatten()))
                values = sst_values[valid_mask].flatten()
               
                # Create target grid points
                target_lon_grid, target_lat_grid = np.meshgrid(target_lons, target_lats)
                target_points = np.column_stack((target_lat_grid.flatten(), target_lon_grid.flatten()))
               
                # Perform interpolation
                print(f"  Interpolating {len(values)} points to {len(target_points)} target points")
                interp_sst = griddata(points, values, target_points, method='linear')
               
                # Reshape to grid and store in target time slot
                out_sst[target_idx] = interp_sst.reshape(len(target_lats), len(target_lons))
               
                # Mark as processed
                processed_target_indices.add(target_idx)
           
            # Force garbage collection to free memory
            import gc
            gc.collect()
       
        # Add the reprojected data to the output dataset
        out_ds['sst'] = xr.DataArray(
            data=out_sst,
            dims=['time', 'lat', 'lon'],
            attrs={
                'long_name': 'Sea Surface Temperature',
                'units': ds.sst.attrs.get('units', 'K'),
                'grid_mapping': 'crs'
            }
        )
       
        # Add CRS information
        out_ds['crs'] = xr.DataArray(
            data=np.array(0),
            attrs={
                'grid_mapping_name': 'latitude_longitude',
                'longitude_of_prime_meridian': 0.0,
                'semi_major_axis': 6378137.0,
                'inverse_flattening': 298.257223563,
                'spatial_ref': 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["degree",0.0174532925199433]]',
                'proj4text': '+proj=longlat +datum=WGS84 +no_defs'
            }
        )
       
        # Set metadata for coordinates
        out_ds.lat.attrs = {
            'standard_name': 'latitude',
            'long_name': 'Latitude',
            'units': 'degrees_north',
            'axis': 'Y'
        }
       
        out_ds.lon.attrs = {
            'standard_name': 'longitude',
            'long_name': 'Longitude',
            'units': 'degrees_east',
            'axis': 'X'
        }
       
        # Add global attributes
        out_ds.attrs = {
            'title': 'Reprojected MOSDAC SST data',
            'source': os.path.basename(input_file),
            'history': f'Created on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
            'description': 'SST data reprojected to regular lat-lon grid (EPSG:4326)',
            'Conventions': 'CF-1.8'
        }
       
        # Write to netCDF file
        print(f"Writing output to {output_file}")
        comp = dict(zlib=True, complevel=5)
        encoding = {var: comp for var in out_ds.data_vars}
        out_ds.to_netcdf(output_file, format='NETCDF4', encoding=encoding)
       
    print(f"Reprojection completed at {datetime.now()}")
    return output_file

if __name__ == "__main__":
    
    input_file = r'C:\Users\Admin\RIYA PROJECT\DATASETS\mosdac_mon_mean.nc'
    output_file = r"C:\Users\Admin\RIYA PROJECT\DATASETS\mosdac_reproj.nc"
   
    # Set parameters to match your other datasets
    lat_min = 16.0
    lat_max = 21.0
    lon_min = 68.0
    lon_max = 74.0
    lat_res = 0.25
    lon_res = 0.25
   
    # Set time parameters to match your other datasets
    # This ensures alignment with NOAA and ECMWF datasets
    target_time_start = '2019-01-01'
    target_time_end = '2024-12-01'
    target_time_freq = 'MS'  # Month Start frequency
   
    # Adjust chunk_size based on your available RAM
    # Smaller chunk_size = less memory but slower processing
    reproject_mosdac_to_regular_grid(
        input_file,
        output_file,
        lat_min=lat_min,
        lat_max=lat_max,
        lon_min=lon_min,
        lon_max=lon_max,
        lat_res=lat_res,
        lon_res=lon_res,
        chunk_size=5,  # Process 5 time steps at once
        target_time_start=target_time_start,
        target_time_end=target_time_end,
        target_time_freq=target_time_freq
    )

Starting reprojection of C:\Users\Admin\RIYA PROJECT\DATASETS\mosdac_mon_mean.nc at 2025-04-02 13:20:31.260904
Created target time axis with 72 time steps from 2019-01-01 to 2024-12-01
Dataset opened. Original has 12 time steps. Processing in chunks of 5
Processing time steps 0 to 4
  No target times map to this source chunk, skipping
Processing time steps 5 to 9
  No target times map to this source chunk, skipping
Processing time steps 10 to 11
  Processing 72 target time steps in this chunk
  Processing target time index 0 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 1 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 2 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 3 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 4 (using source

  Processing target time index 67 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 68 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 69 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 70 (using source time index 11)
  Interpolating 2699597 points to 525 target points
  Processing target time index 71 (using source time index 11)
  Interpolating 2699597 points to 525 target points
Writing output to C:\Users\Admin\RIYA PROJECT\DATASETS\mosdac_reproj2.nc
Reprojection completed at 2025-04-02 14:22:58.743807


In [14]:
crs2 = xr.open_dataset(r"C:\Users\Admin\RIYA PROJECT\DATASETS\mosdac_reproj.nc")
crs2

In [9]:
aa = xr.open_dataset(r'C:\Users\Admin\RIYA PROJECT\DATASETS\mosdac_mon_mean.nc')
aa['sst'].values

array([[[[nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         ...,
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan]]],


       [[[nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         ...,
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan]]],


       [[[nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         ...,
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan]]],


       ...,


       [[[nan, nan, nan, ..., nan, nan, nan],
         [nan, nan, nan, ..., nan, nan, nan],
         [na