Skip to content

Commit

Permalink
dimension_issue (#55)
Browse files Browse the repository at this point in the history
* dimension_issue

* update readme

* requested changes

* sort issue
  • Loading branch information
Ci Zhang authored and rabernat committed May 24, 2017
1 parent 1eca7ff commit f338a9e
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Transcoding is done via the `floater_convert` script, which is installed with th
$ floater_convert
usage: floater_convert [-h] [--float_file_prefix PREFIX] [--float_buf_dim N]
[--progress] [--input_dir DIR] [--output_format FMT]
[--keep_fields FIELDS] [--ref_time RT]
[--keep_fields FIELDS] [--ref_time RT] [--pkl_path PP]
[--output_dir OD] [--output_prefix OP]
output_file
```
Expand Down
19 changes: 14 additions & 5 deletions floater/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def subset_floats_from_mask(self, xx, yy):


def npart_to_2D_array(self, ds1d):
"""Constructs 2D Dataset from 1D DataArray/DataSet of single or multi-variable.
"""Constructs 2D Dataset from 1D DataArray/Dataset of single or multi-variable.
PARAMETERS
----------
Expand All @@ -353,26 +353,35 @@ def npart_to_2D_array(self, ds1d):
if type(ds1d) == xr.core.dataarray.DataArray:
ds1d = ds1d.to_dataset()
df = ds1d.to_dataframe()
var_list = list(df)
var_list = list(df.columns)
index_dict = {'index': range(1, Nt+1)}
var_dict = {var: np.zeros(Nt) for var in var_list}
frame_dict = {}
frame_dict.update(index_dict)
frame_dict.update(var_dict)
frame = pd.DataFrame(frame_dict)
framei = frame.set_index('index')
framei.columns = var_list
if self.model_grid is not None:
ocean_bools = self.ocean_bools
else:
ocean_bools = np.zeros(Nt, dtype=bool)==False
framei.loc[ocean_bools==True] = df.values.astype(np.float32)
framei.loc[ocean_bools==False] = np.float32('nan')
data_vars = {}
dim_list = list(ds1d.dims)
dim_list.remove('npart')
dim_len = len(dim_list)
new_shape = (1,)*dim_len + (Ny, Nx)
new_dims = dim_list + ['lat', 'lon']
for var in var_list:
frameir = framei[var].values.reshape(Ny, Nx)
data_vars.update({var: (['lat', 'lon'], frameir)})
frameir = framei[var].values
frameir.shape = new_shape
data_vars.update({var: (new_dims, frameir)})
coords = {}
lon = np.float32(self.x)
lat = np.float32(self.y)
coords = {'lat': (['lat'], lat), 'lon': (['lon'], lon)}
coords.update({dim: ([dim], ds1d[dim].values) for dim in dim_list})
coords.update({'lat': (['lat'], lat), 'lon': (['lon'], lon)})
ds2d = xr.Dataset(data_vars=data_vars, coords=coords)
return ds2d
29 changes: 19 additions & 10 deletions floater/test/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def test_npart_to_2D_array():
land_mask.shape = (len(lat), len(lon))
land_mask[:,0:2] = False
model_grid = {'lon': lon, 'lat': lat, 'land_mask': land_mask}
fs_none = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5), dx=1, dy=1)
fs_mask = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5), dx=1, dy=1, model_grid=model_grid)
fs_none = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5), dx=1.0, dy=1.0)
fs_mask = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5), dx=1.0, dy=1.0, model_grid=model_grid)
# dataarray/dataset
var_list = ['test_01', 'test_02', 'test_03']
values_list_none = []
Expand All @@ -205,15 +205,21 @@ def test_npart_to_2D_array():
data_vars_mask = {}
for var in var_list:
values_none = np.random.random(81)
values_none.shape = (1, 1, 81)
values_mask = np.random.random(69)
values_mask.shape = (1, 1, 69)
values_list_none.append(values_none)
values_list_mask.append(values_mask)
data_vars_none.update({var: (['npart'], values_none)})
data_vars_mask.update({var: (['npart'], values_mask)})
data_vars_none.update({var: (['date', 'loc', 'npart'], values_none)})
data_vars_mask.update({var: (['date', 'loc', 'npart'], values_mask)})
npart_none = np.linspace(1, 81, 81, dtype=np.int32)
npart_mask = np.linspace(1, 69, 69, dtype=np.int32)
coords_none = {'npart': (['npart'], npart_none)}
coords_mask = {'npart': (['npart'], npart_mask)}
coords_none = {'date': (['date'], np.array([np.datetime64('2000-01-01')])),
'loc': (['loc'], np.array(['New York'])),
'npart': (['npart'], npart_none)}
coords_mask = {'date': (['date'], np.array([np.datetime64('2000-01-01')])),
'loc': (['loc'], np.array(['New York'])),
'npart': (['npart'], npart_mask)}
ds1d_none = xr.Dataset(data_vars=data_vars_none, coords=coords_none)
ds1d_mask = xr.Dataset(data_vars=data_vars_mask, coords=coords_mask)
da1d_none = ds1d_none['test_01']
Expand All @@ -228,20 +234,23 @@ def test_npart_to_2D_array():
da2d = fs.npart_to_2D_array(da1d)
ds2d = fs.npart_to_2D_array(ds1d)
# shape test
assert da2d.to_array().values.shape == (1, fs.Ny, fs.Nx)
assert ds2d.to_array().values.shape == (3, fs.Ny, fs.Nx)
assert da2d.to_array().values.shape == (1, 1, 1, fs.Ny, fs.Nx)
assert ds2d.to_array().values.shape == (3, 1, 1, fs.Ny, fs.Nx)
# dimension test
assert da2d.dims == {'date': 1, 'loc': 1, 'lat': 9, 'lon': 9}
assert ds2d.dims == {'date': 1, 'loc': 1, 'lat': 9, 'lon': 9}
# coordinates test
np.testing.assert_allclose(da2d.lon.values, fs.x)
np.testing.assert_allclose(da2d.lat.values, fs.y)
np.testing.assert_allclose(ds2d.lon.values, fs.x)
np.testing.assert_allclose(ds2d.lat.values, fs.y)
# values test
da1d_values = values_list[0]
da1d_values = values_list[0][0][0]
da2d_values_full = da2d.to_array().values[0].ravel()
da2d_values = da2d_values_full[~np.isnan(da2d_values_full)]
np.testing.assert_allclose(da2d_values, da1d_values)
for i in range(3):
ds1d_values = values_list[i]
ds1d_values = values_list[i][0][0]
ds2d_values_full = ds2d.to_array().values[i].ravel()
ds2d_values = ds2d_values_full[~np.isnan(ds2d_values_full)]
np.testing.assert_allclose(ds2d_values, ds1d_values)
Expand Down
17 changes: 10 additions & 7 deletions floater/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,20 @@ def test_floats_to_netcdf(tmpdir, mitgcm_float_datadir_csv):
"""Test that we can convert MITgcm float data into NetCDF format.
"""
import xarray as xr
from floater.generators import FloatSet

input_dir = str(mitgcm_float_datadir_csv) + '/'
output_dir = str(tmpdir) + '/'
input_dir = str(mitgcm_float_datadir_csv)
output_dir = str(tmpdir)
os.chdir(input_dir)
fs = FloatSet(xlim=(-5, 5), ylim=(-2, 2), dx=1.0, dy=1.0)
fs.to_pickle('./fs.pkl')

# least options
utils.floats_to_netcdf(input_dir=input_dir, output_fname='test')
# most options
utils.floats_to_netcdf(input_dir=input_dir, output_fname='test',
float_file_prefix='float_trajectories',
ref_time='1993-01-01',
ref_time='1993-01-01', pkl_path='./fs.pkl',
output_dir=output_dir, output_prefix='prefix_test')

# filename prefix test
Expand All @@ -112,17 +115,17 @@ def test_floats_to_netcdf(tmpdir, mitgcm_float_datadir_csv):
mfdm = xr.open_mfdataset('test_netcdf/prefix_test.*.nc')

# dimensions test
dims = {'npart': 40, 'time': 2}
assert mfdl.dims == dims
assert mfdm.dims == dims
dims = [{'time': 2, 'npart': 40}, {'time': 2, 'lat': 4, 'lon': 10}]
assert mfdl.dims == dims[0]
assert mfdm.dims == dims[1]

# variables and values test
vars_values = [('x', 0.3237109375000000e+03), ('y', -0.7798437500000000e+02),
('z', -0.4999999999999893e+00), ('u', -0.5346306607990328e-02),
('v', -0.2787361934305595e-02), ('vort', 0.9160626946271506e-10)]
for var, value in vars_values:
np.testing.assert_almost_equal(mfdl[var].values[0][0], value, 8)
np.testing.assert_almost_equal(mfdm[var].values[0][0], value, 8)
np.testing.assert_almost_equal(mfdm[var].values[0][0][0], value, 8)

# times test
times = [(0, 0, np.datetime64('1993-01-01', 'ns')), (1, 86400, np.datetime64('1993-01-02', 'ns'))]
Expand Down
9 changes: 7 additions & 2 deletions floater/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def floats_to_castra(input_dir, output_fname, progress=False, **kwargs):
def floats_to_netcdf(input_dir, output_fname,
float_file_prefix='float_trajectories',
ref_time=None, output_dir='./',
output_prefix='float_trajectories'):
output_prefix='float_trajectories',
pkl_path=None):
"""Convert MITgcm float data to NetCDF format.
Parameters
Expand All @@ -242,14 +243,15 @@ def floats_to_netcdf(input_dir, output_fname,
"""
import dask.dataframe as dd
import xarray as xr
from floater.generators import FloatSet
from glob import glob
from tqdm import tqdm

output_fname = _maybe_add_suffix(output_fname, '_netcdf')

match_pattern = float_file_prefix + '.*.csv'
float_files = glob(os.path.join(input_dir, match_pattern))
float_timesteps = set(sorted([int(float_file[-22:-12]) for float_file in float_files]))
float_timesteps = sorted(list({int(float_file[-22:-12]) for float_file in float_files}))

float_columns = ['npart', 'time', 'x', 'y', 'z', 'u', 'v', 'vort']
var_names = float_columns[2:]
Expand All @@ -270,6 +272,9 @@ def floats_to_netcdf(input_dir, output_fname,
var_shape = (1, len(npart))
data_vars = {var_name: (['time', 'npart'], dfcs[var_name].values.astype(np.float32).reshape(var_shape)) for var_name in var_names}
ds = xr.Dataset(data_vars, coords={'time': time, 'npart': npart})
if pkl_path is not None:
fs = FloatSet(load_path=pkl_path)
ds = fs.npart_to_2D_array(ds)
output_path = os.path.join(output_dir, output_fname)
if not os.path.exists(output_path):
os.makedirs(output_path)
Expand Down
6 changes: 5 additions & 1 deletion scripts/floater_convert
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ parser.add_argument('--output_prefix', default='float_trajectories', metavar='OP
parser.add_argument('--ref_time', default=None, metavar='RT',
help='reference time, format: YYYY-MM-DD')

parser.add_argument('--pkl_path', default=None, metavar='PP',
help='path to the pickle file of the FloatSet used to generate initial positions')

parser.add_argument('output_file',
help='the output filename')

Expand Down Expand Up @@ -69,4 +72,5 @@ elif args.output_format=='netcdf':
float_file_prefix=args.float_file_prefix,
ref_time=args.ref_time,
output_dir=args.output_dir,
output_prefix=args.output_prefix)
output_prefix=args.output_prefix,
pkl_path=args.pkl_path)

0 comments on commit f338a9e

Please sign in to comment.