In [36]:
import os
import xarray as xr
import matplotlib.pyplot as plt
from sklearn.decomposition import NMF
import tensorly as tl
import numpy as np

In [37]:
MVBS_path = '/Users/wu-jung/code_git/ooi_sonar/zplsc_data_2015fall/nc_MVBS_envFromFile/'
MVBS_path = '../data/reproduced_MVBS_files/'
MVBS_file = '20150817-20151017_MVBS_time_from_Sv_rangeBin5_all.nc'

In [38]:
MVBS = xr.open_dataset(os.path.join(MVBS_path, MVBS_file))
MVBS

In [39]:
#MVBS_PCP_path = '/Users/wu-jung/code_git/ooi_sonar/zplsc_data_2015fall/nc_PCP_envFromFile/'
MVBS_PCP_path = '../data/reproduced_MVBS_files/'
MVBS_rpca_file = '20150817-20151017_MVBS_time_from_Sv_rangeBin5_rpca.nc'

In [40]:
MVBS_rpca = xr.open_dataset(os.path.join(MVBS_PCP_path, MVBS_rpca_file))
MVBS_rpca

In [41]:
low_rank = MVBS_rpca['low_rank']
low_rank.shape

(62, 3, 37, 144)

In [42]:
#low_rank = low_rank.sel(frequency=[38000])

In [43]:
n_observations, n_frequencies, n_depth_levels, n_pins = low_rank.shape

In [44]:
low_rank = low_rank.values.reshape([n_observations,-1])
low_rank.shape

(62, 15984)

In [45]:
low_rank_nonneg = low_rank - low_rank.min()

In [46]:
low_rank_nonneg_scaled = (low_rank_nonneg/np.std(low_rank_nonneg.T, axis=1)).T

## Classic NMF

In [47]:
model = NMF(n_components=3, init='random', random_state=0)

In [None]:
H = model.fit_transform(low_rank_nonneg_scaled.T)
W = model.components_

In [None]:
(W.shape, H.shape)

In [None]:
hlines = plt.plot(H)

In [None]:
W_reorg = W.reshape(3,37,144)

In [None]:
fig, ax = plt.subplots(1,3,figsize=(15,6))
for icomp in range(3):
    ax[icomp].imshow(tl.unfold(tl.tensor(W_reorg[icomp,:,:].squeeze()),mode=1).T,
                     aspect='auto')

In [None]:
fig, ax = plt.subplots(1,3,figsize=(15,6))
for icomp in range(3):
    ax[icomp].imshow(tl.unfold(tl.tensor(W_reorg[icomp,:,:,:].squeeze()),mode=2).T,
                     aspect='auto')

In [None]:
model_scaled = NMF(n_components=3, init='random', random_state=0)
H = model_scaled.fit_transform(low_rank_nonneg)
W = model_scaled.components_
W_reorg = W.reshape(3,3,37,144)

In [None]:
hlines = plt.plot(H)
fig, ax = plt.subplots(1,3,figsize=(15,6))
for icomp in range(3):
    ax[icomp].imshow(tl.unfold(tl.tensor(W_reorg[icomp,:,:,:].squeeze()),mode=2).T,
                     aspect='auto')

## Check similarity between days

In [None]:
from scipy.spatial.distance import pdist, squareform

In [None]:
# Normalize the activation coefficients
k = H.T
k_norm = k.T-k.min(axis=1)
k_norm = k_norm/k_norm.max(axis=0)
D = pdist(k_norm, 'euclidean')
D_square = squareform(D)
similarity_m = 1-D_square/D_square.max()

# Check similarity between any two days within the observation period
fig = plt.figure(figsize=(6,4))
ax = fig.add_subplot(111)
plt.imshow(similarity_m,cmap='RdYlBu_r')
plt.xticks(np.arange(0,62,10),fontsize=14)
plt.yticks(np.arange(0,62,10),fontsize=14)
plt.xlabel('Day',fontsize=16)
plt.ylabel('Day',fontsize=16)

cbaxes = fig.add_axes([0.8, 0.125, 0.03, 0.755]) 
cbar = plt.colorbar(cax = cbaxes)  
cbar.ax.tick_params(labelsize=14) 
cbar.ax.set_ylabel('Similarity', rotation=90, fontsize=16)
plt.show()


## Check reconstruction error

In [None]:
recon = (W.T@H.T).T

In [None]:
recon_da = xr.DataArray(np.moveaxis(recon.reshape([62, 3, 37, 144]),[0,1,2],[2,0,1]).reshape([3,37,-1]),
                        coords=[('frequency', MVBS_rpca['frequency']),
                                ('depth', MVBS_rpca['depth']),
                                ('ping_time', MVBS['ping_time'])])

In [None]:
rpca_da = xr.DataArray(np.moveaxis(MVBS_rpca['low_rank'].values,[0,1,2],[2,0,1]).reshape([3,37,-1])-
                       MVBS_rpca['low_rank'].values.min(),
                       coords=[('frequency', MVBS_rpca['frequency']),
                                ('depth', MVBS_rpca['depth']),
                                ('ping_time', MVBS['ping_time'])])

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(20,6), sharex=True)
for ifreq, freq in enumerate([38000,120000,200000]):
    (recon_da-rpca_da).sel(frequency=freq).plot(ax=ax[ifreq], yincrease=False)

# Smooth NMF

Next we run the smooth NMF which imposes smoothness (in time) on the activations by adding a Tikhonov regularization term on the gradient of $H$.

In [None]:
# perform this step once to install the ssnmf package
!pip install --upgrade git+https://github.com/valentina-s/ss-nmf.git

In [48]:
!pip install --upgrade ~/projects/ss-nmf

Processing /Users/valentina/projects/ss-nmf
Building wheels for collected packages: ss-nmf
  Building wheel for ss-nmf (setup.py) ... [?25ldone
[?25h  Created wheel for ss-nmf: filename=ss_nmf-VERSION-cp37-none-any.whl size=8526711 sha256=cf0f340d831c45dacaea0c1929be0c7c02eb0c87a5c9d387f17cfe604ec6859d
  Stored in directory: /Users/valentina/Library/Caches/pip/wheels/d7/7c/ba/b54f0f3eb5c7145fb79e786f902b02752012e20b2c512b1619
Successfully built ss-nmf
Installing collected packages: ss-nmf
  Found existing installation: ss-nmf VERSION
    Uninstalling ss-nmf-VERSION:
      Successfully uninstalled ss-nmf-VERSION
Successfully installed ss-nmf-VERSION


In [49]:
import ssnmf
model = ssnmf.smoothNMF(n_components=3, max_iter=200, smoothness=5000000)

In [50]:
!mkdir checkpoints

mkdir: checkpoints: File exists


In [87]:
%%time
model.fit(low_rank_nonneg.T, init='random', checkpoint_idx=range(201), checkpoint_dir='./checkpoints', random_state=1)

CPU times: user 5.21 s, sys: 2.37 s, total: 7.57 s
Wall time: 2.32 s


In [88]:
ls checkpoints

chkpt-2020-01-04-19:25:54.294395.db  chkpt-2020-01-05-22:33:21.660402.db
chkpt-2020-01-04-19:28:22.910514.db  chkpt-2020-01-05-22:42:08.063138.db
chkpt-2020-01-05-04:56:34.282931.db  chkpt-2020-01-05-23:10:17.235639.db
chkpt-2020-01-05-05:07:15.705589.db  chkpt-2020-01-06-00:16:01.430192.db
chkpt-2020-01-05-22:00:09.694790.db  chkpt-2020-01-06-00:29:21.598527.db
chkpt-2020-01-05-22:01:15.207486.db  chkpt-2020-01-06-01:46:41.186396.db
chkpt-2020-01-05-22:23:41.778270.db  chkpt-2020-01-06-02:01:37.713040.db
chkpt-2020-01-05-22:26:40.479871.db  chkpt-2020-01-06-02:06:54.938689.db
chkpt-2020-01-05-22:26:47.706272.db


In [89]:
from glob import glob
last_chkpt = sorted(glob('checkpoints/chkpt-*'))[-1][:-3]
print(last_chkpt)

checkpoints/chkpt-2020-01-06-02:06:54.938689


In [90]:
import shelve
# note: do not use the .db extension when opening
chkpt_data = shelve.open(last_chkpt)


In [99]:
# test first iteration
from ssnmf import check_random_state

rng = check_random_state(1)
W = rng.uniform(0, 1, size=(low_rank_nonneg.T.shape[0], 3))
H = rng.uniform(0, 1, size=(3, low_rank_nonneg.T.shape[1]))
np.linalg.norm(chkpt_data[str(0)]['H'] - H)

0.0

In [93]:
# test last iteration (note to get the last one we need checkpoint_idx=range(201))
np.linalg.norm(chkpt_data[str(200)]['H'] - model.H)

0.0

In [61]:
# Note the length is not calculated correctly due to a bug
len(chkpt_data)

179

In [62]:
# but the data seems to be there
# display H and W for  eachiteration 

for it in range(200):
    print(chkpt_data[str(it)])

{'H': array([[0.30124358, 0.2060183 , 0.369306  , 0.92220307, 0.06026278,
        0.85097617, 0.67264426, 0.41984316, 0.66284384, 0.13642608,
        0.20417431, 0.17372092, 0.61665235, 0.89044779, 0.94673414,
        0.08315065, 0.44708085, 0.15831286, 0.30028286, 0.13484988,
        0.095977  , 0.85301793, 0.29237805, 0.92112578, 0.28185691,
        0.95166369, 0.863764  , 0.61112653, 0.95046007, 0.96514901,
        0.05870271, 0.72779551, 0.36174043, 0.17883863, 0.35614566,
        0.37198603, 0.06183442, 0.02192633, 0.46612253, 0.67749808,
        0.51436968, 0.03815606, 0.08375638, 0.05174836, 0.12848583,
        0.10473978, 0.19386304, 0.48667639, 0.50641727, 0.36322305,
        0.84603002, 0.15310394, 0.79668244, 0.11733593, 0.53194877,
        0.5559848 , 0.07565229, 0.26592183, 0.02787784, 0.66057164,
        0.59324505, 0.80862184],
       [0.13583141, 0.92691674, 0.94547887, 0.94447631, 0.06853416,
        0.81978468, 0.9351946 , 0.20111382, 0.41532185, 0.69200392,
        0

       [12.28528962,  8.6852744 , 31.04676231]])}
{'H': array([[0.45658018, 0.44129805, 0.43395061, 0.44779322, 0.45343524,
        0.48394423, 0.48700698, 0.47126051, 0.48201806, 0.50714749,
        0.51467091, 0.51482201, 0.51637839, 0.51169676, 0.52108173,
        0.51030566, 0.48833429, 0.49633198, 0.50714517, 0.52113311,
        0.53367359, 0.56435814, 0.59495586, 0.62033483, 0.64400371,
        0.65867028, 0.68313216, 0.67084528, 0.61689511, 0.56237579,
        0.52673864, 0.4598459 , 0.38864401, 0.36343848, 0.3350678 ,
        0.31880955, 0.30961423, 0.32101713, 0.3316225 , 0.34833303,
        0.29992515, 0.24162483, 0.19811267, 0.19235478, 0.19426833,
        0.18956804, 0.19724702, 0.23784934, 0.28574051, 0.35107646,
        0.40318584, 0.47352702, 0.50672337, 0.50549215, 0.47402195,
        0.46413942, 0.48171758, 0.47910139, 0.4549474 , 0.4611731 ,
        0.46506477, 0.42984247],
       [0.69795869, 0.69063067, 0.67083079, 0.66298046, 0.64953413,
        0.65390303, 0.64167

       [12.87838183,  4.16589289, 35.21850932]])}
{'H': array([[0.47412909, 0.45058217, 0.43364493, 0.43724858, 0.43134118,
        0.4591775 , 0.46593752, 0.46255391, 0.4931923 , 0.54024268,
        0.56061069, 0.56587453, 0.56833827, 0.5628728 , 0.57372412,
        0.56833774, 0.55267961, 0.56889154, 0.5835081 , 0.59214166,
        0.58981806, 0.59698899, 0.59739342, 0.59192641, 0.58361905,
        0.57273693, 0.58319458, 0.56551079, 0.51671598, 0.47862356,
        0.46872386, 0.4242011 , 0.3687227 , 0.35976267, 0.34244766,
        0.33777396, 0.32820777, 0.3340013 , 0.33664207, 0.34487047,
        0.28196967, 0.21996741, 0.18762346, 0.19760442, 0.20638672,
        0.1866438 , 0.1675433 , 0.18216625, 0.20598628, 0.25785963,
        0.30653889, 0.39160694, 0.44248767, 0.46089192, 0.4511516 ,
        0.46863164, 0.511219  , 0.52048929, 0.49561312, 0.49247267,
        0.48237941, 0.43527924],
       [0.67874778, 0.68471805, 0.67778537, 0.6823385 , 0.68197489,
        0.69417179, 0.69055

In [59]:
# cannot display the items due to a bug
for k, v in chkpt_data.items():
    print(k)

SystemError: Negative size passed to PyBytes_FromStringAndSize

In [60]:
# cannot convert to a list
list(chkpt_data)

SystemError: Negative size passed to PyBytes_FromStringAndSize

In [None]:
%%time
# test random state
model.fit(low_rank_nonneg.T, init='random', random_state=1)

In [None]:
# set W and H
H = model.H.T
W = model.W.T

In [None]:
(W.shape, H.shape)

In [None]:
hlines = plt.plot(H)

In [None]:
low_rank_nonneg.shape

In [None]:
W_reorg = W.reshape(3,3,37,144)

In [None]:
fig, ax = plt.subplots(1,3,figsize=(15,6))
for icomp in range(3):
    ax[icomp].imshow(tl.unfold(tl.tensor(W_reorg[icomp,:,:,:].squeeze()),mode=2).T,
                     aspect='auto')

## Check similarity between days

In [None]:
from scipy.spatial.distance import pdist, squareform

In [None]:
# Normalize the activation coefficients
k = H.T
k_norm = k.T-k.min(axis=1)
k_norm = k_norm/k_norm.max(axis=0)
D = pdist(k_norm, 'euclidean')
D_square = squareform(D)
similarity_m = 1-D_square/D_square.max()

# Check similarity between any two days within the observation period
fig = plt.figure(figsize=(6,4))
ax = fig.add_subplot(111)
plt.imshow(similarity_m,cmap='RdYlBu_r')
plt.xticks(np.arange(0,62,10),fontsize=14)
plt.yticks(np.arange(0,62,10),fontsize=14)
plt.xlabel('Day',fontsize=16)
plt.ylabel('Day',fontsize=16)

cbaxes = fig.add_axes([0.8, 0.125, 0.03, 0.755]) 
cbar = plt.colorbar(cax = cbaxes)  
cbar.ax.tick_params(labelsize=14) 
cbar.ax.set_ylabel('Similarity', rotation=90, fontsize=16)
plt.show()

## Check reconstruction error

In [None]:
recon = (W.T@H.T).T


In [None]:
recon_da = xr.DataArray(np.moveaxis(recon.reshape([62, 3, 37, 144]),[0,1,2],[2,0,1]).reshape([3,37,-1]),
                        coords=[('frequency', MVBS_rpca['frequency']),
                                ('depth', MVBS_rpca['depth']),
                                ('ping_time', MVBS['ping_time'])])

In [None]:
rpca_da = xr.DataArray(np.moveaxis(MVBS_rpca['low_rank'].values,[0,1,2],[2,0,1]).reshape([3,37,-1])-
                       MVBS_rpca['low_rank'].values.min(),
                       coords=[('frequency', MVBS_rpca['frequency']),
                                ('depth', MVBS_rpca['depth']),
                                ('ping_time', MVBS['ping_time'])])

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(20,6), sharex=True)
for ifreq, freq in enumerate([38000,120000,200000]):
    (recon_da-rpca_da).sel(frequency=freq).plot(ax=ax[ifreq], yincrease=False)