In [1]:
%load_ext autoreload
%autoreload 2

main_dir = "/Users/simon/bsose_monthly/"
salt = main_dir + "bsose_i106_2008to2012_monthly_Salt.nc"
theta = main_dir + "bsose_i106_2008to2012_monthly_Theta.nc"

In [25]:
import numpy as np
import xarray as xr
xr.set_options(keep_attrs=True)
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature
import matplotlib.path as mpath
import pyxpcm
from pyxpcm.models import pcm


def pcm_pca_out(time_i=42, K=4, maxvar=2):  
    # Define features to use
    # Instantiate the PCM
    
    max_depth = 2000
    z = np.arange(-300., -max_depth, -10.)
    features_pcm = {'THETA': z, 'SALT': z}
    features = {'THETA': 'THETA', 'SALT': 'SALT'}
    salt_nc = xr.open_dataset(salt).isel(time=time_i)
    theta_nc = xr.open_dataset(theta).isel(time=time_i)
    big_nc = xr.merge([salt_nc, theta_nc])
    both_nc = big_nc.where(big_nc.coords['Depth'] > 
                           max_depth).drop(['iter', 'Depth', 
                                            'rA', 'drF', 'hFacC'])   
 
    attr_d = {}

    for coord in both_nc.coords:
        attr_d[coord] = both_nc.coords[coord].attrs
        
    lons_new = np.linspace(both_nc.XC.min(), both_nc.XC.max(), 60*4)
    lats_new = np.linspace(both_nc.YC.min(), both_nc.YC.max(), 60)
    # ds = both_nc # .copy(deep=True)
    ds = both_nc.interp(coords={'YC': lats_new, 'XC': lons_new})#, method='cubic')
        
    m = pcm(K=K, features=features_pcm, 
            maxvar=maxvar, 
            timeit=True, timeit_verb=1)
    ds = m.add_pca_to_xarray(ds, features=features, dim='Z', inplace=True)

    m.fit(ds, features=features, dim='Z') #, inplace=True)
    m.predict(ds, features=features, dim='Z', inplace=True)
    m.predict_proba(ds, features=features, dim='Z', inplace=True)
    m.find_i_metric(ds, inplace=True)
    
    def sanitize():
        del ds.PCM_LABELS.attrs['_pyXpcm_cleanable']
        del ds.PCM_POST.attrs['_pyXpcm_cleanable']
        del ds.PCM_RANK.attrs['_pyXpcm_cleanable']
        del ds.PCA_VALUES.attrs['_pyXpcm_cleanable']
        
    for coord in attr_d:
        ds.coords[coord].attrs = attr_d[coord]
    
    sanitize()
    return ds, m 

In [41]:
ds, m = pcm_pca_out(K=5)
ds

  fit.1-preprocess.1-mask: 19 ms
[-2.100e+00 -6.700e+00 -1.215e+01 -1.855e+01 -2.625e+01 -3.525e+01
 -4.500e+01 -5.500e+01 -6.500e+01 -7.500e+01 -8.500e+01 -9.500e+01
 -1.050e+02 -1.150e+02 -1.250e+02 -1.350e+02 -1.465e+02 -1.615e+02
 -1.800e+02 -2.000e+02 -2.200e+02 -2.400e+02 -2.600e+02 -2.800e+02
 -3.010e+02 -3.270e+02 -3.610e+02 -4.025e+02 -4.500e+02 -5.000e+02
 -5.515e+02 -6.140e+02 -7.000e+02 -8.000e+02 -9.000e+02 -1.000e+03
 -1.100e+03 -1.225e+03 -1.400e+03 -1.600e+03 -1.800e+03 -2.010e+03
 -2.270e+03 -2.610e+03 -3.000e+03 -3.400e+03 -3.800e+03 -4.200e+03
 -4.600e+03 -5.000e+03 -5.400e+03 -5.800e+03]
  fit.1-preprocess.2-feature_THETA.1-ravel: 20 ms
[-2.100e+00 -6.700e+00 -1.215e+01 -1.855e+01 -2.625e+01 -3.525e+01
 -4.500e+01 -5.500e+01 -6.500e+01 -7.500e+01 -8.500e+01 -9.500e+01
 -1.050e+02 -1.150e+02 -1.250e+02 -1.350e+02 -1.465e+02 -1.615e+02
 -1.800e+02 -2.000e+02 -2.200e+02 -2.400e+02 -2.600e+02 -2.800e+02
 -3.010e+02 -3.270e+02 -3.610e+02 -4.025e+02 -4.500e+02 -5.000e+02


  predict.1-preprocess.2-feature_SALT.4-scale_transform: 30 ms
  predict.1-preprocess.2-feature_SALT.5-reduce_fit: 0 ms
  predict.1-preprocess.2-feature_SALT.6-reduce_transform: 6 ms
  predict.1-preprocess.2-feature_SALT.total: 71 ms
  predict.1-preprocess: 71 ms
  predict.1-preprocess.3-homogeniser: 0 ms
  predict.1-preprocess.4-xarray: 1 ms
  predict.1-preprocess: 235 ms
  predict.predict: 4 ms
  predict.score: 4 ms
  predict.xarray: 24 ms
  predict: 270 ms
  predict_proba.1-preprocess.1-mask: 11 ms
[-2.100e+00 -6.700e+00 -1.215e+01 -1.855e+01 -2.625e+01 -3.525e+01
 -4.500e+01 -5.500e+01 -6.500e+01 -7.500e+01 -8.500e+01 -9.500e+01
 -1.050e+02 -1.150e+02 -1.250e+02 -1.350e+02 -1.465e+02 -1.615e+02
 -1.800e+02 -2.000e+02 -2.200e+02 -2.400e+02 -2.600e+02 -2.800e+02
 -3.010e+02 -3.270e+02 -3.610e+02 -4.025e+02 -4.500e+02 -5.000e+02
 -5.515e+02 -6.140e+02 -7.000e+02 -8.000e+02 -9.000e+02 -1.000e+03
 -1.100e+03 -1.225e+03 -1.400e+03 -1.600e+03 -1.800e+03 -2.010e+03
 -2.270e+03 -2.610e+03 -

In [31]:
ds.to_netcdf('pca_trial.nc', engine='netcdf4')

In [32]:
X

In [40]:
salt_nc = xr.open_dataset(theta)
salt_nc.coords['XC'].values

array([8.3333336e-02, 2.5000000e-01, 4.1666669e-01, ..., 3.5958334e+02,
       3.5975000e+02, 3.5991669e+02], dtype=float32)

In [2]:
import xarray as xr
xr.set_options(keep_attrs=True)

def merge_whole_density_netcdf():

    pca_ds = xr.open_mfdataset('nc/labels/*.nc',
                               concat_dim="time",
                               combine='by_coords',
                               chunks={'time': 1},
                               data_vars='minimal',
                               # parallel=True,
                               coords='minimal',
                               compat='override').astype('float32')

    # this is too intense for memory

    xr.save_mfdataset([pca_ds], ['nc/pcm_labels.nc'], format='NETCDF4')

merge_whole_density_netcdf()

            with self._context('predict_proba.predict', self._context_args):
                post_values = self._classifier.predict_proba(X)

            with self._context('predict_proba.score', self._context_args):
                self._props['llh'] = self._classifier.score(X)

            # Create a xarray with posteriors:
            with self._context('predict_proba.xarray', self._context_args):
                P = list()
                for k in range(self.K):
                    X = post_values[:, k]
                    x = self.unravel(ds, sampling_dims, X)
                    P.append(x)
                
                da = xr.concat(P, dim=classdimname).rename(name)
                da.attrs['long_name'] = 'PCM posteriors'
                da.attrs['units'] = ''
                da.attrs['valid_min'] = 0
                da.attrs['valid_max'] = 1
                da.attrs['llh'] = self._props['llh']

            # Add posteriors to the dataset:
            if inplace:
                return ds.pyxpcm.add(da)
            else:
                return da
