This notebook will guide you through the process of applying a Gaussian Mixture Model to a set of oceanographic profiles. The code was designed for application to chlorophyll profiles, but places where you might want to make modifications for use with other profiles will be noted. This notebook will also allow you to make sample figures akin to those shown in Echols, Rocap, and Riser (2021). 

This notebook reflects work done by Rosalind Echols as a PhD candidate at the University of Washington School of Oceanography. 

In [1]:
#import basic packages
import numpy as np
import xarray as xr
import random
import gsw

#import packages for PCA and GMM:
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
import pandas as pd

#import stats packages for analysis
from scipy.stats import norm
import scipy.stats

#import plotting packages; yes, I know Basemap is on the way out
import matplotlib.pyplot as plt
import cmocean
from mpl_toolkits.basemap import Basemap
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib as mpl

In [2]:
import warnings
warnings.filterwarnings("ignore")

#set path for saving figures/data if you want:
sfpath='/Users/rosalindechols/Documents/Generals/Self_Shading_Research/Analysis/PCA_GMM/FINAL/8pcs_21clusters/'
save_plots=False
    
depth=250
max_depth=251

#To make subsets for testing out the BIC criterion, you will need to make subsets.  
make_tests=True

#import QCed, filtered, interpolated file: 
f='all_chla_argo_250_5m.nc'

data=xr.open_dataset(f)

Some reported values of chlorophyll from Argo floats are negative; while this is obviously incorrect in the absolute sense, these values may be correct relative to the remainder of the profile. In addition, zero values will also not work for future steps in which we may want to take the logarithm of the data. To address both of these issues, we are going to create a useable version of our data by offsetting all the profiles by the same amount so that all values are positive. 

In [3]:
#minimum value is negative so you need to subtract. Could also add the absolute value. 

#check the current minimum and maximum values of the data
print("Original maximum and minimum: ", data['CHLA'].values.max(),data['CHLA'].values.min())

test_data=data['CHLA'].values-data['CHLA'].values.min()+0.01
print("New maximum and minimum: ",test_data.max(),test_data.min())

Original maximum and minimum:  43.20624542847094 -0.019999999552965164
New maximum and minimum:  43.236245428023906 0.01


Depending on the type of variable you are working with, the values associated with the profile at each depth may not be normally distributed. This is particularly true in the case of a variable like chlorophyll, where the values at all depths span multiple orders of magnitude, skewed towards low values. We can visualize this by plotting histograms at several different depths, and comparing this with the results after log-normalizing. This is why we need the test data set where we have adjusted so that all values are positive.

In [4]:
#do more with the histograms to show that they are skew and can be normalized
fig=plt.figure(figsize=(20,8))

for n in range(1,11):
    ax=fig.add_subplot(2,5,n)
    if n<6:
        ax.hist(test_data[:,(n-1)*10],bins=np.arange(0,2,0.1),edgecolor='k',facecolor='lightgray')
        ax.axvline(np.nanmedian(test_data[:,(n-1)*10]),c='r',linestyle='dashed')
        if n==1:
            ax.set_ylabel('Number of profiles',fontsize=14)
    else:
        ax.hist(np.log10(test_data[:,(n-6)*10]),bins=np.arange(-2,1,0.2),edgecolor='k',facecolor='lightgray')
        ax.axvline(np.nanmedian(np.log10(test_data[:,(n-6)*10])),c='r',linestyle='dashed')
        ax.set_title('%dm' %((n-6)*50),fontsize=14)
        if n==6:
            ax.set_ylabel('Number of profiles',fontsize=14)
        elif n==8:
            ax.set_xlabel('log(Chlorophyll (mg m$^{-3}$))',fontsize=18)

plt.tight_layout()

plt.subplots_adjust(top=0.9)
plt.suptitle('Chlorophyll (mg m$^{-3}$)',fontsize=18)

if save_plots==True:
    plt.savefig(sfpath+'concentration_hist.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

Before developing the Gaussian Mixture Model (GMM), we want to perform principal component analysis. This is ture for several reasons: 
1) Dimensionality reduction
2) Allows us to sort profiles into clusters based on profile shape rather than chlorophyll values

For chlorophyll profiles, we will be applying principal component analysis to the log-transformed data. For other variables, log-transformation is not necessary to produce reasonable clusters using GMM.

In [5]:
#Perform principal component analysis, with a target of 99% of the variance
pca=PCA(0.99)
pca.fit(np.log(test_data))

print("Explained variance ratio, by component: ", pca.explained_variance_ratio_)
print("Cumulative explained variance: ", np.cumsum(pca.explained_variance_ratio_))

fig=plt.figure(figsize=(10,6))
plt.plot(np.arange(1,pca.n_components_+1),np.cumsum(pca.explained_variance_ratio_),c='lightgray',zorder=1,lw=3)
plt.scatter(np.arange(1,pca.n_components_+1),np.cumsum(pca.explained_variance_ratio_),s=50,c='r',zorder=3)
plt.tick_params(labelsize=16)
plt.xlabel('Number of PCs', fontsize=18)
plt.ylabel('Explained Variance',fontsize=18)
plt.axhline(0.95,linestyle='--',c='teal',label='95%')
plt.axhline(0.975,linestyle=':',c='teal',label='97.5%')
plt.legend(loc='best',fontsize='x-large')
plt.xlim(1,17)
if save_plots==True:
    plt.savefig(sfpath+'PCA_expl_variance.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

Explained variance ratio, by component:  [0.45247978 0.29244187 0.13169803 0.04863126 0.02260199 0.01314921
 0.00807264 0.00549772 0.00385226 0.00287422 0.00220218 0.00173911
 0.00141432 0.00116742 0.00099608 0.00084743 0.0007502 ]
Cumulative explained variance:  [0.45247978 0.74492165 0.87661967 0.92525093 0.94785292 0.96100214
 0.96907477 0.97457249 0.97842475 0.98129897 0.98350115 0.98524026
 0.98665458 0.987822   0.98881808 0.98966551 0.99041571]


We also want to visualize the principal components. For the dataset provided here, the first 8 principal components explain ~97.5% of the variance, and includes all the PCs that individually explain >0.5% of the variance. We will plot these and use them for the subsequent clustering; however, it is possible to select a different number of PCs for the later analysis and compare to the results produced by only using 8 PCs

In [6]:
colors = cmocean.cm.matter(np.linspace(0,1,8))
fig=plt.figure(figsize=(16,8))
for i in range(0,8):
    fig.add_subplot(2,4,i+1)
    plt.plot(pca.components_[i],np.arange(0,251,5),linewidth=3,c=colors[i])
    if i==5:
        plt.text(0,310,'Chlorophyll (mg m$^{-3}$)',fontsize=18)
    if i==4:
        plt.text(-0.35,30,'Depth (dbar)',fontsize=18,rotation=90)
    plt.ylim(250,0)
    plt.tick_params(labelsize=14)

if save_plots==True:
    plt.savefig(sfpath+'PCA_comps.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

Gaussian Mixture Models require pre-selection of a number of clusters. In order to select a number of clusters, we need to use a method that evaluates the tradeoffs between number of clusters and quality of fit. Both the Bayesian Information Criterion (BIC) and Akaike Information Criteria (AIC) accomplish this, although BIC has a harsher penalty as the number of clusters increases. In order to investigate the optimal number of clusters, we will first subset the data, selecting one profile from each 1x1 degree box in which data has been gathered. Using 20 different randomly selected test-sets following this approach, we can estimate the trend in BIC as the number of clusters increases as well as the uncertainty. 

In [7]:
#make subsets for future use in testing out the BIC criterion for varying numbers of clusters
if make_tests==True:
    #Initialize array of nans; will remove unused data later
    tlist=np.nan*np.ones((20,len(data['LATITUDE'])//5))
    
    #sort data by latitude
    lat_sort=sorted(data['LATITUDE'].values)
    
    #retain original indices of sorted data
    lat_sort_ind=data['LATITUDE'].values.argsort()
    
    start=0
    count=0
    for n in range(-76, 78):
        if n%10==0:
            print("Current Latitude: ",n)
        #find 1-degree latitude subset
        end=next(i for i, j in enumerate(lat_sort) if j>=n+1)
        for nn in np.arange(-180,180.1):
            lon_set=[i for i,j in enumerate(data['LONGITUDE'].values[lat_sort_ind[start:end]]) if nn <= j < nn+1]
            if len(lon_set)==0:
                pass
            else:
                for nnn in range(0,20):
                     #randomly select 20 times to create test subsets
                    #for 1x1 boxes with a small number of profiles, this won't do much
                    tlist[nnn,count]=random.choice(lat_sort_ind[start:end][lon_set])
                count+=1
        start=end

    #only retain the values in the array that are actually used
    tlist=np.array(tlist[:,0:count],dtype='int')
    np.savetxt(sfpath+"1x1_test_lists.csv", tlist, delimiter=",")
    print('DONE')

Current Latitude:  -70
Current Latitude:  -60
Current Latitude:  -50
Current Latitude:  -40
Current Latitude:  -30
Current Latitude:  -20
Current Latitude:  -10
Current Latitude:  0
Current Latitude:  10
Current Latitude:  20
Current Latitude:  30
Current Latitude:  40
Current Latitude:  50
Current Latitude:  60
Current Latitude:  70
DONE


Let's plot the locations of one of these subsets just to make sure the code is doing what we think it is doing. We can do this in two ways: first, check to make sure that the profiles are distributed across all latitudes. At the same time, we can see how the distribution of the test profiles compares to the overall distribution of profiles. Latitudes that are better represented in the test set than in the complete data set have data across more longitudes (i.e. the Southern Ocean) than in other areas (i.e. the Mediterranean Sea). 

In [8]:
fig=plt.figure(figsize=(16,8))
ax1=fig.add_subplot(1,2,1)
ax1.hist(data['LATITUDE'].values,bins=np.arange(-75,76,10),edgecolor='m',linewidth=2,facecolor='gray',label='All Profiles',alpha=0.5,density=1)
ax1.set_ylabel('Fraction of Profiles',fontsize=18)
ax1.set_ylim(0,0.025)
plt.tick_params(labelsize=16)
ax1.set_xlabel('Latitude',fontsize=18)
plt.legend(loc='best',fontsize='x-large')
ax2=fig.add_subplot(1,2,2)
ax2.hist(data['LATITUDE'].values[tlist[0]],bins=np.arange(-75,76,10),edgecolor='m',linewidth=2,facecolor='gray',label='Test Set',alpha=0.5,density=1)
plt.tick_params(labelsize=16)
plt.legend(loc='best',fontsize='x-large')
ax2.set_xlabel('Latitude',fontsize=18)
ax2.set_ylim(0,0.025)
if save_plots==True:
    plt.savefig(sfpath+'latitude_hist.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

A second method is tos simply plot the points on a map. If it's working, we should see data points all over the globe (excepting areas where there is no data).

In [9]:
fig=plt.figure(figsize=(12,10))
m = Basemap(projection='robin',lon_0=-180,resolution='c')
m.drawcoastlines()
m.fillcontinents(color='gray')
cp=m.scatter(data['LONGITUDE'].values[tlist[0]],data['LATITUDE'].values[tlist[0]],latlon='True',s=5)
m.drawparallels(np.arange(-90., 81., 30.), labels = [1,0,0,0], fontsize = 16)
m.drawmeridians(np.arange(-180., 181., 90.), labels = [0,0,1,0], fontsize = 16)
if save_plots==True:
    plt.savefig(sfpath+'map_test_set.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

We will now transform our data into PC space, and test out a range of clustering options for each of our subsets. This step takes a long time because the GMM algorithm is computationally expensive. The more principal components you decide to retain, the longer this will take. You can also test it out with different number of PCS and compare the outcome. Just be prepared to run the cell and walk away for a while. 

In [10]:
#Define empty arrays to store AIC and BIC values; we will only end up using the BIC values, but it is 
#instructional to see

#do you want to save the criterion data?
save_BIC=True
#select the number of principal components to use:
pcs=17
#set the maximum number of clusters you want to test:
max_clusters=50

#transform profiles into PC space:
pca=PCA(n_components=pcs)
pca.fit(np.log(test_data))
training=pca.transform(np.log(test_data))

BIC=np.zeros((len(tlist),max_clusters))
AIC=np.zeros((len(tlist),max_clusters))
for i in range(0,len(tlist)):
    if i%5==0:
        print("Currently testing subset %d." %(i+1))
    n_components=np.arange(1, 51)
    subset=training[tlist[i]]
    models = [GaussianMixture(n, covariance_type='full', random_state=0).fit(subset) for n in n_components]
    BIC[i]=[m.bic(subset) for m in models]
    AIC[i]=[m.aic(subset) for m in models]
    
if save_BIC==True:
    np.savetxt(sfpath+"pca%d_BIC_results.txt" %pcs, BIC, delimiter=",")
    np.savetxt(sfpath+"pca%d_AIC_results.txt" %pcs, AIC, delimiter=",")

Currently testing subset 1.
Currently testing subset 6.
Currently testing subset 11.
Currently testing subset 16.


We can now plot the BIC (and AIC if we want) to see the range of cluster numbers that are ideal for this problem. It is unlikely that a single "best" number of clusters will emerge from using BIC, which means that we will need to "use our judgement" at some point. 

In [11]:
plot_AIC=False

#plot BIC
min_BIC=next([i+1,j] for i,j in enumerate(np.mean(BIC,axis=0)) if j==min(np.mean(BIC,axis=0)))
print("Minimum BIC value = %1.0f at %d clusters" %(min_BIC[1],min_BIC[0]))
fig=plt.figure(figsize=(10,8))
plt.errorbar(n_components,np.mean(BIC,axis=0),yerr=np.std(BIC,axis=0),c='b',label='BIC',lw=3)

#plot AIC
if plot_AIC==True:
    plt.errorbar(n_components,np.mean(AIC,axis=0),yerr=np.std(BIC,axis=0),c='r',label='AIC',lw=3)

plt.legend(loc='best',fontsize='x-large')
plt.xlabel('Number of Clusters',fontsize=18)
plt.ylabel('BIC score',fontsize=18)
plt.tick_params(labelsize=16)
plt.ylim(min(np.mean(BIC,axis=0))-0.1*min(np.mean(BIC,axis=0)),np.mean(BIC,axis=0)[5])
plt.xlim(5,35)

if save_plots==True:
    plt.savefig(sfpath+'BIC_plot.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

Minimum BIC value = 75537 at 20 clusters


This shows that (for 17 PCs) there is a wide approximate minimum between 17 and 23 clusters. For the remaining sections of this notebook, we will use 21 clusters, for reasons that will be discussed in relation to some of the figures. However, all of the sections can be run with a different number of clusters for comparison. 

In [20]:
#Set number of clusters to use taken from decision based on BIC plot
gmm_comps=21

#redoing the PCA here in case you were experimenting with different #s of PCs in previous cells
pca=PCA(0.99)
pca.fit(np.log(test_data))
training=pca.transform(np.log(test_data))

gmm = GaussianMixture(n_components=gmm_comps)
#this is where you train
gmm.fit(training)
probs = gmm.predict_proba(training)
#the labels are the groups; there is no de facto order to the clusters, so you will get different
#cluster numbers each time you run this, even if a particular profile is always assigned to the same
#cluster; we will sort the groups later
labels = gmm.predict(training)

One useful first plot to make is a histogram of how many profiles are in each group. Although the BIC criterion in theory allowed us to select a quasi-optimal number of clusters, it is still helpful to know what this looks like in practice. Are some of the clusters really small? Is there one cluster that has a disproportionately large number of profiles assigned to it? This is a very superficial look, but starts helping us understand the data. 

In [21]:
fig=plt.figure(figsize=(10,8))
plt.hist(labels+1, bins=np.arange(1,gmm_comps+2),edgecolor='k',linewidth=2,facecolor='lightgray')
plt.xlabel('Cluster Number', fontsize=18)
plt.ylabel('Number of Profiles',fontsize=18)
plt.tick_params(labelsize=16)
if save_plots==True:
    plt.savefig(sfpath+'GMM_cluster_hist.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

We now want to start looking at the characteristics of each cluster. Because we did PCA, looking at the PC statistics may not be super informative. Instead, we can start looking at average profile shape, probability of assignment to each cluster, geographical distribution, and so on, all derived from the raw profiles assigned to each group.  

We'll start by looking at general profile characteristics for each group. We can do this in a variety of ways: an average profile; a weighted average profile (weighted by the probability that an individual profile is assigned to a particular group); and a median profile. Because GMM assigns a probability that each item is in a particular group, the weighted average gives more weight to the profiles that have the highest probability of being in a particular group. Profiles that may be bordering between two groups and thus have a lower probability of being in either group will have less weight in constructing that profile. 

In [22]:
print('Find averages')
#regular average
group_ave={}
#weighted average
group_wave={}
#median
group_median={}
#standard deviation
group_std={}
#weighted standard deviation
group_wstd={}
for var in ['CHLA']:
    group_ave[var]=np.nan*np.ones((gmm_comps,51))
    group_wave[var]=np.nan*np.ones((gmm_comps,51))
    group_median[var]=np.nan*np.ones((gmm_comps,51))
    group_std[var]=np.nan*np.ones((gmm_comps,51))
    group_wstd[var]=np.nan*np.ones((gmm_comps,51))

for var in ['CHLA']:
    for m in range(gmm_comps):
        d=[i for i,j in enumerate(labels) if j==m]
        #regular average
        group_ave[var][m]=np.nanmean(data[var][d],0)
        #weighted average
        group_wave[var][m]=np.average(data[var][d],0,weights=probs[d,m])
        #standard deviation
        group_std[var][m]=np.nanstd(data[var][d],0)
        #weighted standard deviation
        group_median[var][m]=np.nanmedian(data[var][d],0)
        variance = np.average((data[var][d]-group_ave[var][m])**2,0, weights=probs[d,m])
        group_wstd[var][m]=variance**0.5

Find averages


The labeling of groups coming out of GMM is random, and thus group 0 for any given run does not represent anything particularly special (and would likely be different from group 0 if you ran the exact same code a second time). One way we can get around this is by sorting the average profiles for each group based on some characteristic of that average profile, like the surface chlorophyll. For the purposes of exploring the data, we'll do that here, but there are many other options for how it might make sense to sequence the clusters. 

In [23]:
def sort_chl_profs(data,labels,groups=12,keyword='surf'):
    #resequence CHL groups
    #find surface chl
    new_chl=np.zeros(groups)
    #print(groups,len(data['CHLA']))
    count=0
    for prof in data['CHLA']:
        if keyword=='surf':
            new_chl[count]=prof[0]
        else:
            new_chl[count]=np.trapz(prof)
        count+=1
    
    new_groups=sorted(range(len(new_chl)), key=lambda k: new_chl[k])
    
    new_labels=np.zeros(len(labels))
    for n in range(0,len(labels)):
        new_labels[n]=next(i for i, j in enumerate(new_groups) if j==labels[n])  
    
    return new_groups,new_labels

#select a sorting group (group_ave, group_wave, or group_median) for sorting by the surface chlorophyll
#associated with each group average.
sorting_group=group_wave
sorting_std=group_wstd

new_groups,new_labels=sort_chl_profs(sorting_group,labels,groups=gmm_comps,keyword='surf')

print(new_groups)

[5, 0, 10, 3, 7, 18, 13, 2, 19, 6, 20, 16, 15, 8, 12, 1, 9, 17, 4, 14, 11]


Now we can start visualizing each of clusters and exploring the statistics of the GMM output. First up, we'll look at individual clusters. The clusters are now sequenced based the surface chlorophyll of whichever central tendency statistic you selected in the previous cell. First, we'll just plot all the profiles. This will be busy but will give us a quick snapshot of the overal cluster shapes.

In [24]:
fig=plt.figure(figsize=(21,9))
color = cmocean.cm.matter(np.linspace(0,1,gmm_comps))
for n in range(gmm_comps):
    ax=fig.add_subplot(3,7,n+1)
    ax.plot(sorting_group['CHLA'][new_groups[n]],data['DEPTH'][0],lw=3,c=color[n],zorder=3)
    if n==7:
        ax.set_ylabel('Depth',fontsize=18)
    if n==17:
        ax.set_xlabel('Chlorophyll (mg m$^{-3}$)',fontsize=18)
    ax.tick_params(labelsize=14)
    ax.set_ylim(max_depth,0)
    #set all the x-limits to the same for better comparison (even though
    #some clusters will have much higher concentrations)
    ax.set_xlim(-0.01,1.3)
    
plt.tight_layout()
if save_plots==True:
    plt.savefig(sfpath+'profiles/GMM_all_profs.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

Next up, for each of these clusters, we can look a little more carefully at what's going on in each cluster. Lets look at several features to start:
1. The central tendency profile (average, weighted average, or median) with a subset of individual profiles
2. The probability distribution for that cluster (i.e. what is the distribution of probabilities for profiles assigned to that group)
3. The difference between the average, weighted average, and median

In [25]:
#select groups to plot below
plot_group=np.arange(21)

color = cmocean.cm.matter(np.linspace(0,1,gmm_comps))

for nnn in plot_group:
    #plot average profile and representative subset
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[20, 10])
    ax1.plot(sorting_group['CHLA'][new_groups[nnn]],data['DEPTH'][0],lw=3,c=color[nnn],zorder=3)
    ax1.plot(sorting_group['CHLA'][new_groups[nnn]]+sorting_std['CHLA'][new_groups[nnn]],data['DEPTH'][0],lw=2,linestyle='--',c=color[nnn],zorder=3)
    ax1.plot(sorting_group['CHLA'][new_groups[nnn]]-sorting_std['CHLA'][new_groups[nnn]],data['DEPTH'][0],lw=2,linestyle='--',c=color[nnn],zorder=3)
    d=[i for i,j in enumerate(labels) if j==new_groups[nnn]]
    for m in range(0,len(d),50):
        ax1.plot(data['CHLA'][d][m],data['DEPTH'][0],lw=0.5,alpha=0.5,color='lightgray',zorder=1)

    #plt.legend(loc='best',fontsize='large')
    ax1.set_ylim(max_depth,0)
    ax1.set_xlim(0-sorting_std['CHLA'][new_groups[nnn]].max(),sorting_group['CHLA'][new_groups[nnn]].max()+2*sorting_std['CHLA'][new_groups[nnn]].max())
    ax1.set_ylabel('Depth (m)',fontsize=18)
    ax1.set_xlabel('Chlorophyll (mg m$^{-3}$)',fontsize=18)
    ax1.tick_params(labelsize=16)
    
    axins = inset_axes(ax1, width=3, height=2,loc=4,borderpad=2)
    axins.hist(probs[d,new_groups[nnn]],bins=np.arange(0,1.05,0.1),edgecolor='k',facecolor='lightgray')
    axins.axhline(0.75*len(d),color='red',linestyle='dotted',lw=0.5)
    axins.axhline(0.9*len(d),color='red',linestyle='dashed',lw=0.5)
    axins.set_ylabel('# of Profiles')
    axins.set_xlabel('Cluster probability')
    axins.xaxis.tick_top()
    axins.xaxis.set_label_position('top') 
    
    #plot each central tendency profile for comparison
    ax2.plot(group_ave['CHLA'][new_groups[nnn]],data['DEPTH'][0],lw=3,c=color[2],zorder=3,label='Average')
    ax2.plot(group_wave['CHLA'][new_groups[nnn]],data['DEPTH'][0],lw=3,c=color[10],zorder=2,label='Weighted Average')
    ax2.plot(group_median['CHLA'][new_groups[nnn]],data['DEPTH'][0],lw=3,c=color[18],zorder=1,label='Median')
    ax2.set_ylim(max_depth,0)
    ax2.set_xlim(0-sorting_std['CHLA'][new_groups[nnn]].max(),sorting_group['CHLA'][new_groups[nnn]].max()+sorting_std['CHLA'][new_groups[nnn]].max())
    ax2.set_xlabel('Chlorophyll (mg m$^{-3}$)',fontsize=18)
    ax2.tick_params(labelsize=16)
    ax2.legend(loc='best',fontsize='x-large')
    
    if save_plots==True:
        plt.savefig(sfpath+'profiles/GMM_mean_prof_cluster%d.pdf' %nnn,format='pdf',bbox_inches='tight')
        plt.close()
    else:
        plt.show()

Another useful thing to look at to start understanding the clusters is to look at where they occur. As with looking at the profiles, we can select a cluster number to look at and run the code in the following cell, which will produce a map showing where those profiles occurred (and when)

In [26]:
#first we'll determine the months and seasons for each profile. To be able to look at N/S similarity,
#we'll look at everything in terms of equivalent northern month.

dates = pd.to_datetime(data['JULD'].values).month
print('Find months')
months=np.zeros(len(dates))
for m in range(0,len(months)):
    if data['LATITUDE'].values[m]>=0:
        months[m]=dates[m]
    else:
        if dates[m]>=7:
            months[m]=dates[m]-6
        else:
            months[m]=dates[m]+6
print('Find seasons')
seasons_north=[[12,1,2],[3,4,5],[6,7,8],[9,10,11]]
seasons=np.zeros(len(months))
for n in range(0,len(months)):
    for nn in range(0,len(seasons_north)):
        if months[n] in seasons_north[nn]:
            seasons[n]=nn

Find months
Find seasons


In [27]:
plot_group=np.arange(21)

for nnn in plot_group:
    subset=[i for i,j in enumerate(labels) if j==new_groups[nnn]]

    cmap = cmocean.cm.haline  # define the colormap
    # extract all colors from the color map
    cmaplist = [cmap(i) for i in range(cmap.N)]

    # create the new segmented (vs. continuous) map
    cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'Custom cmap', cmaplist, cmap.N)
    bounds = np.linspace(-0.5, 3.5, num=5)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

    fig=plt.figure(figsize=(14,7))
    m = Basemap(projection='robin',lon_0=-0,resolution='c')
    cp=m.scatter(data['LONGITUDE'].values[subset],data['LATITUDE'].values[subset],c=seasons[subset],s=5,cmap=cmap, norm=norm,latlon='True')
    cp.set_rasterized(True)
    m.drawcoastlines()
    m.fillcontinents(color='gray')
    m.drawparallels(np.arange(-60., 61., 30.), labels = [1,0,0,0], fontsize = 16)
    m.drawmeridians(np.arange(-180., 181., 60.), labels = [0,0,0,1], fontsize = 16)
    cbar=fig.colorbar(cp,ticks=[0,1,2,3],shrink=0.75)
    cbar.ax.set_yticklabels(['Winter','Spring','Summer','Fall'],fontsize=14)
    cbar.ax.tick_params(labelsize=14)

    if save_plots==True:
        plt.savefig(sfpath+'maps/GMM_map_cluster_%d.pdf' %nnn,format='pdf',bbox_inches='tight')
        plt.close()
    else:
        plt.show()

It is not always clear from the map plots what the actual distribution across seasons looks like for each cluster, so we can get a better view of this with some histograms. 

In [28]:
#histograms of "months" and "seasons"-->everything is normalized to northern months to make N/S comparison easier. 

fig=plt.figure(figsize=(22,9))
for n in range(gmm_comps):
    d=[i for i,j in enumerate(labels) if j==new_groups[n]]
    ax=fig.add_subplot(3,7,n+1)
    ax.hist(seasons[d],bins=np.arange(0,5),edgecolor='k',facecolor='lightgray',linewidth=2)
    ax.tick_params(labelsize=14)
    if n==7:
        ax.set_ylabel('Number of Profiles',fontsize=18)
    if n==17:
        ax.set_xlabel('Season',fontsize=18)
    if n<=13:
        frame1 = plt.gca()
        frame1.axes.get_xaxis().set_ticks([0.5,1.5,2.5,3.5])
        frame1.axes.get_xaxis().set_ticklabels([])
    elif n>13:
        frame1 = plt.gca()
        frame1.axes.get_xaxis().set_ticks([0.5,1.5,2.5,3.5])
        frame1.axes.get_xaxis().set_ticklabels(['Winter','Spring','Summer','Fall'],rotation=45)
        

plt.tight_layout()

if save_plots==True:
    plt.savefig(sfpath+'GMM_seasons_hist.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

We can also look at some statistics in other ways. To make this as useful as possible, we will need to do some manual sorting. For example, based on the plot of the all of the average profiles above, it looks like we have three broad categories of profiles: pure SCM (low surface Chl, peak at depth); well-mixed/sigmoid (looks a lot like a typical mixed layer); and combination (higher surface Chl, but still with a peak at depth). For the 21 clusters we're working with, these can be broken down as follows (if you decided to tinker, you'll need to do this part manually in the next cell):
scm_group = [0,1,2,3,5]
well_mixed = [4,6,7,9,10,11,12,17]
combo = [8,13,14,15,16,18,19,20]

Some of these are a bit of a judgment call and there may not be a rigid delineation in all cases. For the next bit of analysis, we just need to get close in order for it to be useful. First, we'll plot the average profiles for each of the sub-groups on the same axes, which will allow us to see the range of profile shapes across a particular cluster.

In [29]:
scm_group = [0,1,2,3,5]
well_mixed = [4,6,7,9,10,11,12,17]
combo = [8,13,14,15,16,18,19,20]

fig=plt.figure(figsize=(18,8))
count=1
for g1 in ['scm_group','well_mixed','combo']:
    if g1=='scm_group':
        # test=group1
        g=scm_group
    elif g1=='well_mixed':
        # test=group2
        g=well_mixed
    else:
        # test=group3
        g=combo
    
    color = cmocean.cm.matter(np.linspace(0,1,len(g)))
    ax1=fig.add_subplot(1,3,count)
    nn=0
    for n in g:
        plt.plot(group_wave['CHLA'][new_groups[n]],data['DEPTH'][0],lw=4,c=color[nn])
        nn+=1
    plt.ylim(250,0)
    plt.xlabel('Chlorophyll (mg m$^{-3}$)',fontsize=18)
    plt.ylabel('Depth (m)', fontsize=18)
    plt.tick_params(labelsize=16)
    # plt.legend(loc=4,fontsize='x-large')
    count+=1
plt.tight_layout()

if save_plots==True:
    plt.savefig(sfpath+'all_profs_sorted.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

We're going to make a "heatmap" that allows us to look at the two highest probabilities from the GMM output. This way we can see how the highest probability and second highest probability relate in terms of profile shape. If the algorithm is working well, we would expect the second highest probability cluster to share many of the same characteristics as the highst probability group. 

In [30]:
#make a heatmap
#sort the probabilities for each profile and retain the location of the 
#two highest probabilities; from a previous cell, the probability output
#has the variable name "probs"
comps=[np.argsort(probs)[-1:-3:-1] for probs in probs]
#select the two highest probabilities
prob=[probs[n][comps[n]] for n in range(0, len(probs))]

#sort groups so that similar shapes are near each other
new_seq=[0,1,2,3,5,4,6,7,9,10,11,12,17,8,13,14,15,16,18,19,20]
ax_lab=[1,2,3,4,6,5,7,8,10,11,12,13,18,9,14,15,16,17,19,20,21]

locs={}
for n in range(0,21):
    #take sorted list and organize by profile shape
    locs[str(new_groups[n])]=next(i for i,j in enumerate(new_seq) if j==n)

pair_array=np.zeros((21,21))
mags=np.zeros(21)

count=0
for c in comps:
#     # print(locs[str(c[0])],locs[str(c[1])])
#     # pair_array[locs[str(c[0])],locs[str(c[1])]]+=1
    pair_array[locs[str(c[0])],locs[str(c[1])]]+=prob[count][1]
    pair_array[locs[str(c[0])],locs[str(c[0])]]+=prob[count][0]
    mags[locs[str(c[0])]]+=1
    count+=1

# # for n in range(0,21):
#     # pair_array[n,n]=np.nan
from matplotlib import colors
fig=plt.figure(figsize=(12,10))

cp=plt.pcolor(pair_array.transpose()/mags,cmap=cmocean.cm.amp,vmin=1e-2,vmax=1,norm=colors.LogNorm())
plt.tick_params(labelsize=16)
#plt.grid()
#plt.xticks(ticks=[0,3,6,9,12,15,18],labels=np.array(new_seq)[0,3,6,9,12,15,18]+1)
#plt.yticks(ticks=[0,3,6,9,12,15,18],labels=np.array(new_seq)[0,3,6,9,12,15,18]+1)
plt.xticks(ticks=np.arange(0.5,21.5,1),labels=ax_lab)
plt.yticks(ticks=np.arange(0.5,21.5,1),labels=ax_lab)
plt.axvline(5,linestyle='dashed',color='k',linewidth=2)
plt.axvline(13,linestyle='dashed',color='k',linewidth=2)
plt.axhline(5,linestyle='dotted',color='k',linewidth=2)
plt.axhline(13,linestyle='dotted',color='k',linewidth=2)
plt.xlabel('Highest Probability Group',fontsize=20)
plt.ylabel('Second Highest Probability Group',fontsize=20)
cbar=fig.colorbar(cp)
cbar.set_label('Relative frequency',fontsize=18)
cbar.ax.tick_params(labelsize=16)

if save_plots==True:
    plt.savefig(sfpath+'GMM_probs_heatmap.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

To view all of the data spatially and temporally in a simplified manner, we can make a modified Hovmuller plot. This is useful for gathering a general impression of the data, but should really just be a precursor to further analysis given that the whole globe is plotted in a 1D slide for each time point. 

In [32]:
cmap = cmocean.cm.matter  # define the colormap
# extract all colors from the .jet map
cmaplist = [cmap(i) for i in range(cmap.N)]

# create the new map
cmap = mpl.colors.LinearSegmentedColormap.from_list(
    'Custom cmap', cmaplist, cmap.N)
bounds = np.linspace(0, 22, num=22)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

fig=plt.figure(figsize=(14,6))
cp=plt.scatter(data['LATITUDE'],months,c=new_labels+1,marker='|',s=700,alpha=0.9,cmap=cmap, norm=norm)
plt.tick_params(labelsize=14)
#plt.ylabel('Season',fontsize=16)
plt.xlim(-76,79)
plt.xlabel('Latitude',fontsize=16)
frame1 = plt.gca()
frame1.axes.get_yaxis().set_ticks([1,4,7,10])
frame1.axes.get_yaxis().set_ticklabels(['Winter','Spring','Summer','Fall'])
cbar=fig.colorbar(cp,ticks=[2,4,6,8,10,12,14,16,18,20])
cbar.ax.tick_params(labelsize=14)
# cbar.ax.set_xticklabels()
cbar.set_label('Group',fontsize=14)
plt.tight_layout()

if save_plots==True:
    plt.savefig(sfpath+'GMM_hov.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

This plot might be more useful viewed in a couple alternative setups: first, separated out by basin; and second, separated out across the three broad groups of profile shapes. This will help us understand which of the broad patterns shown above are global in nature in which are specific to a certain basin. For the basin separation, we will also separate out the Mediterranean and Black Seas.

In [33]:
#define subgroups for each basin
med_black=[]
indian=[]
atlantic=[]
pacific=[]

#sort each profile into the appropriate bin
#these are sequenced from easiest to separate to most difficult to separate
for n in range(0,len(data['LATITUDE'])):
    if 0<data['LONGITUDE'].values[n]<42.5 and 30<data['LATITUDE'].values[n]<47:
        med_black.append(n)
    elif 30<data['LONGITUDE'].values[n]<120 and -80<data['LATITUDE'].values[n]<30:
        indian.append(n)
    elif -70<data['LONGITUDE'].values[n]<30 and -80<data['LATITUDE'].values[n]<=0:
        atlantic.append(n)
    elif -80<data['LONGITUDE'].values[n]<30 and 0<data['LATITUDE'].values[n]<80:
        atlantic.append(n)
    else:
        pacific.append(n)
    
print('Number of Mediterranean/Black Sea profiles: %d' %len(med_black))
print('Number of Indian Ocean profiles: %d' %len(indian))
print('Number of Atlantic Ocean profiles: %d' %len(atlantic))
print('Number of Pacific Ocean profiles: %d' %len(pacific))

Number of Mediterranean/Black Sea profiles: 9775
Number of Indian Ocean profiles: 20829
Number of Atlantic Ocean profiles: 21348
Number of Pacific Ocean profiles: 19320


In [37]:
#now we can repeat the HOV plotting, but with the basins separated out into subplots
cmap = cmocean.cm.matter  # define the colormap
# extract all colors from the .jet map
cmaplist = [cmap(i) for i in range(cmap.N)]

regions=[indian,atlantic,pacific]
titles=['Indian Ocean','Atlantic Ocean','Pacific Ocean']

# create the new map
cmap = mpl.colors.LinearSegmentedColormap.from_list(
    'Custom cmap', cmaplist, cmap.N)
bounds = np.linspace(0, 22, num=22)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

fig=plt.figure(figsize=(12,10))
for n in range(0,3):
    ax=fig.add_subplot(3,1,n+1)
    cp=ax.scatter(data['LATITUDE'][regions[n]],months[regions[n]],c=new_labels[regions[n]]+1,marker='|',s=700,alpha=0.9,cmap=cmap, norm=norm)
    ax.tick_params(labelsize=14)
    ax.set_xlim(-76,79)
    ax.set_title(titles[n],fontsize=16)
    if n==2:
        ax.set_xlabel('Latitude',fontsize=16)
        frame1 = plt.gca()
        frame1.axes.get_yaxis().set_ticks([1,4,7,10])
        frame1.axes.get_yaxis().set_ticklabels(['Winter','Spring','Summer','Fall'])
    else:
        frame1 = plt.gca()
        frame1.axes.xaxis.set_ticklabels([])
        frame1.axes.get_yaxis().set_ticks([1,4,7,10])
        frame1.axes.get_yaxis().set_ticklabels(['Winter','Spring','Summer','Fall'])

plt.tight_layout()
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.95, 0.25, 0.03, 0.5])
cbar = fig.colorbar(cp,cax=cbar_ax,orientation='vertical',ticks=[2,4,6,8,10,12,14,16,18,20])
cbar.ax.tick_params(labelsize=14)
# cbar.ax.set_xticklabels()
cbar.set_label('Group',fontsize=14)

if save_plots==True:
    plt.savefig(sfpath+'GMM_hov_basin.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

In [35]:
#this definition of the subgroups is copied from above. As stated earlier, these groups were selected
#manually based on the original data, PCA, and GMM paramaters. If you depart significantly from those,
#you may need to change these groups.
scm_group = [0,1,2,3,5]
well_mixed = [4,6,7,9,10,11,12,17]
combo = [8,13,14,15,16,18,19,20]

scm_subset=[i for i,j in enumerate(labels) if j in np.array(new_groups)[scm_group]]
well_mixed_subset=[i for i,j in enumerate(labels) if j in np.array(new_groups)[well_mixed]]
combo_subset=[i for i,j in enumerate(labels) if j in np.array(new_groups)[combo]]

print('Total profiles in SCM group: %d' %len(scm_subset))
print('Total profiles in Well-Mixed group: %d' %len(well_mixed_subset))
print('Total profiles in Combo group: %d' %len(combo_subset))

Total profiles in SCM group: 16631
Total profiles in Well-Mixed group: 24828
Total profiles in Combo group: 29813


In [36]:
#we can repeat the HOV plotting again, but with each subgroup separated into subplots
regions=[scm_subset,well_mixed_subset,combo_subset]
titles=['SCM','Well-Mixed','Combo']

# create the new map
cmap = mpl.colors.LinearSegmentedColormap.from_list(
    'Custom cmap', cmaplist, cmap.N)
bounds = np.linspace(0, 22, num=22)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

fig=plt.figure(figsize=(12,10))
for n in range(0,3):
    ax=fig.add_subplot(3,1,n+1)
    cp=ax.scatter(data['LATITUDE'][regions[n]],months[regions[n]],c=new_labels[regions[n]]+1,marker='|',s=700,alpha=0.9,cmap=cmap, norm=norm)
    ax.tick_params(labelsize=14)
    ax.set_xlim(-76,79)
    ax.set_title(titles[n],fontsize=16)
    if n==2:
        ax.set_xlabel('Latitude',fontsize=16)
        frame1 = plt.gca()
        frame1.axes.get_yaxis().set_ticks([1,4,7,10])
        frame1.axes.get_yaxis().set_ticklabels(['Winter','Spring','Summer','Fall'])
    else:
        frame1 = plt.gca()
        frame1.axes.xaxis.set_ticklabels([])
        frame1.axes.get_yaxis().set_ticks([1,4,7,10])
        frame1.axes.get_yaxis().set_ticklabels(['Winter','Spring','Summer','Fall'])

plt.tight_layout()
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.95, 0.25, 0.03, 0.5])
cbar = fig.colorbar(cp,cax=cbar_ax,orientation='vertical',ticks=[2,4,6,8,10,12,14,16,18,20])
cbar.ax.tick_params(labelsize=14)
# cbar.ax.set_xticklabels()
cbar.set_label('Group',fontsize=14)

if save_plots==True:
    plt.savefig(sfpath+'GMM_hov_subgroup.pdf',format='pdf',bbox_inches='tight')
    plt.close()
else:
    plt.show()

We can also save all this data (both the original profiles as well as the PCA, GMM, and average profile data). That way you can work with the data offline later. 

In [42]:
print('Find averages')
for var in ['DOXY','NITRATE','TEMP','PSAL']:
    group_wave[var]=np.nan*np.ones((gmm_comps,51))
    group_wstd[var]=np.nan*np.ones((gmm_comps,51))

for var in ['DOXY','NITRATE','TEMP','PSAL']:
    masked_data = np.ma.masked_array(data[var], np.isnan(data[var]))
    #masked_data = np.ma.masked_array(masked_data, np.isinf(masked_data))
    for m in range(gmm_comps):
        d=[i for i,j in enumerate(labels) if j==m]
        #weighted average
        group_wave[var][m]=np.average(masked_data[d],0,weights=probs[d,m])
        #weighted standard deviation
        variance = np.average((masked_data[d]-group_wave[var][m])**2,0, weights=probs[d,m])
        group_wstd[var][m]=variance**0.5

print('Save Data')
gmm_data={}
gmm_data['gmm_means']={'dims':('n_comps','pcs'),'data':gmm.means_}
gmm_data['gmm_probs']={'dims':('profs','n_comps'),'data':probs}
gmm_data['gmm_cov']={'dims':('n_comps','pcs','pcs'),'data':gmm.covariances_}
gmm_data['gmm_labels']={'dims':('profs'),'data':labels}
gmm_data['gmm_new_groups']={'dims':('n_comps'),'data':new_groups}
gmm_data['gmm_new_labels']={'dims':('profs'),'data':new_labels}
gmm_data['pc_comps']={'dims':('pcs','z'),'data':pca.components_}
gmm_data['pc_transform']={'dims':('profs','pcs'),'data':training}
gmm_data['pc_var']={'dims':('pcs'),'data':pca.explained_variance_ratio_}

for var in ['CHLA','DOXY','TEMP','PSAL','NITRATE']:
    gmm_data[var+'_ave']={'dims':('n_comps','z'),'data':group_wave[var]}
    gmm_data[var+'_std']={'dims':('n_comps','z'),'data':group_wstd[var]}
    

dict_data={}
for v in ['CHLA','DOXY','TEMP','PSAL','NITRATE']:
    dict_data[v]={'dims':('t','z'),'data':data[v]}
    
dict_data['CHLA_test']={'dims':('t','z'),'data':test_data}

for v in ['LATITUDE','LONGITUDE','JULD']:
    dict_data[v]={'dims':('t'),'data':data[v]}

for v in ['PRES','DEPTH']:
    dict_data[v]={'dims':('t','z'),'data':data[v]}

#GMM data
for var in gmm_data.keys():
    dict_data[var]=gmm_data[var]

ds=xr.Dataset.from_dict(dict_data)
print(ds.keys())

filename=sfpath+'gmm_results_%dpcs_%dclusters.nc' %(pcs,gmm_comps)

ds.to_netcdf(filename)

Find averages
Save Data
KeysView(<xarray.Dataset>
Dimensions:         (n_comps: 21, pcs: 17, profs: 71272, t: 71272, z: 51)
Dimensions without coordinates: n_comps, pcs, profs, t, z
Data variables:
    CHLA            (t, z) float64 1.456 1.456 1.456 ... 0.006506 0.004381
    DOXY            (t, z) float64 343.3 343.4 343.5 343.7 ... 167.4 167.0 166.6
    TEMP            (t, z) float64 2.734 2.736 2.734 2.714 ... 13.71 13.71 13.71
    PSAL            (t, z) float64 33.85 33.86 33.86 33.86 ... 38.6 38.6 38.61
    NITRATE         (t, z) float64 nan nan nan nan nan ... nan nan nan nan nan
    CHLA_test       (t, z) float64 1.486 1.486 1.486 ... 0.03535 0.03651 0.03438
    LATITUDE        (t) float64 -48.48 -48.48 -48.49 -48.48 ... 40.96 41.01 41.1
    LONGITUDE       (t) float64 72.2 72.19 72.19 72.18 ... 5.936 5.901 5.922
    JULD            (t) datetime64[ns] 2011-10-29T18:16:44.999980288 ... 2020...
    PRES            (t, z) float64 0.0 5.043 10.09 15.13 ... 242.0 247.1 252.1
    DEPT

Thanks for playing around with this. Please let me know if you have any feedback, thoughts, or want to collaborate on future related projects!
Rosalind Echols
rechols@uw.edu