---
---
# [0] Setup

## [0.1] Imports

In [1]:
from pathlib import Path

import dask.array as da
import numpy as np
import rioxarray
import xarray as xr

import matplotlib.pyplot as plt
import plotly.express as px

from redplanet.DatasetManager.hash import _calculate_hash_from_file, get_available_algorithms


from contextlib import contextmanager
import time
@contextmanager
def timer():
    start = time.time()
    yield
    end = time.time()
    print(f'{end - start:.3f} seconds')

---
## [0.2] Inputs

Dataset can be downloaded here: 
- 463m (2 GB): https://astrogeology.usgs.gov/search/map/mars_mgs_mola_dem_463m
- 200m (10.6 GB): https://astrogeology.usgs.gov/search/map/mars_mgs_mola_mex_hrsc_blended_dem_global_200m

See README for more info.

In [2]:
dem_info = {
    '463m': {
        'fpath': Path('/home/lain/root/100_work/110_projects/111_mars/raw_data/Mars_MGS_MOLA_DEM_mosaic_global_463m.tif'),
        'tif_hash': {
            'xxh3_64': 'a0dc027e687f855f',
        },
        'memmap_hash': {
            'xxh3_64': '6eed1a19495d736f',
        },
        'target_crs_wkt': 'GEOGCS["GCS_Mars_2000_Sphere",DATUM["Mars_2000_(Sphere)",SPHEROID["Mars_2000_Sphere_IAU_IAG",3396190,0]],PRIMEM["Reference_Meridian",0],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST]]',
        ## Note: native units in the 463m DEM TIF data file are just `x` and `y` (not sure what units), so we reproject using the CRS WKT straight from the 200m DEM TIF file (accessed via `dat_dem_xr.spatial_ref.attrs['crs_wkt']`).
        'shape': (23041, 46081),
        'dtype': np.int16,
        'nan_value': -99_999, ## data is stored as int16 which doesn't support `np.nan`, so we use this sentinel value.
        # 'na_value': -32768,
        'lons_approx': -179.9960938347692 + 0.007812330461578525 * np.arange(46081),
        'lats_approx': -89.99376946560506 + 0.00781206004494716  * np.arange(23041),
    },
    '200m': {
        'fpath': Path('/home/lain/root/100_work/110_projects/111_mars/raw_data/Mars_HRSC_MOLA_BlendDEM_Global_200mp_v2.tif'),
        'tif_hash': {
            'xxh3_64': 'dafb191af5826c66',
        },
        'memmap_hash': {
            'xxh3_64': 'e8cc649a36ea4fab',
        },
        'shape': (53347, 106694),
        'dtype': np.int16,
        'nan_value': -99_999, ## data is stored as int16 which doesn't support `np.nan`, so we use this sentinel value.
        # 'na_value': -32768, ## note: 39957 nan values, ~3.25 mins to compute
        'lons_approx': -179.9983129395848 + 0.0033741208306410017 * np.arange(106694),
        'lats_approx': -89.99753689179012 + 0.0033741208306410004 * np.arange(53347),
    },
}



choice = '463m'
# choice = '200m'

assert choice in dem_info, f'Invalid choice: {choice}'

---
---
# [1] Extract info from TIF

---
## [1.1] TIF -> xarray

In [3]:
assert _calculate_hash_from_file(dem_info[choice]['fpath'], 'xxh3_64') == dem_info[choice]['tif_hash']['xxh3_64']

In [None]:
dat_dem_xr = rioxarray.open_rasterio(
    filename = dem_info[choice]['fpath'],
)

target_crs_wkt = dem_info[choice].get('target_crs_wkt')
if target_crs_wkt:
    dat_dem_xr = dat_dem_xr.rio.reproject(target_crs_wkt)     ## bottleneck (~1-3 mins)

dat_dem_xr = (
    dat_dem_xr
    .sel(band=1)
    .drop_vars(['band'])
    .rename({'x': 'lon', 'y': 'lat'})
    .isel(lat = slice(None, None, -1))
    .rename(dem_info[choice]['fpath'].stem)
    .chunk({'lon': 'auto', 'lat': 'auto'})
)

dat_dem_xr = dat_dem_xr.where(
    dat_dem_xr != dat_dem_xr._FillValue,
    dem_info[choice]['nan_value']
)

dat_dem_xr

---
## [1.2] Inspect TIF dataset

This this point, `dat_dem_xr` looks like this...

- 463m:

![](https://files.catbox.moe/nb3cva.png)

- 200m:

![](https://files.catbox.moe/0hblcg.png)

Info about the CRS is stored in `dat_dem_xr.spatial_ref.attrs` as a dict, if you're curious. The CRS WKT for both is:

- `'GEOGCS["GCS_Mars_2000_Sphere",DATUM["Mars_2000_(Sphere)",SPHEROID["Mars_2000_Sphere_IAU_IAG",3396190,0]],PRIMEM["Reference_Meridian",0],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST]]'`

---
## [1.3] Longitude/latitude approximations

- The original TIF file contains two 1D arrays representing longitude/latitude values for each axis of the 2D grid. 
- However, there's no straightforward/clean way to store these in the primary numpy memmap file since the memmap data type is int16 (i.e., we'd need to store/distribute/load an extra array of coordinates for each DEM file).
- Therefore, the best approach is applying a linear fit and using that to generate the longitude/latitude series upon loading the memmap file.
    - => The max error with this method is on the order of $10^{-14}$.

In [None]:
"""calculate approximation from actual data"""

lons_actual = dat_dem_xr.lon.values
lats_actual = dat_dem_xr.lat.values


idx_lons = np.arange(len(lons_actual))
idx_lats = np.arange(len(lats_actual))

fit_lons = np.polynomial.polynomial.Polynomial.fit(
    x      = idx_lons,
    y      = lons_actual,
    deg    = 1,
    domain = [0, idx_lons[-1]]
)
fit_lats = np.polynomial.polynomial.Polynomial.fit(
    x      = idx_lats,
    y      = lats_actual,
    deg    = 1,
    domain = [0, idx_lats[-1]]
)

fit_lons_coefs = fit_lons.convert().coef
fit_lats_coefs = fit_lats.convert().coef

print(f'lon LSR fit: {fit_lons_coefs[0]} + {fit_lons_coefs[1]} * x')
print(f'lat LSR fit: {fit_lats_coefs[0]} + {fit_lats_coefs[1]} * x')
print()

lons_approx = fit_lons(idx_lons)
lats_approx = fit_lats(idx_lats)

lons_error = np.abs(lons_actual - lons_approx)
lats_error = np.abs(lats_actual - lats_approx)

print(f'{lons_error.max() = }')
print(f'{lats_error.max() = }')

In [None]:
"""test/verify approximation from the `dem_info` dict"""

lons_approx = dem_info[choice]['lons_approx']
lats_approx = dem_info[choice]['lats_approx']

lons_error = np.abs(lons_actual - lons_approx)
lats_error = np.abs(lats_actual - lats_approx)

print(f'{lons_error.max() = }')
print(f'{lats_error.max() = }')

---
---
# [2] Save to numpy memmap

In [None]:
## extract dask array from the xr.DataArray
dask_array = dat_dem_xr.data

## metadata, must be known when loading!!
dtype = dask_array.dtype
shape = dask_array.shape
print(f'{dtype = }')
print(f'{shape = }')

## set up file names/paths
dirpath_out = Path.cwd() / 'output'
dirpath_out.mkdir(parents=True, exist_ok=True)
fpath_memmap = dirpath_out / dem_info[choice]['fpath'].with_suffix('.memmap').name
memmap_array = np.memmap(
    fpath_memmap,
    mode  = 'w+',
    dtype = dtype,
    shape = shape,
)

with timer():
    ## save dask array to numpy memmap file
    da.store(dask_array, memmap_array)
    ## "Write any changes in the array to the file on disk." (not sure if necessary, just in case)
    memmap_array.flush()

In [None]:
print(f'{fpath_memmap.name}')
print(f'{fpath_memmap.stat().st_size / 1e9:.2f} GB')

# for alg in get_available_algorithms():
#     print(f'- {alg}: {_calculate_hash_from_file(fpath_memmap, alg)}')

alg = 'xxh3_64'
calculated_hash = _calculate_hash_from_file(fpath_memmap, alg)
known_hash = dem_info[choice]['memmap_hash'][alg]

print(f'- {alg}: {calculated_hash}')

assert calculated_hash == known_hash

---

## Actual hashes:

Mars_HRSC_MOLA_BlendDEM_Global_200mp_v2.memmap (11.38 GB)
- xxh3_64: e8cc649a36ea4fab
- md5: 74cedd82aaf200b62ebb64affffe0e7e
- sha1: 0c7704155a3e9fb6bef284980fdb37aa559457c5
- sha256: 691b6ce6a1cacc5fcea4b95ef1832fac50e421e1ec8f7fb33e5c791396aa4a4f

...

Mars_MGS_MOLA_DEM_mosaic_global_463m.memmap (2.12 GB)
- xxh3_64: 6eed1a19495d736f
- md5: 0f2378e55a01c217b2662b7ba07a3f27
- sha1: f3547e5423bd447179e5126e37b262e4136adcac
- sha256: 7788fa9287c633456fbf2de8b0e674a7e375014d2b58731b45f991be284879c4

---
---
# [3] Demo/test: load from scratch

In [None]:
dirpath_out = Path.cwd() / 'output'
fpath_memmap = dirpath_out / dem_info[choice]['fpath'].with_suffix('.memmap').name

dat_dem_memmap = np.memmap(
    fpath_memmap,
    mode  = 'r',
    dtype = dem_info[choice]['dtype'],
    shape = dem_info[choice]['shape'],
)



def get(
    lons: float | np.ndarray,
    lats: float | np.ndarray,
    return_exact_coords: bool = False,
) -> np.ndarray | list[np.ndarray, np.ndarray, np.ndarray]:

    lons = np.atleast_1d(lons)
    lats = np.atleast_1d(lats)

    idx_lons = find_closest_indices(dem_info[choice]['lons_approx'], lons)
    idx_lats = find_closest_indices(dem_info[choice]['lats_approx'], lats)
    dat = dat_dem_memmap[np.ix_(idx_lats, idx_lons)]

    if return_exact_coords:
        lons_exact = dem_info[choice]['lons_approx'][idx_lons]
        lats_exact = dem_info[choice]['lats_approx'][idx_lats]
        return (dat, lons_exact, lats_exact)

    return dat



def find_closest_indices(sorted_array, target_values):
    insertion_indices = np.searchsorted(sorted_array, target_values)

    insertion_indices = np.clip(insertion_indices, 1, len(sorted_array) - 1)

    left_neighbors = sorted_array[insertion_indices - 1]
    right_neighbors = sorted_array[insertion_indices]

    closest_indices = np.where(
        np.abs(target_values - left_neighbors) <= np.abs(target_values - right_neighbors),
        insertion_indices - 1,
        insertion_indices
    )

    return closest_indices





print(f'{choice = }')
lon_spacing = np.unique(np.diff(dem_info[choice]['lons_approx']))[0]
lat_spacing = np.unique(np.diff(dem_info[choice]['lats_approx']))[0]
print(f'{lon_spacing = }')
print(f'{lat_spacing = }')
grid_spacing = np.max([lon_spacing, lat_spacing])

In [None]:
## global
lons = np.arange(-180, 180, 1)
lats = np.arange(-90, 90, 1)

## Henry crater
# center = (23.5, 10.8)
# radius = 3
# resolution = 0.01
# lons = np.arange(center[0] - radius, center[0] + radius, resolution)
# lats = np.arange(center[1] - radius, center[1] + radius, resolution)

# ## Valles Marineris
# center = (-33, 0)
# radius = 25
# resolution = 0.1
# lons = np.arange(center[0] - radius, center[0] + radius, resolution)
# lats = np.arange(center[1] - radius, center[1] + radius, resolution)

# ## Hellas
# center = (70, -43)
# radius = 30
# resolution = 0.1
# lons = np.arange(center[0] - radius, center[0] + radius, resolution)
# lats = np.arange(center[1] - radius, center[1] + radius, resolution)

# ## Argyre
# center = (-43.30980, -49.84058)
# radius = 20
# resolution = 0.05
# lons = np.arange(center[0] - radius, center[0] + radius, resolution)
# lats = np.arange(center[1] - radius, center[1] + radius, resolution)

# ## drippy
# center = (-21.09126, 2.65345)
# radius = 7
# resolution = 0.005
# lons = np.arange(center[0] - radius, center[0] + radius, resolution)
# lats = np.arange(center[1] - radius, center[1] + radius, resolution)

# ## Gusev crater
# center = (175.52437, -14.53081)
# radius = 4
# resolution = 0.001
# lons = np.arange(center[0] - radius, center[0] + radius, resolution)
# lats = np.arange(center[1] - radius, center[1] + radius, resolution)


with timer():
    dat, lons_exact, lats_exact = get(lons, lats, return_exact_coords=True)

# px.imshow(
#     dat,
#     origin='lower',
#     color_continuous_scale='viridis',
#     x = lons,
#     y = lats,
# )


plt.figure(figsize=(5, 5))

plt.imshow(
    dat,
    origin = 'lower',
    aspect = 'equal',
    # cmap='viridis',
    extent = [ lons_exact[0], lons_exact[-1], lats_exact[0], lats_exact[-1] ],
)

---
---
# [Footnote] — Alternative reprojection method to `rio.reproject` with `rasterio` WarpedVRT

Note: when it comes to reprojecting the 463m DEM, here's another code implementation that yields identical results and loads instantly (although I assume speed is lost when calculations are done on the fly). The code is pretty rough (and probably missing some things), but should be fairly easy to adapt if the need arises.

&nbsp;

```python
import rasterio

target_crs_wkt = 'GEOGCS["GCS_Mars_2000_Sphere",DATUM["Mars_2000_(Sphere)",SPHEROID["Mars_2000_Sphere_IAU_IAG",3396190,0]],PRIMEM["Reference_Meridian",0],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST]]'

# vrt = rasterio.vrt.WarpedVRT(rasterio.open(custom_DEM_fpath), crs=target_crs_wkt)
dat_dem_xr_v1 = (
    rioxarray.open_rasterio(
        filename = rasterio.vrt.WarpedVRT(
            src_dataset = rasterio.open(fpath_dem_tif), 
            crs         = target_crs_wkt, 
        ), 
        chunks = {'x': 'auto', 'y': 'auto'}, 
    )
    .sel(band=1).drop_vars(['band', 'spatial_ref'])
    .rename({'x': 'lon', 'y': 'lat'})
    .sortby('lat', ascending=True)
)


dat_dem_xr_v1
```