# Usage Guide : pygtm
## netcdf IO functions

In [1]:
import sys
import os
import h5py
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cmocean
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
%matplotlib inline

In [2]:
sys.path.insert(0, '../')
from pygtm.physical import physical_space
from pygtm.matrix import matrix_space
from pygtm.dataset import trajectory
from pygtm.tools import export_nc, import_nc

# Load drifter trajectories
## Using the drogued drifters in GDP database in the North Atlantic ([*download* data here](https://miamiedu-my.sharepoint.com/:u:/g/personal/pxm498_miami_edu/EfaPVkKsPABJrJtKyctKrAUBpp7XzNHHrOLUhFow3pMkZw?e=MXzrcG)) and ([notebook](https://github.com/philippemiron/notebooks/blob/master/gdp.ipynb) for more info)

In [3]:
filename = 'data/gdp-north-atlantic-drogued.mat'
with h5py.File(filename, 'r') as f:
    d_id = f['id'][0]
    x = f['x'][0]  # [deg]
    y = f['y'][0]  # [deg]
    t = f['t'][0]  # [day]
del f

# drifter IDs are reused in the GDP so you make sure
# there isn't more than 10d before data points of a
# same drifter ID
I = np.where(abs(np.diff(d_id, axis=0)) > 0)[0]
I = np.insert(I, [0, len(I)], [-1, len(d_id) - 1])
max_id = np.max(d_id)

for i in range(0, len(I) - 1):
    range_i = np.arange(I[i] + 1, I[i + 1] + 1)
    t_diff = np.diff(t[range_i])

    # if there is a big gap, I changed the id and
    # treat the new segments as another drifters
    jump = np.where(t_diff > 10)[0]  # 10 days
    if len(jump) > 0:
        jump = np.insert(jump, [0, len(jump)], [-1, len(range_i) - 1])

        for j in range(0, len(jump) - 1):
            range_j = np.arange(jump[j] + 1, jump[j + 1] + 1)
            d_id[range_i[range_j]] = np.ones(len(range_j)) * (max_id + 1)  # range modification
            max_id += 1

In [4]:
T = 5  # transition time [days]
spatial_dis = 75
lon = [-98, 35]
lat = [-5, 80]

# create the grid and bins
d = physical_space(lon, lat, spatial_dis)

# creates segments ready to plot with add_collection()
data = trajectory(x,y,t,d_id)
data.create_segments(T)

In [5]:
# create matrix object
tm = matrix_space(d)
tm.fill_transition_matrix(data)
tm.left_and_right_eigenvectors(20)

Domain contains 1808 bins. (1596 bins were removed)


# Import and Export to netCDF

In [6]:
# export and import
export_nc('test.nc', data, d, tm)

In [7]:
data2, d2, tm2 = import_nc('test.nc')

# Validate the IO functions

In [8]:
# function to compare before writing and reading
def compare_domain(dom, dom2):
    test = np.zeros(14, dtype=bool)
    test[0] = np.array_equal(d.lon, d2.lon)
    test[1] = np.array_equal(d.lat, d2.lat)
    test[2] = dom.resolution == d2.resolution
    test[3] = dom.nx == d2.nx
    test[4] = dom.ny == d2.ny
    test[5] = np.array_equal(d.coords, d2.coords)
    test[6] = np.array_equal(d.bins, d2.bins)
    test[7] = np.array_equal(d.vx, d2.vx)
    test[8] = np.array_equal(d.vy, d2.vy)
    test[9] = d.dx == d2.dx
    test[10] = d.dy == d2.dy
    test[11] = np.array_equal(d.id, d2.id)
    test[12] = d.N0 == d2.N0
    test[13] = np.array_equal(d.id_og, d2.id_og)
    
    if np.all(test):
        return True
    else:
        return False

def compare_data(data, data2):
    if np.all(data.x0 == data2.x0) and np.all(data.y0 == data2.y0) and np.all(data.xt == data2.xt) and np.all(data.yt == data2.yt):
        return True
    else:
        return False
    
def compare_matrix(tm, tm2):
    test = np.zeros(10, dtype=bool)
    test[0] = tm.N == tm2.N
    test[1] = np.all(np.hstack((np.equal(tm.B, tm2.B, dtype=np.object))))
    test[2] = np.array_equal(tm.P, tm2.P)
    test[3] = np.array_equal(tm.M, tm2.M)
    test[4] = np.array_equal(tm.fi,tm2.fi)
    test[5] = np.array_equal(tm.fo,tm2.fo)
    test[6] = np.array_equal(tm.eigL,tm2.eigL)
    test[7] = np.array_equal(tm.L, tm2.L)
    test[8] = np.array_equal(tm.eigR, tm2.eigR)
    test[9] = np.array_equal(tm.R, tm2.R)

    if np.all(test):
        return True
    else:
        return False

In [9]:
# validate io
if compare_domain(d, d2) and compare_data(data, data2) and compare_matrix(tm, tm2):
    print('Objects before and after IO are equals.')
else:
    print('Error during the writing or reading phase.')

# delete file so it's not save to github
os.remove('test.nc')

Objects before and after IO are equals.
