# Clustering analysis

1. Load distance data
2. Keep all samples
    - Remove outliers (`k=8, N=10`)
    - Dendrogram
    - Large partition (`k=4`)
        - Map
    - Analyse tropical subtree (`k=6`, `N=10`)
        - Fine-grained partition map (`k=7`)
    - Depth graph
3. Filter `<25m depth` samples, remove outliers
    - Dendrogram
4. Large partitioning `k=4`
    - Map
5. Fine-grained partioning `k=10`
    - Map
6. Export cluster tables

In [None]:
import os
import sys
import itertools

import pandas as pd
import numpy as np
import seaborn as sns

import matplotlib as mpl
from matplotlib import pyplot as plt

from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, dendrogram, cut_tree, set_link_color_palette
from scipy import stats

sys.path.insert(0, "/local/path/to/scripts/")
from plotting_utils import plot_colored_markers, palettes

os.environ['PATH'] = os.environ['PATH'] + ':/apps/easybuild-2022/easybuild/software/Compiler/GCC/11.3.0/texlive/20230313/bin/x86_64-linux/'
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['text.usetex'] = True
get_cmap_ix = lambda ix, cmap: mpl.colors.rgb2hex(plt.get_cmap(cmap).colors[ix])


# 1. Load distance data

In [None]:
df = pd.read_csv("/local/path/to/data/distances/sourmash.csv", index_col=0)
md = pd.read_csv("~/biogo-hub/provinces_final/data/metadata_2132.csv", index_col=0)

df = df.loc[md.index, md.index]

In [None]:
md["coords"].drop_duplicates()

# 2. Keep all samples
## Remove outliers

In [None]:
X = squareform(df.values)
Z = linkage(X, method='average')
print_clusters = False

# Use this to determine k
if print_clusters:
    for k in range(5, 11):
        K = cut_tree(Z, n_clusters=k)
        labels = pd.DataFrame(K, index=df.index)[0]
        print(f"k = {k}")
        k_gt_ten = labels.value_counts()[labels.value_counts() > 10].index.__len__()
        print(f"Groups larger than 10: {k_gt_ten}.")
        print(labels.value_counts())
        print()
else:
    k_gt_ten = 4

In [None]:
k, N = 8, 10
K = cut_tree(Z, n_clusters=k)
cut_value = Z[-(k - 1), 2]
labels = pd.DataFrame(K, index=df.index)[0]
label_name = f"sourmash_k_{k_gt_ten}_{len(labels)}"
md[label_name] = labels
outlier_provs = md[label_name].value_counts()[md[label_name].value_counts() < N].index
outlier_samples = md[md[label_name].isin(outlier_provs)]
md.loc[outlier_samples.index, label_name] = 99
robust_provs = md[label_name].value_counts()[md[label_name].value_counts() >= N].index

In [None]:
outlier_samples

## Dendrogram

In [None]:
def plot_dendrogram_all_samples(label_name):
    fig, ax = plt.subplots(figsize=(20, 240))

    D = dendrogram(Z,
                color_threshold=cut_value,
                labels= "B" + md[label_name].astype(str) + "_" + md.index,
                orientation='left',
                leaf_font_size=8,
                ax=ax)
    _ = ax.set_xlim(1, 0.75)

# plot_dendrogram_all_samples(label_name)

In [None]:
fig, ax = plt.subplots(figsize=(4, 10))

D = dendrogram(Z,
               color_threshold=0.992,
               orientation='left',
               leaf_font_size=12,
               ax=ax,
               show_leaf_counts=True,
               truncate_mode="lastp",
               p=7)
clusters = pd.Series(D["color_list"]).value_counts()
_ = ax.set_xlim(1, 0.99)
_ = ax.set_xticks(np.arange(1, 0.99, -0.002))

## Large partitioning map

In [None]:
md_robust = md[md[label_name].isin(robust_provs)]
plot_colored_markers(md_robust, color_category=label_name, jitter=1, cmap="Dark2")

## Analyse subtree

In [None]:
md

In [None]:
md_subtree = md[md[label_name] == 3]
len(md_subtree) / 2129

In [None]:
subtree = 3
md_subtree = md[md[label_name] == subtree]
subtree = md_subtree.index

X = df.loc[subtree, subtree]
X = squareform(X.values)
Z = linkage(X, method='average')

if print_clusters:
    for k in range(5, 21):
        K = cut_tree(Z, n_clusters=k)
        labels = pd.DataFrame(K, index=subtree)[0]
        print(f"k = {k}")
        k_gt_ten = labels.value_counts()[labels.value_counts() > N].index.__len__()
        print(f"Groups larger than 10: {k_gt_ten}.")
        print(labels.value_counts())
    print()
else:
    k_gt_ten = 3

k = 6
K = cut_tree(Z, n_clusters=k)
cut_value = Z[-(k - 1), 2]
labels = pd.DataFrame(K, index=subtree)[0]

# Identify labels from subtree
labels = labels + 20
md_subtree.loc[labels.index, label_name] = labels
outlier_provs = md_subtree[label_name].value_counts()[md_subtree[label_name].value_counts() < N].index
outlier_samples = md_subtree[md_subtree[label_name].isin(outlier_provs)]
md_subtree.loc[outlier_samples.index, label_name] = 99

# Join subtree with others
md.loc[labels.index, label_name] = labels
robust_provs = md[label_name].value_counts()[md[label_name].value_counts() >= N].index

In [None]:
md_subtree["sourmash_k_4_2132"].value_counts()

### Map

In [None]:
# Subtree only
# plot_colored_markers(md_subtree[md_subtree[label_name].isin(robust_provs)], color_category=label_name, jitter=1, cmap="Dark2")

md_robust = md[md[label_name].isin(robust_provs)]

# All samples
plot_colored_markers(md_robust, color_category=label_name, jitter=1, cmap="Dark2")

In [None]:
md_robust["sourmash_k_4_2132"].value_counts()

In [None]:
(1203 + 39 + 237) - 1482

### Depth plots

In [None]:
plot_data = md_robust

In [None]:
sns.boxplot(data = plot_data, x=label_name, y="depth", showfliers=False, color="lightgray", width=0.5)
sns.stripplot(data = plot_data, x=label_name, y="depth", jitter=0.15, alpha=0.5)

In [None]:
boxenplot_colors = dict(zip(plot_data[label_name].astype(int).unique(), sns.color_palette("Dark2").as_hex()))

In [None]:
sns.boxenplot(data = plot_data, x=label_name, y="depth", showfliers=False, color="gray", width=0.5, palette=boxenplot_colors, hue=label_name, legend=False)

In [None]:
md["depth_cat"] = md.apply(lambda row: "MLD" if row["MixedLayerDepth"] <= row["depth"] else "SUR", axis=1)
md[md[label_name].isin(robust_provs)].groupby(label_name)["depth_cat"].value_counts(normalize=True).unstack().plot(kind="bar", stacked=True)

In [None]:
# Export
md.rename(columns={"sourmash_k_4_2132": "sourmash_k_7_2132"}).to_csv("~/biogo-hub/provinces_final/data/sourmash_k_7_2132.csv")

# 2. Filter samples to 25m depth

In [None]:
len(md)

In [None]:
md = md[md["depth"] < 25].drop(label_name, axis=1)
df = df.loc[md.index, md.index]

In [None]:
len(df)

## Remove outliers

In [None]:
X = squareform(df.values)
Z = linkage(X, method='average')
print_clusters = False

# Use this to determine k
if print_clusters:
    for k in range(5, 11):
        K = cut_tree(Z, n_clusters=k)
        labels = pd.DataFrame(K, index=df.index)[0]
        print(f"k = {k}")
        k_gt_ten = labels.value_counts()[labels.value_counts() > 10].index.__len__()
        print(f"Groups larger than 10: {k_gt_ten}.")
        print(labels.value_counts())
        print()
else:
    k_gt_ten = 4

In [None]:
k, N = 6, 10
K = cut_tree(Z, n_clusters=k)
cut_value = Z[-(k - 1), 2]
labels = pd.DataFrame(K, index=df.index)[0]
label_name = f"sourmash_k_{k_gt_ten + 2}_{len(labels)}_25m"
md[label_name] = labels
outlier_provs = md[label_name].value_counts()[md[label_name].value_counts() < N].index
outlier_samples = md[md[label_name].isin(outlier_provs)]
md.loc[outlier_samples.index, label_name] = 99
robust_provs = md[label_name].value_counts()[md[label_name].value_counts() >= N].index
robust_samples = md[md[label_name].isin(robust_provs)].index

In [None]:
print(len(robust_samples))
print(len(outlier_samples))

## Dendrogram

In [None]:
# plot_dendrogram_all_samples(label_name)

In [None]:
def leaf_label_func(id):
    mapping_counts = {
        3013: md.loc[robust_samples][label_name].value_counts().loc[1],
        3005: md.loc[robust_samples][label_name].value_counts().loc[0],
        3015: md.loc[robust_samples][label_name].value_counts().loc[2],
        2998: md.loc[robust_samples][label_name].value_counts().loc[5],
    }
    mapping = {
        3013: "Temperate",
        3005: "Polar",
        3015: "Tropical",
        2998: "Baltic Sea",
    }
    return mapping[id] + f" ({mapping_counts[id]})"

### Collapsed dendrogram (k=4)

In [None]:
X = squareform(df.loc[robust_samples, robust_samples].values)
Z = linkage(X, method='average')

fig, ax = plt.subplots(figsize=(4, 10))

D = dendrogram(Z,
               color_threshold=0.992,
               orientation='left',
               leaf_font_size=12,
               ax=ax,
               show_leaf_counts=True,
               leaf_label_func=leaf_label_func,
               truncate_mode="lastp",
               p=4)
clusters = pd.Series(D["color_list"]).value_counts()
_ = ax.set_xlim(1, 0.99)
_ = ax.set_xticks(np.arange(1, 0.99, -0.002))

## Large partitioning map

In [None]:
plot_colored_markers(md[md[label_name].isin(robust_provs)], color_category=label_name, jitter=1, cmap="Dark2")

In [None]:
md[md[label_name].isin(robust_provs)].to_csv(f"~/biogo-hub/provinces_final/data/{label_name}.csv")

## Analyse subtree

In [None]:
len(subtree)

In [None]:
subtree = 2
md_subtree = md[md[label_name] == subtree]
subtree = md_subtree.index

X = df.loc[subtree, subtree]
X = squareform(X.values)
Z = linkage(X, method='average')

if print_clusters:
    for k in range(5, 21):
        K = cut_tree(Z, n_clusters=k)
        labels = pd.DataFrame(K, index=subtree)[0]
        print(f"k = {k}")
        k_gt_ten = labels.value_counts()[labels.value_counts() > N].index.__len__()
        print(f"Groups larger than 10: {k_gt_ten}.")
        print(labels.value_counts())
    print()
else:
    k_gt_ten = 3

k = 14
K = cut_tree(Z, n_clusters=k)
cut_value = Z[-(k - 1), 2]
labels = pd.DataFrame(K, index=subtree)[0]

# Identify labels from subtree
labels = labels + 20
md_subtree.loc[labels.index, label_name] = labels
outlier_provs = md_subtree[label_name].value_counts()[md_subtree[label_name].value_counts() < N].index
outlier_samples = md_subtree[md_subtree[label_name].isin(outlier_provs)]
md_subtree.loc[outlier_samples.index, label_name] = 99

# Join subtree with others
md.loc[labels.index, label_name] = labels
robust_provs = md[label_name].value_counts()[md[label_name].value_counts() >= N].index
robust_samples = md[md[label_name].isin(robust_provs)].index

### Map

In [None]:
print(len(robust_samples))
print(len(outlier_samples))

In [None]:
# # Subtree only
# plot_colored_markers(md_subtree[md_subtree[label_name].isin(robust_provs)], color_category=label_name, jitter=1, cmap="Dark2")

# All samples
plot_colored_markers(md.loc[robust_samples], color_category=label_name, jitter=1, cmap="Dark2")

In [None]:
len(md.loc[robust_samples])

## Finer-grained partition

### Collapsed dendrogram (k=10)

In [None]:
X = squareform(df.loc[robust_samples, robust_samples].values)
Z = linkage(X, method='average')

k, N = 18, 20
K = cut_tree(Z, n_clusters=k)
cut_value = Z[-(k - 1), 2]
labels = pd.DataFrame(K, index=robust_samples)[0]
robust_provs = labels.value_counts()[labels.value_counts() > N].index
robust_samples = labels[labels.isin(robust_provs)].index
k_gt_ten = len(robust_provs)
label_name = f"sourmash_k_{k_gt_ten}_{len(labels)}_25m"
md[label_name] = 99

md.loc[robust_samples, label_name] = labels.astype(int)

df_prime = df.loc[robust_samples, robust_samples]
X_prime = squareform(df_prime.values)
Z_prime = linkage(X_prime, method='average')

In [None]:
1487 - len(robust_samples)

In [None]:
get_cmap_ix = lambda ix, cmap: mpl.colors.rgb2hex(plt.get_cmap(cmap).colors[ix])

def leaf_label_func(id, color=False):
    leaves = {
        2877:{"label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(16, "tab20"), "id":16},
        2862:{"label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(4, "tab20"), "id":11},
        2896:{"label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(6, "tab20"), "id":5},
        2897:{"label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(7, "tab20"), "id":9},
        2880:{"label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(0, "tab20"), "id":0},
        2881:{"label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(3, "tab20"), "id":14},
        2891:{"label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(8, "tab20"), "id":10},
        2893:{"label": "SSTC", "category": "TEMP", "counts": 43, "color": get_cmap_ix(14, "tab20"), "id":2},
        2888:{"label": "NADR", "category": "TEMP", "counts": 34, "color": get_cmap_ix(15, "tab20"), "id":3},
        2894:{"label": "MEDI", "category": "TEMP", "counts": 82, "color": get_cmap_ix(12, "tab20"), "id":7}}
    
    if color:
        return leaves[id]["color"]
    else:
        return f"B{leaves[id]['id']} - {leaves[id]['label']} - {leaves[id]['counts']}"


pmetadata = {
    16:{"description": "Baltic Sea", "label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(16, "tab20")},
    11:{"description": "Pacific Equatorial Divergence/Countercurrent", "label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(4, "tab20") },
    5:{"description": "Tropical Convergence", "label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(6, "tab20") },
    9:{"description": "Broad Tropical", "label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(7, "tab20") },
    0:{"description": "Antarctic Polar", "label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(0, "tab20") },
    14:{"description": "Arctic Polar", "label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(3, "tab20") },
    10:{"description": "Upwelling Areas", "label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(8, "tab20") },
    2:{"description": "S. Subtropical Convergence", "label": "SSTC", "category": "TEMP", "counts": 43, "color": get_cmap_ix(15, "tab20") },
    3:{"description": "North Atlantic Drift/Agulhas", "label": "NADR", "category": "TEMP", "counts": 34, "color": get_cmap_ix(14, "tab20") },
    7:{"description": "", "label": "MEDI", "category": "TEMP", "counts": 82, "color": get_cmap_ix(12, "tab20") }
}

In [None]:
pmetadata_sorted = {str(k): v for k, v in pmetadata.items()}
pmetadata_sorted = {k: v for k, v in sorted(pmetadata_sorted.items())}

In [None]:
[print(key) for key in pmetadata_sorted.keys()]
[print("'" + v["color"] + "'") for v in pmetadata_sorted.values()]

In [None]:
fig, ax = plt.subplots(figsize=(4, 10))

D = dendrogram(Z_prime,
               color_threshold=0.99,
               orientation='left',
               leaf_font_size=8,
               ax=ax,
               truncate_mode="lastp",
               p=7,
               get_leaves=True,
               # leaf_label_func=leaf_label_func,
               )
_ = ax.set_xlim(1, 0.97)

In [None]:
fig, ax = plt.subplots(figsize=(4, 10))

D = dendrogram(Z_prime,
               color_threshold=0.99,
               orientation='left',
               leaf_font_size=8,
               ax=ax,
               truncate_mode="lastp",
               p=8,
               get_leaves=True,
               # leaf_label_func=leaf_label_func,
               )
_ = ax.set_xlim(1, 0.97)

In [None]:
fig, ax = plt.subplots(figsize=(4, 10))

D = dendrogram(Z_prime,
               color_threshold=0.99,
               orientation='left',
               leaf_font_size=8,
               ax=ax,
               truncate_mode="lastp",
               p=9,
               get_leaves=True,
               # leaf_label_func=leaf_label_func,
               )
_ = ax.set_xlim(1, 0.97)

In [None]:
fig, ax = plt.subplots(figsize=(4, 10))

D = dendrogram(Z_prime,
               color_threshold=0.99,
               orientation='left',
               leaf_font_size=8,
               ax=ax,
               truncate_mode="lastp",
               p=10,
               get_leaves=True,
               leaf_label_func=leaf_label_func,
               )
_ = ax.set_xlim(1, 0.97)

In [None]:
9470 / 4792

In [None]:
fig, ax = plt.subplots(figsize=(16, 4))

D = dendrogram(Z_prime,
               color_threshold=0.99,
               orientation='top',
               leaf_font_size=8,
               ax=ax,
               truncate_mode="lastp",
               p=10,
               get_leaves=True,
               leaf_label_func=leaf_label_func,
               )
_ = ax.set_ylim(0.97, 1)

In [None]:
type(Z_prime)

In [None]:
# Z_prime.tofile("/local/path/to/data/clustering/z_collapsed_dend_k10.csv")

In [None]:
Z_prime_load = np.fromfile("/local/path/to/data/clustering/z_collapsed_dend_k10.csv")

In [None]:
Z_prime_load.reshape(1453, 4)

In [None]:
!ls ~/biogo-hub/provinces_final/data/clustering

In [None]:
D

In [None]:
dir(D)

### Metadata
```python
        16:{"description": "Baltic Sea", "label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(10, "tab20")},
        11:{"description": "Pacific Equatorial Divergence/Countercurrent", "label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(8, "tab20") },
        5:{"description": "Tropical Convergence", "label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(1, "tab20") },
        9:{"description": "Broad Tropical", "label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(0, "tab20") },
        0:{"description": "Antarctic Polar", "label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(15, "tab20") },
        14:{"description": "Arctic Polar", "label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(14, "tab20") },
        10:{"description": "Upwelling Areas", "label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(4, "tab20") },
        2:{"description": "S. Subtropical Convergence", "label": "SSTC", "category": "TEMP", "counts": 43, "color": get_cmap_ix(6, "tab20") },
        3:{"description": "North Atlantic Drift/Agulhas", "label": "NADR", "category": "TEMP", "counts": 34, "color": get_cmap_ix(7, "tab20") },
        7:{"description": "Mediterranean", "label": "MEDI", "category": "TEMP", "counts": 82, "color": get_cmap_ix(6, "tab20") }
```

In [None]:
def link_color_func(x):
    link_color_dict = {
        2898: get_cmap_ix(0, "tab20"),
        2902: get_cmap_ix(8, "tab20"),
        2901: get_cmap_ix(14, "tab20"),
        2899: get_cmap_ix(6, "tab20"),
        2900: get_cmap_ix(6, "tab20"),
        2903: get_cmap_ix(4, "tab20"),
        2904: "black",
        2905: "black",
        2906: get_cmap_ix(10, "tab20"),
    }
    return link_color_dict[x]

In [None]:
set_link_color_palette([get_cmap_ix(0, "tab20"), get_cmap_ix(14, "tab20"), get_cmap_ix(6, "tab20")])
fig, ax = plt.subplots(figsize=(16, 4))

D = dendrogram(Z_prime,
               color_threshold=0.99,
               orientation='top',
               leaf_font_size=8,
               ax=ax,
               truncate_mode="lastp",
               p=10,
               get_leaves=True,
               leaf_label_func=leaf_label_func,
               link_color_func=link_color_func,
               above_threshold_color="black"
               )
_ = ax.set_ylim(0.975, 1)

set_link_color_palette(None)

In [None]:
D

### Map

In [None]:
plot_colored_markers(md[md[label_name].isin(robust_provs)], color_category=label_name, jitter=1, cmap="tab10")

In [None]:
md[md[label_name].isin(robust_provs)].to_csv(f"~/biogo-hub/provinces_final/data/{label_name}.csv")

In [None]:
robust_md_1454 = md[md[label_name].isin(robust_provs)].copy()

In [None]:
robust_md_1454["sourmash_k_8_joint_temp_1487_25m"] = robust_md_1454[label_name]

In [None]:
robust_md_1454["sourmash_k_8_joint_temp_1487_25m"] = robust_md_1454["sourmash_k_8_joint_temp_1487_25m"].apply(lambda x: 7 if x in (2,3) else x)
# robust_md_1454["sourmash_k_8_joint_temp_1487_25m"] = robust_md_1454["sourmash_k_8_joint_temp_1487_25m"].apply(lambda x: 0 if x == 14 else x)

In [None]:
palettes["k_10"][7] = {'description': 'Temperate',
 'label': 'Temp',
 'category': 'TEMP',
 'counts': 159,
 'color': '#756BB2'}

palette = {
            16:{"description": "Baltic Sea", "label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(16, "tab20c")},
            11:{"description": "Pacific Equatorial Divergence/Countercurrent", "label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(0, "tab20c")},
            5:{"description": "Tropical Convergence", "label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(2, "tab20c") },
            9:{"description": "Broad Tropical", "label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(0, "tab20c") },
            0:{"description": "Antarctic Polar", "label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(13, "tab20c") },
            14:{"description": "Arctic Polar", "label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(12, "tab20c") },
            10:{"description": "Upwelling Areas", "label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(8, "tab20c") },
            7:{"description": "Temperate", "label": "TEMP", "category": "TEMP", "counts": 159, "color": get_cmap_ix(4, "tab20c") }
    }  
get_cmap_ix_tab20 = lambda ix: mpl.colors.rgb2hex(plt.get_cmap("tab20").colors[ix])
palette = {
            16:{"description": "Baltic Sea", "label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(14, "tab20")},
            11:{"description": "Pacific Equatorial Divergence/Countercurrent", "label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(-2, "tab20")},
            5:{"description": "Tropical Convergence", "label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(1, "tab20") },
            9:{"description": "Broad Tropical", "label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(0, "tab20") },
            0:{"description": "Antarctic Polar", "label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(8, "tab20") },
            14:{"description": "Arctic Polar", "label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(9, "tab20") },
            10:{"description": "Upwelling Areas", "label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(4, "tab20") },
            7:{"description": "Temperate", "label": "TEMP", "category": "TEMP", "counts": 159, "color": get_cmap_ix(6, "tab20") }
    }  
palette = {k: v["color"] for k, v in palette.items()}


get_cmap_ix_tab20b = lambda ix: mpl.colors.rgb2hex(plt.get_cmap("tab20").colors[ix])
palette = {
            16:{"description": "Baltic Sea", "label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(14, "tab20c")},
            11:{"description": "Pacific Equatorial Divergence/Countercurrent", "label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(-2, "tab20c")},
            5:{"description": "Tropical Convergence", "label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(1, "tab20c") },
            9:{"description": "Broad Tropical", "label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(0, "tab20c") },
            0:{"description": "Antarctic Polar", "label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(8, "tab20c") },
            14:{"description": "Arctic Polar", "label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(9, "tab20c") },
            10:{"description": "Upwelling Areas", "label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(4, "tab20c") },
            7:{"description": "Temperate", "label": "TEMP", "category": "TEMP", "counts": 159, "color": get_cmap_ix(6, "tab20c") }
    }  

palette = {
        16:{"description": "Baltic Sea", "label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(10, "tab20")},
        11:{"description": "Pacific Equatorial Divergence/Countercurrent", "label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(8, "tab20") },
        5:{"description": "Tropical Convergence", "label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(1, "tab20") },
        9:{"description": "Broad Tropical", "label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(0, "tab20") },
        0:{"description": "Antarctic Polar", "label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(15, "tab20") },
        14:{"description": "Arctic Polar", "label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(14, "tab20") },
        10:{"description": "Upwelling Areas", "label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(4, "tab20") },
        2:{"description": "S. Subtropical Convergence", "label": "SSTC", "category": "TEMP", "counts": 43, "color": get_cmap_ix(6, "tab20") },
        3:{"description": "North Atlantic Drift/Agulhas", "label": "NADR", "category": "TEMP", "counts": 34, "color": get_cmap_ix(7, "tab20") },
        7:{"description": "Mediterranean", "label": "MEDI", "category": "TEMP", "counts": 82, "color": get_cmap_ix(6, "tab20") }
    }

palette = {k: v["color"] for k, v in palette.items()}

In [None]:
{k: palette[k] for k in sorted(palette.keys())}

In [None]:
plot_colored_markers(robust_md_1454.drop_duplicates("coords"), color_category="sourmash_k_8_joint_temp_1487_25m", jitter=1, cmap="tab10", palette=palette)

In [None]:
robust_md_1454.rename(columns={"sourmash_k_8_joint_temp_1487_25m": "sourmash_k_8_richter_polar_balt"}).to_csv("~/biogo-hub/provinces_final/data/sourmash_k_8_richter_polar_balt.csv")

In [None]:
robust_md_1454.to_csv(f"~/biogo-hub/provinces_final/data/sourmash_k_8_richter_upwl_balt_1454.csv")

In [None]:
# Group by 'category1' and count unique occurrences of 'category2' for each group
result = robust_md_1454.groupby('coords')['sourmash_k_10_1487_25m'].agg(lambda x: len(x.unique()))

# Filter rows where the count is greater than 1
result = result[result > 1]

# Plot drafts

In [None]:
import xarray as xr
import cartopy.crs as ccrs

In [None]:
ds = xr.open_dataset("/data/scratch/projects/punim1293/vini/data/bio-oracle/ph_baseline_2000_2018_depthsurf_xr.nc")

In [None]:
ds.to_dataframe()

In [None]:
ds["ph_max"].plot()

In [None]:
df = pd.read_csv("~/biogo-hub/data/models/model_data/sourmash_k_10_1487_25m_not_scaled_16270117_points_X_test.csv")

In [None]:
ds = df.set_index(["time", "latitude", "longitude"]).to_xarray()

In [None]:
ds["Salinity"].plot()

In [None]:
ds.dropna(dim="latitude", how="all").dropna(dim="longitude", how="all")

In [None]:
ds['province_obj'] = ds['province'].astype(str)

In [None]:
ds["province"].plot(levels=8, subplot_kws=dict(projection=ccrs.Orthographic()), transform=ccrs.PlateCarree())

In [None]:
ax = plt.axes()
ds["0"].plot(cmap="Blues", add_colorbar=False)
ax.set_facecolor("k")

In [None]:
ax = plt.axes()
ds["5"].plot(cmap="Blues", add_colorbar=False)
ax.set_facecolor("k")

In [None]:
ax = plt.axes()
ds["7"].plot(cmap="Blues", add_colorbar=False)
ax.set_facecolor("k")

In [None]:
ax = plt.axes()
ds["9"].plot(cmap="Blues", add_colorbar=False)
ax.set_facecolor("k")

In [None]:
ax = plt.axes()
ds["10"].plot(cmap="Blues", add_colorbar=False)
ax.set_facecolor("k")

In [None]:
ax = plt.axes()
ds["11"].plot(cmap="Blues", add_colorbar=False)
ax.set_facecolor("k")

In [None]:
ax = plt.axes()
ds["14"].plot(cmap="Blues", add_colorbar=False)
ax.set_facecolor("k")

In [None]:
def create_sequential_colormap(color_hex):
    # Convert color hex to RGB
    r, g, b = tuple(int(color_hex[i:i+2], 16) / 255.0 for i in (0, 2, 4))

    # Define colormap dictionary
    colormap_dict = {'red':   ((0.0, r, r),
                               (1.0, 1.0, 1.0)),
                     'green': ((0.0, g, g),
                               (1.0, 1.0, 1.0)),
                     'blue':  ((0.0, b, b),
                               (1.0, 1.0, 1.0))}

    # Create colormap
    colormap = mpl.colors.LinearSegmentedColormap('sequential_colormap', colormap_dict)

    return colormap

In [None]:
palette = {
        16:{"description": "Baltic Sea", "label": "BALT", "category": "BALT", "counts": 51, "color": get_cmap_ix(10, "tab20")},
        11:{"description": "Pacific Equatorial Divergence/Countercurrent", "label": "PEQD", "category": "TROP", "counts": 54, "color": get_cmap_ix(8, "tab20") },
        5:{"description": "Tropical Convergence", "label": "TCON", "category": "TROP", "counts": 161, "color": get_cmap_ix(1, "tab20") },
        9:{"description": "Broad Tropical", "label": "TROP", "category": "TROP", "counts": 818, "color": get_cmap_ix(0, "tab20") },
        0:{"description": "Antarctic Polar", "label": "APLR", "category": "POLR", "counts": 30, "color": get_cmap_ix(15, "tab20") },
        14:{"description": "Arctic Polar", "label": "BPLR", "category": "POLR", "counts": 42, "color": get_cmap_ix(14, "tab20") },
        10:{"description": "Upwelling Areas", "label": "UPWL", "category": "TEMP", "counts": 139, "color": get_cmap_ix(4, "tab20") },
        2:{"description": "S. Subtropical Convergence", "label": "SSTC", "category": "TEMP", "counts": 43, "color": get_cmap_ix(6, "tab20") },
        3:{"description": "North Atlantic Drift/Agulhas", "label": "NADR", "category": "TEMP", "counts": 34, "color": get_cmap_ix(7, "tab20") },
        7:{"description": "Mediterranean", "label": "MEDI", "category": "TEMP", "counts": 82, "color": get_cmap_ix(6, "tab20") }
    }

In [None]:
ds

In [None]:
p = "0"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
p = "5"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
p = "7"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
p = "9"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
p = "10"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
p = "11"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
p = "14"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
p = "16"
ax = plt.axes()
ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
ax.set_facecolor("k")

In [None]:
for p in "0 5 7 9 10 11 14 16".split():
    fig, ax = plt.subplots(figsize=(40, 40))
    ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False)
    ax.set_facecolor("k")

In [None]:
fig, ax = plt.subplots(figsize=(40, 40))
ax.set_facecolor("k")

for p in "0 5 7 9 10 11 14 16".split():
    ds[p].plot(cmap=create_sequential_colormap(palette[int(p)]["color"][1:]).reversed(), add_colorbar=False, ax=ax)