In [None]:
import numpy as np
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor, as_completed
from joblib import Parallel, delayed
import xarray as xr
import time
import os
import pygmt

import AbFab as af

%load_ext autoreload
%autoreload 2


In [None]:
import pygmt
age_da = xr.open_dataarray('/Users/simon/Data/AgeGrids/2020/age.2020.1.GeeK2007.2m.nc')
#age_da = xr.open_dataarray('/Users/simon/Data/AgeGrids/2020/age.2020.1.GeeK2007.6m.nc')

sed_da = pygmt.grdsample('/Users/simon/GIT/pyBacktrack/pybacktrack/bundle_data/sediment_thickness/GlobSed.nc',
                         region='/Users/simon/Data/AgeGrids/2020/age.2020.1.GeeK2007.2m.nc')

age_da = af.extend_longitude_range(age_da).sel(lon=slice(-190,190))
sed_da = af.extend_longitude_range(sed_da).sel(lon=slice(-190,190))

sed_da = sed_da.where(np.isfinite(sed_da), 100.)
sed_da = sed_da.where(sed_da<500., 500.)

#sed_da = age_da[200:400,100:300]

rand_da = age_da.copy()
rand_da.data = af.generate_random_field(rand_da.data.shape)

sed_da.plot(vmin=0,vmax=1000)
#rand_da.plot()

In [None]:

full_ny,full_nx = age_da.shape

#nx = 360
#ny = 200
#x0 = 0
#y0 = 0
#seafloor_age = age_da[:200,:200]
#ny,nx = age_da.shape
#sediment_thickness = np.random.rand(ny, nx) * 5  # Example sediment thickness in meters

params = {
    'H': 500,       # Base rms height in meters
    'kn': 0.05,   # Characteristic width in km^-1 (normal to ridge)
    'ks': 0.5,   # Characteristic length in km^-1 (parallel to ridge)
    'D': 2.2       # Fractal dimension
}


start = time.time()



def process_bathymetry_chunk(coord, chunksize, chunkpad):
    
    print(coord, chunksize)
    
    chunk_age = age_da[coord[0]:coord[0]+chunksize+chunkpad, 
                   coord[1]:coord[1]+chunksize+chunkpad]
    chunk_sed = sed_da[coord[0]:coord[0]+chunksize+chunkpad, 
                       coord[1]:coord[1]+chunksize+chunkpad]
    chunk_random = rand_da[coord[0]:coord[0]+chunksize+chunkpad, 
                           coord[1]:coord[1]+chunksize+chunkpad]


    if np.all(np.isnan(chunk_age.data)):
        print('Empty Chunk')
        return chunk_age
    
    ny,nx = chunk_age.shape
    
    #print(chunk_random.shape)
    
    #chunk_sed.data = np.random.rand(ny,nx) * 5  # Example sediment thickness in meters

    #print(chunk_sed.data.min())
    
    # Generate the synthetic bathymetry
    synthetic_bathymetry = af.generate_bathymetry_spatial_filter(chunk_age.data, 
                                                                 chunk_sed.data, 
                                                                 params,
                                                                 chunk_random)

    return xr.DataArray(synthetic_bathymetry, 
                        coords=chunk_age.coords, name='z')[int(chunkpad/2):int(-chunkpad/2),int(chunkpad/2):int(-chunkpad/2)]
    

chunksize = 50
chunkpad = 20

chunkpad = int(2 * np.round(chunkpad / 2))

coords = np.meshgrid(np.arange(0,full_ny-1,chunksize), np.arange(0,full_nx-1,chunksize))
coords = list(zip(coords[0].flatten(), coords[1].flatten()))

num_cpus = 4
results = Parallel(n_jobs=num_cpus)(delayed(process_bathymetry_chunk)(coord, chunksize, chunkpad) for coord in coords)
#results = [process_bathymetry_chunk(coord, chunksize, chunkpad) for coord in coords]

#result = process_bathymetry_chunk((720,1520),#coords[246], 
#                                  chunksize, chunkpad)

#for coord in coords:
#    process_bathymetry_chunk(coord, chunksize, chunkpad)
#    break


print(time.time() - start)

#plt.pcolormesh(result.lon, result.lat, result.data, vmin=-0.5, vmax=0.5)


In [None]:
#xr.combine_by_coords(results)

resss = [result for result in results if 0 not in result.shape]
#resss
#for res in resss:
#    print(res.shape)
#results[24].plot()

#out = xr.combine_by_coords(resss)


fig,ax = plt.subplots(figsize=(30,16))
for res in resss:
    plt.pcolormesh(res.lon, res.lat, res.data, vmin=-0.5, vmax=0.5, cmap='magma')
#for res in resss:
#    plt.pcolormesh(res.lon[int(chunkpad/2):int(-chunkpad/2)], 
#                   res.lat[int(chunkpad/2):int(-chunkpad/2)], 
#                   res.data[int(chunkpad/2):int(-chunkpad/2),int(chunkpad/2):int(-chunkpad/2)], 
#                   vmin=-0.5, vmax=0.5)
#plt.pcolormesh(res.lon, res.lat, res.data, cmap='magma')
#plt.pcolormesh(resss[0].lon[:chunksize], resss[0].lat[:chunksize], resss[0].data[:chunksize,:chunksize], vmin=-0.5, vmax=0.5)
plt.axis([-50,0,-30,0])
plt.colorbar()
plt.show()


In [None]:
(resss[0]-resss[2]).plot()