<font size=6>Ocean Data Analysis - Analyzing Sea Surface Temperature w/ Variational Inference</font>

Interesting Research
https://www.nature.com/articles/s41467-018-08066-0

In [None]:
import warnings
warnings.filterwarnings('ignore')

import xarray as xr
!pip install netcdf4
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 

import sklearn 
import sklearn.mixture as mix 
import scipy.stats as stats 

import seaborn as sns
sns.set()

In [None]:
# Buoy in Gulf of Maine
ds = xr.open_dataset('https://dods.ndbc.noaa.gov/thredds/dodsC/data/stdmet/44007/44007.ncml')
ds = ds.sel(time=slice('2021-3-30','2023-5-12'))
print(ds)

# Get buoy data from NOAA
noaa_df = ds.to_dataframe().reset_index()
print(noaa_df)

In [None]:
# Get sea surface temps
cov = noaa_df[['time','sea_surface_temperature']]
cov = cov.dropna().reset_index().drop(columns=['index'])
cov['sea_surface_temperature'] = cov['sea_surface_temperature'].rolling(24*14).mean()
cov = cov.dropna()
print(cov)

In [None]:
# Fit mixture model 
num_components = 2
dpgmm_model = mix.BayesianGaussianMixture(
    n_components=num_components, 
    weight_concentration_prior_type='dirichlet_process',
    n_init=1,
    max_iter=100)
p = dpgmm_model.fit_predict(cov['sea_surface_temperature'].to_numpy().reshape(-1,1))

# Count States
state_counts = np.zeros(num_components)
for M in p:
  state_counts[M] += 1 
print(state_counts)

In [None]:
# Plot States
fig2, ax2=plt.subplots()

for M in range(len(p)): 
  if p[M] == p[-1]:
    ax2.axvline(M, color='black', alpha=0.002) 
    
sns.lineplot(data=cov['sea_surface_temperature'].values, ax=ax2, alpha=0.7)
plt.ylabel('sst')
plt.xlabel('time')
plt.show()

In [None]:
# Plot States Grouped by Year 
fig,ax=plt.subplots()
data=[]
thresh = dpgmm_model.means_[1]
    
for M in range(len(p)): 
  if p[M] == p[-1]:
    yr = str(cov['time'].values[M])[:4]
    mo = str(cov['time'].values[M])[5:7]
    v = cov['sea_surface_temperature'].values[M]
    if v >= thresh:
        data.append(v)

#import scipy
#mean, var, skew, kurt = scipy.stats.genextreme.stats(data, moments='mvsk')
#print(skew)
#print(kurt)
sns.distplot(data)
plt.ylabel('sst')
plt.xlabel('time')
plt.show()