In [1]:
%load_ext autoreload
%autoreload 2

main_dir = '/Users/simon/bsose_monthly/'
salt = main_dir + 'bsose_i106_2008to2012_monthly_Salt.nc'
theta = main_dir + 'bsose_i106_2008to2012_monthly_Theta.nc'

In [40]:
import numpy as np
import xarray as xr
xr.set_options(keep_attrs=True)
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature
import matplotlib.path as mpath
import pyxpcm
from pyxpcm.models import pcm


def pcm_fit_and_post(time_i=42, K=4, maxvar=2):  
    # Define features to use
    # Instantiate the PCM
    
    max_depth = 2000
    z = np.arange(0., -max_depth, -10.)
    features_pcm = {'THETA': z, 'SALT': z}
    features = {'THETA': 'THETA', 'SALT': 'SALT'}
    salt_nc = xr.open_dataset(salt).isel(time=time_i)
    theta_nc = xr.open_dataset(theta).isel(time=time_i)
    big_nc = xr.merge([salt_nc, theta_nc])
    both_nc = big_nc.where(big_nc.coords['Depth'] > 
                           max_depth).drop(['iter', 'Depth', 
                                            'rA', 'drF', 'hFacC'])   
 
    attr_d = {}

    for coord in both_nc.coords:
        attr_d[coord] = both_nc.coords[coord].attrs
        
    lons_new = np.linspace(both_nc.XC.min(), both_nc.XC.max(), 60*4)
    lats_new = np.linspace(both_nc.YC.min(), both_nc.YC.max(), 60)
    # ds = both_nc # .copy(deep=True)
    ds = both_nc.interp(coords={'YC': lats_new, 'XC': lons_new})#, method='cubic')
        
    m = pcm(K=K, features=features_pcm, maxvar=maxvar, 
            timeit=True, timeit_verb=1)
    m.fit(ds, features=features, dim='Z') #, inplace=True)
    m.predict(ds, features=features, dim='Z', inplace=True)
    m.predict_proba(ds, features=features, dim='Z', inplace=True)
    m.find_i_metric(ds, inplace=True)
    
    def sanitize():
        del ds.PCM_LABELS.attrs['_pyXpcm_cleanable']
        del ds.PCM_POST.attrs['_pyXpcm_cleanable']
        del ds.PCM_RANK.attrs['_pyXpcm_cleanable']
        
    for coord in attr_d:
        ds.coords[coord].attrs = attr_d[coord]
    
    sanitize()
    return ds, m

  class pyXpcmDataSetAccessor:


In [41]:
ds, m = pcm_fit_and_post(K=5)
ds

  fit.1-preprocess.1-mask: 21 ms
  fit.1-preprocess.2-feature_THETA.1-ravel: 23 ms
  fit.1-preprocess.2-feature_THETA.2-interp: 42 ms
  fit.1-preprocess.2-feature_THETA.3-scale_fit: 78 ms
  fit.1-preprocess.2-feature_THETA.4-scale_transform: 33 ms
  fit.1-preprocess.2-feature_THETA.5-reduce_fit: 125 ms
  fit.1-preprocess.2-feature_THETA.6-reduce_transform: 11 ms
  fit.1-preprocess.2-feature_THETA.total: 317 ms
  fit.1-preprocess: 317 ms
  fit.1-preprocess.3-homogeniser: 2 ms
  fit.1-preprocess.2-feature_SALT.1-ravel: 20 ms
  fit.1-preprocess.2-feature_SALT.2-interp: 18 ms
  fit.1-preprocess.2-feature_SALT.3-scale_fit: 46 ms
  fit.1-preprocess.2-feature_SALT.4-scale_transform: 29 ms
  fit.1-preprocess.2-feature_SALT.5-reduce_fit: 52 ms
  fit.1-preprocess.2-feature_SALT.6-reduce_transform: 7 ms
  fit.1-preprocess.2-feature_SALT.total: 176 ms
  fit.1-preprocess: 176 ms
  fit.1-preprocess.3-homogeniser: 1 ms
  fit.1-preprocess.4-xarray: 1 ms
  fit.1-preprocess: 523 ms
  fit.fit: 323 ms
  f

In [42]:
ds.THETA

In [43]:
ds.SALT

In [79]:
import gsw

def return_rho(pt_values, salt_values, lon_values, z_values):
    ct_values = gsw.conversions.CT_from_pt(salt_values, pt_values)
    # gsw.conversions.SA_from_SP(SP, p, lon, lat)
    lon_mesh, z_mesh = np.meshgrid(lon_values, z_values)
    print(z_values)
    print(lon_values)
    pressure_mesh = gsw.p_from_z(z_mesh, lon_mesh)
    pressure_values = np.zeros(np.shape(ct_values))
    for i in range(np.shape(ct_values)[2]):
        pressure_values[:, :, i] = pressure_mesh[:, :]
    rho_values = gsw.density.rho(salt_values, ct_values, pressure_values)
    print(np.shape(rho_values))
    return rho_values, ct_values, pressure_values
    
rho_values, ct_values, pressure_values = return_rho(ds.THETA.values, ds.SALT.values, ds.YC.values, ds.Z.values)

def rho_values_dataset(dataset, rho_values, propagate=True):
    if propagate:
        attr_d = {}
        for coord in dataset.coords:
            attr_d[coord] = dataset.coords[coord].attrs
    
    da = xr.DataArray(rho_values, dims=["Z", "YC", "XC"], 
                      coords=[ds.Z.values, ds.YC.values, ds.XC.values]).rename('rho')
    
    da.attrs['long_name'] = 'density'
    da.attrs['units'] = 'kg m^{-3}'
    da.attrs['valid_min'] = 1000 
    da.attrs['valid_max'] = 1100
    
    if propagate:
        for coord in attr_d:
            if coord != 'time':
                da.coords[coord].attrs = attr_d[coord]
    dataset[da.name] = da
    
    if propagate:
        for coord in attr_d:
            dataset.coords[coord].attrs = attr_d[coord]

    return dataset, da


def ct_values_dataset(dataset, ct_values, propagate=True):
    if propagate:
        attr_d = {}
        for coord in dataset.coords:
            attr_d[coord] = dataset.coords[coord].attrs
    
    da = xr.DataArray(ct_values, dims=["Z", "YC", "XC"], 
                      coords=[ds.Z.values, ds.YC.values, ds.XC.values]).rename('ct')
    
    da.attrs['long_name'] = 'conservative temperature'
    da.attrs['units'] = 'degC'
    #da.attrs['valid_min'] = -
    #da.attrs['valid_max'] = 1100
    
    if propagate:
        for coord in attr_d:
            if coord != 'time':
                da.coords[coord].attrs = attr_d[coord]
    dataset[da.name] = da
    
    if propagate:
        for coord in attr_d:
            dataset.coords[coord].attrs = attr_d[coord]

    return dataset, da


def pressure_values_dataset(dataset, pressure_values, propagate=True):
    if propagate:
        attr_d = {}
        for coord in dataset.coords:
            attr_d[coord] = dataset.coords[coord].attrs
    
    da = xr.DataArray(pressure_values, dims=["Z", "YC", "XC"], 
                      coords=[ds.Z.values, ds.YC.values, ds.XC.values]).rename('pressure')
    
    da.attrs['long_name'] = 'pressure'
    da.attrs['units'] = 'Pa'
    #da.attrs['valid_min'] = -
    #da.attrs['valid_max'] = 1100
    
    if propagate:
        for coord in attr_d:
            if coord != 'time':
                da.coords[coord].attrs = attr_d[coord]
    dataset[da.name] = da
    
    if propagate:
        for coord in attr_d:
            dataset.coords[coord].attrs = attr_d[coord]

    return dataset, da


ds, da = rho_values_dataset(ds, rho_values)
ds, da = ct_values_dataset(ds, ct_values)
ds, da = pressure_values_dataset(ds, pressure_values)


[-2.100e+00 -6.700e+00 -1.215e+01 -1.855e+01 -2.625e+01 -3.525e+01
 -4.500e+01 -5.500e+01 -6.500e+01 -7.500e+01 -8.500e+01 -9.500e+01
 -1.050e+02 -1.150e+02 -1.250e+02 -1.350e+02 -1.465e+02 -1.615e+02
 -1.800e+02 -2.000e+02 -2.200e+02 -2.400e+02 -2.600e+02 -2.800e+02
 -3.010e+02 -3.270e+02 -3.610e+02 -4.025e+02 -4.500e+02 -5.000e+02
 -5.515e+02 -6.140e+02 -7.000e+02 -8.000e+02 -9.000e+02 -1.000e+03
 -1.100e+03 -1.225e+03 -1.400e+03 -1.600e+03 -1.800e+03 -2.010e+03
 -2.270e+03 -2.610e+03 -3.000e+03 -3.400e+03 -3.800e+03 -4.200e+03
 -4.600e+03 -5.000e+03 -5.400e+03 -5.800e+03]
[-77.98265076 -77.16456035 -76.34646994 -75.52837954 -74.71028913
 -73.89219872 -73.07410832 -72.25601791 -71.4379275  -70.6198371
 -69.80174669 -68.98365629 -68.16556588 -67.34747547 -66.52938507
 -65.71129466 -64.89320425 -64.07511385 -63.25702344 -62.43893303
 -61.62084263 -60.80275222 -59.98466181 -59.16657141 -58.348481
 -57.53039059 -56.71230019 -55.89420978 -55.07611937 -54.25802897
 -53.43993856 -52.6218481

In [69]:
np.shape(ds.rho.values)

(52, 60, 240)

In [80]:
ds.to_netcdf('density_2.nc', engine='netcdf4')

In [74]:
ds.rho

In [77]:
ds.ct


In [78]:
ds.pressure

In [60]:
da