In [1]:
import numpy as np
from collections import Counter
from collections import defaultdict
from math import log
import pandas as pd
import xarray as xr
from itertools import product
import hexMinisom
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib import colormaps
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.feature as cf
import pickle
import itertools
import colorsys
import random
import datetime as dt

In [2]:
def save_som(som, fileName):
    with open(fileName, 'wb') as outfile:
        pickle.dump(som, outfile)

def load_som(fileName):
    with open(fileName, 'rb') as infile:
        som = pickle.load(infile)
    return som

def generate_distinct_colors(n):
    colors = []
    
    # Generate `n` distinct colors in the HSV color space
    for i in range(n):
        # Generate a unique hue for each color, evenly spaced between 0 and 1
        hue = i / n
        # Use full saturation and value to get vibrant colors
        saturation = 1.0
        value = 1.0
        
        # Convert HSV to RGB (the result is a tuple of RGB values in [0, 1])
        rgb = colorsys.hsv_to_rgb(hue, saturation, value)
        
        # Convert RGB to a format that ranges from 0 to 255 and create a hex string
        rgb = [int(x * 255) for x in rgb]
        hex_color = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
        colors.append(hex_color)

    random.shuffle(colors)
    return colors

# Get the counts of each WR_label for each node in SOM
def get_WR_counts(l, return_percents=False, indices=None):

    # Take only the data from days with the given indices in indices
    if indices is None:
        labels, counts = np.unique([WR_labels[i] for i in l], return_counts=True)
    else:
        labels, counts = np.unique([WR_labels[i] for i in l if i in indices], return_counts=True)

    # If the node has no days of a given regime add the regime to the labels and 0 as its count
    if len(labels) != len(WR_labels_dict.keys()):
        missing_values = np.setdiff1d(list(WR_labels_dict.keys()), labels)

        labels = WR_labels_dict.keys()

        for v in missing_values:
           counts = np.insert(counts, v, 0)

    # Calculate the percents
    if return_percents:
        counts = 100 * counts / sum(counts)

    regime_counts = dict(zip(labels, counts))
    return regime_counts

# Plotting functions
def hex_heatmap(som, data, cmap='Blues', title='', cbLabel=''):
    # set up the figure
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    ax.set_aspect('equal')
    cmap = mpl.colormaps[cmap]
    
    # get data from the som
    weights = som.get_weights()
    xx, yy = som.get_euclidean_coordinates()
    
    maxCount = max(v for v in data.values())
    #maxCount = 553  #frequency
    #maxCount = 4.39   #persistence
    minCount = min(v for v in data.values())
    #minCount = 65    #frequency
    #minCount = 1.23   #persistence
    
    # loops through the neurons
    for i in range(weights.shape[0]):
        for j in range(weights.shape[1]):
            # Only use non-masked nodes
            if som._mask[i, j] == 0:
                # If theres no data still plot the hexagon
                if (i, j) not in data:
                    data[(i,j)] = 0
                    
                # determine the hexagon position and color
                wy = yy[(j, i)] * np.sqrt(3) / 2
                colorWeight = data[(i, j)]/maxCount
                
                # Create hexagon and add it to axis
                hex = patches.RegularPolygon((xx[(j, i)], wy), numVertices=6, radius=.85 / np.sqrt(3), 
                                        facecolor=cmap(colorWeight), edgecolor='grey')
                ax.add_patch(hex)
                
                # determine the color the text should be based on color of node
                if colorWeight >= .75:
                    textColor = 'white'
                else:
                    textColor = 'black'
                
                # add text to hexagon for its frequency
                plt.text(xx[(j, i)], wy - .07, f'{data[(i, j)]}', {'horizontalalignment': 'center', 'color': textColor, 'fontsize': 17})
            
    # align figure to show all hexagons
    plt.xlim(-1, weights.shape[0] - .5)
    plt.ylim(-1, (weights.shape[1] - .5) * np.sqrt(3) / 2)
    
    # remove the axis labels and lines
    ax.axis('off')
    
    # Create the color bar
    norm = mpl.colors.Normalize(vmin=minCount, vmax=maxCount)
    cb = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, location='bottom', anchor=(.56, 2.3), shrink=.65, extend='both')
    cb.set_label(cbLabel, fontsize=17)
    cb.ax.tick_params(labelsize=17)
    
    # Title the plot
    plt.title(title, fontsize=17, y=.95, x=.515)
    return fig

def hex_frequency_plot(som, winmap=None):

    if winmap == None:
        winmap = som.win_map(dataarray)
    
    data = {k: len(v) for k, v in winmap.items()}
    
    fig = hex_heatmap(som, data, 'Blues', '(a) SOM Node Frequencies', 'Count')
    return fig

def hex_plot(som, projection=None):
    """Create a matplot lib figure with an axis for each neuron already positioned into the hexagonal shape"""
    
    # Extract the needed data from the som
    n = som._num
    xy = (2 * n) - 1
    mask = som._mask
    node_indices_xy = np.ma.where(mask == False)
    node_indices = list(zip(node_indices_xy[0], node_indices_xy[1]))

    # create figure
    totRows = xy * 3
    totCols = xy * 2
    fig = plt.figure(figsize=[totRows, totCols])
    axs = {}
    
    for x, y in list(product(range(xy), range(xy))):

        # make a subplot for the nodes not masked
        if (y, x) in node_indices: # for showing in the proper orientation x and y must be switched
            
            # odd rows will be offset to keep the hexagonal shape
            if y % 2 == 0:
                curRow = x * 3
            else:
                curRow = (x * 3) + 1
                
            curCol = (totCols - 2) - (y * 2)

            ax = plt.subplot2grid((totCols, totRows), (curCol, curRow), rowspan=2, colspan=2, projection=projection)
            ax.set_title(node_nums[(y, x)])
            
            if projection is not None:
                ax.set_extent([-180, -30, 20, 80], crs=projection)
            axs[(y, x)] = ax
    return fig, axs

def find_longest_consecutive_index(arr):
    if not arr.size:
        return -1  # Handle empty array

    max_len = 0
    start_index = -1
    current_len = 1
    current_start = 0

    for i in range(1, len(arr)):
        if arr[i] == arr[i - 1]:
            current_len += 1
        else:
            if current_len > max_len:
                max_len = current_len
                start_index = current_start
            current_len = 1
            current_start = i

    # Check if the last sequence is the longest
    if current_len > max_len:
        start_index = current_start
    return start_index

def trend_significance(group, alpha=.05):
    g = list(group)
    slopes = []
    iters = 10000
    
    for i in range(iters):
        samples = []
        for j in range(len(group)):
            samples.extend(random.sample(g, 1))
            
        m = np.polyfit(group.index, samples, 1)[0]
        slopes.append(m)
        
    lower = np.percentile(slopes, (alpha/2) * 100)
    upper = np.percentile(slopes, (1 - alpha/2) * 100)
    return lower, upper

In [3]:
def compute_class_spread(winmap, y, n_classes):
    """
    For each class:
    - Get BMU coordinates of all samples
    - Compute the centroid
    - Compute average Euclidean distance to centroid
    """
    class_spread = {}

    for cls in range(n_classes):
        bmu_coords = []

        # Get all BMUs for class cls
        for bmu, indices in winmap.items():
            for i in indices:
                if y[i] == cls:
                    bmu_coords.append(np.array(bmu))

        if not bmu_coords:
            class_spread[cls] = 0.0
            continue

        bmu_coords = np.array(bmu_coords)
        centroid = bmu_coords.mean(axis=0)
        distances = np.linalg.norm(bmu_coords - centroid, axis=1)
        spread = np.mean(distances)
        class_spread[cls] = spread
    return class_spread

def compute_class_coverage(winmap, y, n_classes):
    """
    For each class, count how many unique SOM nodes are activated.
    """
    class_nodes = defaultdict(set)

    for bmu, indices in winmap.items():
        for i in indices:
            class_nodes[y[i]].add(bmu)

    class_coverage = {cls: len(class_nodes[cls]) for cls in range(n_classes)}
    return class_coverage
    
def compute_class_entropy(winmap, y, n_classes):
    """
    Computes entropy of SOM node distribution for each class.
    Entropy reflects how spread out each class is across the SOM.
    """
    class_entropy = {}

    for cls in range(n_classes):
        node_counts = []
        total = 0

        for bmu, indices in winmap.items():
            count = sum(1 for i in indices if y[i] == cls)
            if count > 0:
                node_counts.append(count)
                total += count

        if total == 0:
            class_entropy[cls] = 0.0
            continue

        probs = [count / total for count in node_counts]
        entropy = -sum(p * log(p) for p in probs if p > 0)
        class_entropy[cls] = entropy

    return class_entropy
    
def compute_topographic_class_purity(winmap, y, n_classes):
    """
    Computes the average node purity for each class.
    Purity = proportion of the most frequent class in each node.
    """
    class_purity = {}

    for cls in range(n_classes):
        purities = []

        for bmu, indices in winmap.items():
            labels = [y[i] for i in indices]
            if cls in labels:
                counts = Counter(labels)
                node_purity = counts[cls] / len(labels)
                purities.append(node_purity)

        if purities:
            class_purity[cls] = sum(purities) / len(purities)
        else:
            class_purity[cls] = 0.0
    return class_purity

def check_node_activation_by_class(winmap, y):
    
    class_node_count = defaultdict(set)

    for bmu, indices in winmap.items():
        labels_in_node = set(y[i] for i in indices)
        for label in labels_in_node:
            class_node_count[label].add(bmu)

    for cls in sorted(class_node_count.keys()):
        print(f"Class {cls} appears in {len(class_node_count[cls])} SOM nodes")

def compute_weighted_coverage(winmap, y, n_classes, method='effective'):
    """
    Compute weighted coverage of classes over SOM nodes.
    
    method:
      - 'effective': exponential of entropy of class distribution over nodes (effective number of nodes)
      - 'simpson': inverse of sum of squared proportions (Simpson index)
    
    Returns:
      dict[class] -> weighted coverage score (higher = more spread out)
    """
    class_node_counts = {cls: defaultdict(int) for cls in range(n_classes)}
    class_totals = {cls: 0 for cls in range(n_classes)}
    
    # Count samples of each class in each node
    for bmu, indices in winmap.items():
        for i in indices:
            cls = y[i]
            class_node_counts[cls][bmu] += 1
            class_totals[cls] += 1
    
    coverage = {}
    for cls in range(n_classes):
        counts = np.array(list(class_node_counts[cls].values()))
        if counts.size == 0:
            coverage[cls] = 0.0
            continue
        p = counts / counts.sum()  # proportion per node
        
        if method == 'effective':
            entropy = -np.sum(p * np.log(p + 1e-12))  # add small value to avoid log(0)
            coverage[cls] = np.exp(entropy)
        elif method == 'simpson':
            coverage[cls] = 1 / np.sum(p ** 2)
        else:
            raise ValueError("Method must be 'effective' or 'simpson'")
    return coverage

In [4]:
dataset = xr.open_dataarray('/glade/work/molina/DATA/Z500Anoms_ERA5.nc')

latSlice = slice(20, 80) #20N, 80N
lonSlice = slice(180, 330) #180W, 30W
dataarray = dataset.sel(lat=latSlice, lon=lonSlice)
dataarray = dataarray.stack(latlon=['lat', 'lon']).values

# Seasonal breakdown of the data
DJF = dataset.time.dt.month.isin([12, 1, 2])
DJF_idxs = np.array(DJF).nonzero()[0]
MAM = dataset.time.dt.month.isin([3, 4, 5])
MAM_idxs = np.array(MAM).nonzero()[0]
JJA = dataset.time.dt.month.isin([6, 7, 8])
JJA_idxs = np.array(JJA).nonzero()[0]
SON = dataset.time.dt.month.isin([9, 10, 11])
SON_idxs = np.array(SON).nonzero()[0]

print(dataarray.shape)

(30660, 9211)


In [5]:
som = load_som('SOM40.p')

n = som._num
xy = hexMinisom.xy_using_n(n)

mask = som._mask
node_indices_xy = np.ma.where(mask == False)
node_indices = list(zip(node_indices_xy[0], node_indices_xy[1]))
all_nodes = product(range(xy), range(xy))

inputLength = dataarray.shape[1]

winmap = som.win_map(dataarray, return_indices=True)
# Seasonal breakdown for winmap
DJF_winmap = {}
MAM_winmap = {}
JJA_winmap = {}
SON_winmap = {}

# Loop through each node
for k, v in winmap.items():
    
    # Keep only the days that are in the given season
    DJF_winmap[k] = [i for i in v if i in DJF_idxs]
    MAM_winmap[k] = [i for i in v if i in MAM_idxs]
    JJA_winmap[k] = [i for i in v if i in JJA_idxs]
    SON_winmap[k] = [i for i in v if i in SON_idxs]

w = som._weights
minimum_weight = -np.max(np.abs(w))
maximum_weight = np.max(np.abs(w))

# Calculate the node number for each coordinate
node_nums = {}
n = 1
for i in range(mask.shape[0])[::-1]:
    for j in range(mask.shape[1]):
        # only use non masked nodes
        if som._mask[i, j] == 0:
            node_nums[(i, j)] = n
            n += 1

color_list = generate_distinct_colors(len(node_indices))

In [6]:
# Import the regime labels
WR_labels_df = pd.read_csv('df_labels_nocorrfilt_ERA5.csv')
WR_labels_df.rename(columns={'Unnamed: 0': 'date'}, inplace=True)
WR_labels_df['date'] = pd.to_datetime(WR_labels_df['date'], format='%Y-%m-%d')
WR_labels_dict = {
    0: 'Polar High', 1: 'Pacific Trough', 2: 'Pacific Ridge', 
    3: 'Alaskan Ridge', 4: 'Atlantic Ridge', 5: 'No WR'
}
WR_labels = np.array(WR_labels_df['WR'])

WRs_by_node = {k: get_WR_counts(v) for k, v in winmap.items()}
WRs_percents = {k: get_WR_counts(v, return_percents=True) for k, v in winmap.items()}

print(WR_labels_df)

# Calculate the 90th percentile of the distances and only keep data less than that
percentile90 = np.percentile(WR_labels_df['distances'], 90)
lt90 = (np.array(WR_labels_df['distances']) < percentile90).nonzero()[0]
WRs_lt90 = {k: get_WR_counts(v, True, lt90) for k, v in winmap.items()}

# Calculate the variances of the distances for each WR
WR_indices = {i: (WR_labels == i).nonzero()[0] for i in np.unique(WR_labels)}
for WR, idxs in WR_indices.items():
    variance = np.var(WR_labels_df['distances'].iloc[idxs])
    
# Get the WR counts for each specific season
WRs_DJF = {k: get_WR_counts(v, True, DJF_idxs) for k, v in winmap.items()}
WRs_MAM = {k: get_WR_counts(v, True, MAM_idxs) for k, v in winmap.items()}
WRs_JJA = {k: get_WR_counts(v, True, JJA_idxs) for k, v in winmap.items()}
WRs_SON = {k: get_WR_counts(v, True, SON_idxs) for k, v in winmap.items()}

            date  WR  distances      corr
0     1940-01-01   0   2.463938  0.518457
1     1940-01-02   0   2.662645  0.565398
2     1940-01-03   0   2.916932  0.552532
3     1940-01-04   0   3.122750  0.495652
4     1940-01-05   0   3.302769  0.394692
...          ...  ..        ...       ...
30655 2023-12-27   1   2.823370  0.857459
30656 2023-12-28   1   2.687266  0.837698
30657 2023-12-29   1   2.394127  0.781903
30658 2023-12-30   1   2.253624  0.655454
30659 2023-12-31   0   2.468174  0.359541

[30660 rows x 4 columns]


In [7]:
names_to_fill = []

for i in WR_labels_df['WR'].values:
    
    names_to_fill.append(WR_labels_dict[i])

array_to_fill = np.ones(len(WR_labels_df), dtype=int) * -9999

for i in range(0, len(node_indices)):
    
    for j in winmap[node_indices[i]]:
        
        array_to_fill[j] = node_nums[node_indices[i]]

WR_labels_df['node'] = array_to_fill
WR_labels_df['WR_name'] = names_to_fill

diversity_df = pd.concat([WR_labels_df, pd.DataFrame(dataarray)], axis=1)

# Extract feature columns — these are the SOM inputs
X = diversity_df.iloc[:, 6:]

# Extract class labels (e.g. WR categories)
y = diversity_df['WR'].values 

In [8]:
# total samples per regime
print(np.unique(y, return_counts=True))
print(len(y))

tmp_map = winmap

spread_scores = compute_class_spread(tmp_map, y, n_classes=6)
coverage_scores = compute_class_coverage(tmp_map, y, n_classes=6)
entropy_scores = compute_class_entropy(tmp_map, y, n_classes=6)
purity_scores = compute_topographic_class_purity(tmp_map, y, n_classes=6)
wcoverage_effective = compute_weighted_coverage(tmp_map, y, n_classes=6, method='effective')
wcoverage_simpson = compute_weighted_coverage(tmp_map, y, n_classes=6, method='simpson')

#for cls in range(6):
#    print(
#        f"Spread = {
#        spread_scores[cls]:.3f}, Coverage = {
#        coverage_scores[cls]}, Entropy = {
#        entropy_scores[cls]:.3f}, Purity = {
#        purity_scores[cls]:.3f}, Eff. Coverage = {
#        wcoverage_effective[cls]:.3f}, Simpson = {
#        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
#    )

for cls in range(6):
    print(
        f"Coverage = {
        coverage_scores[cls]}, Simpson = {
        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
    )

(array([0, 1, 2, 3, 4, 5]), array([5217, 6164, 5943, 4279, 5821, 3236]))
30660
Coverage = 36, Simpson = 10.57: Polar High
Coverage = 37, Simpson = 21.48: Pacific Trough
Coverage = 37, Simpson = 13.22: Pacific Ridge
Coverage = 37, Simpson = 12.38: Alaskan Ridge
Coverage = 37, Simpson = 30.28: Atlantic Ridge
Coverage = 37, Simpson = 30.41: No WR


In [9]:
# total samples per regime
print(np.unique(y[DJF_idxs], return_counts=True))
print(len(y[DJF_idxs]))

tmp_map = DJF_winmap

spread_scores = compute_class_spread(tmp_map, y, n_classes=6)
coverage_scores = compute_class_coverage(tmp_map, y, n_classes=6)
entropy_scores = compute_class_entropy(tmp_map, y, n_classes=6)
purity_scores = compute_topographic_class_purity(tmp_map, y, n_classes=6)
wcoverage_effective = compute_weighted_coverage(tmp_map, y, n_classes=6, method='effective')
wcoverage_simpson = compute_weighted_coverage(tmp_map, y, n_classes=6, method='simpson')

#for cls in range(6):
#    print(
#        f"Spread = {
#        spread_scores[cls]:.3f}, Coverage = {
#        coverage_scores[cls]}, Entropy = {
#        entropy_scores[cls]:.3f}, Purity = {
#        purity_scores[cls]:.3f}, Eff. Coverage = {
#        wcoverage_effective[cls]:.3f}, Simpson = {
#        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
#    )

for cls in range(6):
    print(
        f"Coverage = {
        coverage_scores[cls]}, Simpson = {
        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
    )

(array([0, 1, 2, 3, 4, 5]), array([1186, 1636, 1748,  889, 1352,  749]))
7560
Coverage = 31, Simpson = 8.60: Polar High
Coverage = 37, Simpson = 18.06: Pacific Trough
Coverage = 32, Simpson = 11.53: Pacific Ridge
Coverage = 34, Simpson = 8.69: Alaskan Ridge
Coverage = 37, Simpson = 27.84: Atlantic Ridge
Coverage = 37, Simpson = 26.17: No WR


In [10]:
# total samples per regime
print(np.unique(y[MAM_idxs], return_counts=True))
print(len(y[MAM_idxs]))

tmp_map = MAM_winmap

spread_scores = compute_class_spread(tmp_map, y, n_classes=6)
coverage_scores = compute_class_coverage(tmp_map, y, n_classes=6)
entropy_scores = compute_class_entropy(tmp_map, y, n_classes=6)
purity_scores = compute_topographic_class_purity(tmp_map, y, n_classes=6)
wcoverage_effective = compute_weighted_coverage(tmp_map, y, n_classes=6, method='effective')
wcoverage_simpson = compute_weighted_coverage(tmp_map, y, n_classes=6, method='simpson')

#for cls in range(6):
#    print(
#        f"Spread = {
#        spread_scores[cls]:.3f}, Coverage = {
#        coverage_scores[cls]}, Entropy = {
#        entropy_scores[cls]:.3f}, Purity = {
#        purity_scores[cls]:.3f}, Eff. Coverage = {
#        wcoverage_effective[cls]:.3f}, Simpson = {
#        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
#    )

for cls in range(6):
    print(
        f"Coverage = {
        coverage_scores[cls]}, Simpson = {
        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
    )

(array([0, 1, 2, 3, 4, 5]), array([1386, 1553, 1538,  999, 1485,  767]))
7728
Coverage = 32, Simpson = 9.11: Polar High
Coverage = 37, Simpson = 23.05: Pacific Trough
Coverage = 36, Simpson = 11.98: Pacific Ridge
Coverage = 37, Simpson = 11.01: Alaskan Ridge
Coverage = 37, Simpson = 26.55: Atlantic Ridge
Coverage = 36, Simpson = 29.22: No WR


In [11]:
# total samples per regime
print(np.unique(y[JJA_idxs], return_counts=True))
print(len(y[JJA_idxs]))

tmp_map = JJA_winmap

spread_scores = compute_class_spread(tmp_map, y, n_classes=6)
coverage_scores = compute_class_coverage(tmp_map, y, n_classes=6)
entropy_scores = compute_class_entropy(tmp_map, y, n_classes=6)
purity_scores = compute_topographic_class_purity(tmp_map, y, n_classes=6)
wcoverage_effective = compute_weighted_coverage(tmp_map, y, n_classes=6, method='effective')
wcoverage_simpson = compute_weighted_coverage(tmp_map, y, n_classes=6, method='simpson')

#for cls in range(6):
#    print(
#        f"Spread = {
#        spread_scores[cls]:.3f}, Coverage = {
#        coverage_scores[cls]}, Entropy = {
#        entropy_scores[cls]:.3f}, Purity = {
#        purity_scores[cls]:.3f}, Eff. Coverage = {
#        wcoverage_effective[cls]:.3f}, Simpson = {
#        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
#    )

for cls in range(6):
    print(
        f"Coverage = {
        coverage_scores[cls]}, Simpson = {
        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
    )

(array([0, 1, 2, 3, 4, 5]), array([1647, 1335, 1060, 1410, 1406,  870]))
7728
Coverage = 36, Simpson = 11.24: Polar High
Coverage = 37, Simpson = 23.61: Pacific Trough
Coverage = 37, Simpson = 16.04: Pacific Ridge
Coverage = 37, Simpson = 16.08: Alaskan Ridge
Coverage = 37, Simpson = 28.24: Atlantic Ridge
Coverage = 37, Simpson = 30.51: No WR


In [12]:
# total samples per regime
print(np.unique(y[SON_idxs], return_counts=True))
print(len(y[SON_idxs]))

tmp_map = SON_winmap

spread_scores = compute_class_spread(tmp_map, y, n_classes=6)
coverage_scores = compute_class_coverage(tmp_map, y, n_classes=6)
entropy_scores = compute_class_entropy(tmp_map, y, n_classes=6)
purity_scores = compute_topographic_class_purity(tmp_map, y, n_classes=6)
wcoverage_effective = compute_weighted_coverage(tmp_map, y, n_classes=6, method='effective')
wcoverage_simpson = compute_weighted_coverage(tmp_map, y, n_classes=6, method='simpson')

#for cls in range(6):
#    print(
#        f"Spread = {
#        spread_scores[cls]:.3f}, Coverage = {
#        coverage_scores[cls]}, Entropy = {
#        entropy_scores[cls]:.3f}, Purity = {
#        purity_scores[cls]:.3f}, Eff. Coverage = {
#        wcoverage_effective[cls]:.3f}, Simpson = {
#        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
#    )

for cls in range(6):
    print(
        f"Coverage = {
        coverage_scores[cls]}, Simpson = {
        wcoverage_simpson[cls]:.2f}: {WR_labels_dict[cls]}"
    )

(array([0, 1, 2, 3, 4, 5]), array([ 998, 1640, 1597,  981, 1578,  850]))
7644
Coverage = 29, Simpson = 12.40: Polar High
Coverage = 37, Simpson = 19.93: Pacific Trough
Coverage = 35, Simpson = 13.01: Pacific Ridge
Coverage = 37, Simpson = 10.30: Alaskan Ridge
Coverage = 37, Simpson = 30.00: Atlantic Ridge
Coverage = 37, Simpson = 29.01: No WR
