In [1]:
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

%matplotlib inline

%cd ../..

!hostname

/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818-3d
jrc0002


In [2]:
import os

import re
import pandas as pd
import numpy as np

import h5py as h5

import pli
import pli.image as im

from tqdm import tqdm

In [3]:
# Get mask and feature info

methods = [
    "cl-2d_all_augmentations_circle_small",
    "cl-2d_only_flip_circle_small",
    "cl-3d_no_augmentations_sphere_small",
    "cl-3d_only_scale-attenuation_sphere_small",
    "cl-2d_no_augmentations_circle_small",
    "cl-2d_only_scale-attenuation_circle_small",
    "cl-3d_only_affine_sphere_small",
    "cl-3d_only_scale-thickness_sphere_small",
    "cl-2d_only_affine_circle_small",
    "cl-2d_only_scale-thickness_circle_small",
    "cl-3d_only_blur_sphere_small",
    "cl-2d_only_blur_circle_small",
    "cl-3d_all_augmentations_sphere_small",
    "cl-3d_only_flip_sphere_small"
]

mask_path = "data/aa/masks/cortex/"

# Group of the cluster in the H5 files
cluster_group = "Image"

# Group of the features in the H5 files
feature_group = "PCA"

# Masks
mask_pyramid = 6
background_cluster = 0

# Smoothing of the features
smooth_sigma = 1.0

# Groups
mask_group = 'Image'
background_class = 3

###

###

from skimage import filters
from vervet1818_3d.utils.io import read_masked_features

from tqdm import tqdm

p = re.compile('.*s([0-9]{4})_.*h5')

mask_list = []
for f in sorted(os.listdir(mask_path)):
    match = p.match(f)
    if match:
        id = int(match[1])
        mask_list.append({'id': id, 'file_mask': os.path.join(mask_path, f)})
mask_df = pd.DataFrame(mask_list)

files_dataframes = {}

def load_features(model_name):

    feature_folder = f"data/aa/pca_80/{model_name}"

    feature_list = []
    for f in sorted(os.listdir(feature_folder)):
        match = p.match(f)
        if match:
            id = int(match[1])
            with h5.File(os.path.join(feature_folder, f)) as h5f:
                spacing = h5f[feature_group].attrs['spacing']
                origin = h5f[feature_group].attrs['origin']
            feature_list.append({'id': id, 'spacing': spacing, 'origin': origin,
                                    'file_features': os.path.join(feature_folder, f)})
    feature_df = pd.DataFrame(feature_list)

    files_df = mask_df.merge(feature_df, on='id', how='inner').sort_values('id').reset_index(drop=True)

    selected_features = []
    selected_masks = []

    for k, r in files_df.sort_values('id').iterrows():
        features, mask = read_masked_features(
            r.file_features,
            r.file_mask,
            mask_pyramid=mask_pyramid,
            data_group=feature_group,
            mask_group=mask_group
        )
        assert features.shape[:2] == mask.shape, f"{features.shape[:2]} differs from {mask.shape}"

        # Smooth features a bit
        if smooth_sigma > 0.:
            features = filters.gaussian(features, multichannel=True, sigma=smooth_sigma)

        selected_features.append(features)
        selected_masks.append(mask)


    valid_features = [f[m != background_class] for f, m in zip(selected_features, selected_masks)]
    valid_lengths = [len(vf) for vf in valid_features]
    valid_features = np.vstack(valid_features)

    print(f"Valid features have shape {valid_features.shape}")

    return valid_features

In [8]:
# Optimal number depends on each feature embedding. So we try different numbers of clusters
n_clusters = [2, 8, 32, 128]

n_iter = 10 # 10
max_iter = 300 # 100

n_samples = 10_000
seed = 299792458

###

from collections import namedtuple

from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans
from tqdm import tqdm

Result = namedtuple("Result", ['method', 'n_clusters', 'silhouette_scores'])
results = []

for m in tqdm(methods):
    print("Load features ...")
    valid_features = load_features(m)

    for nc in n_clusters:
        print(f"Fit {nc} clusters {n_iter} times")
        silhouette_scores = []

        for n in range(n_iter):
            np.random.seed(seed + n)

            # Reduce to the selected valid components
            ix = np.random.choice(np.arange(len(valid_features)), n_samples)

            km = KMeans(nc, n_init=1, max_iter=max_iter, tol=1e-4, random_state=(seed + n), verbose=False)
            cluster_labels = km.fit_predict(valid_features[ix])

            silhouette_scores.append(silhouette_score(valid_features[ix], cluster_labels))

        results.append(Result(m, nc, silhouette_scores))

  0%|          | 0/14 [00:00<?, ?it/s]

Load features ...
Valid features have shape (9171147, 15)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


  7%|▋         | 1/14 [01:04<14:03, 64.86s/it]

Load features ...
Valid features have shape (9171147, 21)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 14%|█▍        | 2/14 [02:40<16:38, 83.24s/it]

Load features ...
Valid features have shape (9171147, 48)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 21%|██▏       | 3/14 [04:46<18:45, 102.33s/it]

Load features ...
Valid features have shape (9171147, 50)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 29%|██▊       | 4/14 [06:49<18:28, 110.83s/it]

Load features ...
Valid features have shape (9171147, 23)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 36%|███▌      | 5/14 [08:30<16:05, 107.33s/it]

Load features ...
Valid features have shape (9171147, 19)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 43%|████▎     | 6/14 [10:15<14:09, 106.25s/it]

Load features ...
Valid features have shape (9171147, 42)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 50%|█████     | 7/14 [12:11<12:47, 109.63s/it]

Load features ...
Valid features have shape (9171147, 49)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 57%|█████▋    | 8/14 [14:12<11:19, 113.33s/it]

Load features ...
Valid features have shape (9171147, 17)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 64%|██████▍   | 9/14 [15:41<08:47, 105.46s/it]

Load features ...
Valid features have shape (9171147, 24)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 71%|███████▏  | 10/14 [17:07<06:37, 99.48s/it]

Load features ...
Valid features have shape (9171147, 50)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 79%|███████▊  | 11/14 [19:03<05:13, 104.60s/it]

Load features ...
Valid features have shape (9171147, 23)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 86%|████████▌ | 12/14 [20:37<03:22, 101.37s/it]

Load features ...
Valid features have shape (9171147, 29)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


 93%|█████████▎| 13/14 [22:13<01:39, 99.74s/it] 

Load features ...
Valid features have shape (9171147, 54)
Fit 2 clusters 10 times
Fit 8 clusters 10 times
Fit 32 clusters 10 times
Fit 128 clusters 10 times


100%|██████████| 14/14 [24:11<00:00, 103.64s/it]


In [9]:
json_path = "doc/ablation/silhouette.json"

###

import pandas as pd

result_frame = pd.DataFrame(results)

result_frame['mean_silhouette'] = result_frame['silhouette_scores'].apply(np.mean)
result_frame['std_err_silhouette'] = result_frame['silhouette_scores'].apply(lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))

result_frame.to_json(json_path)

result_frame.head()

Unnamed: 0,method,n_clusters,silhouette_scores,mean_silhouette,std_err_silhouette
0,cl-2d_all_augmentations_circle_small,2,"[0.32832813, 0.33136317, 0.33770776, 0.3330199...",0.333661,0.000944
1,cl-2d_all_augmentations_circle_small,8,"[0.24040191, 0.25170472, 0.2393448, 0.22192466...",0.249659,0.00389
2,cl-2d_all_augmentations_circle_small,32,"[0.21511704, 0.2193213, 0.21206336, 0.21128039...",0.213862,0.001722
3,cl-2d_all_augmentations_circle_small,128,"[0.18830153, 0.18426615, 0.1895832, 0.18970875...",0.184597,0.001353
4,cl-2d_only_flip_circle_small,2,"[0.20351152, 0.2023205, 0.20528288, 0.19764656...",0.199221,0.004088


In [10]:
# Read stored values

result_frame = pd.read_json(json_path)

result_frame.head()

Unnamed: 0,method,n_clusters,silhouette_scores,mean_silhouette,std_err_silhouette
0,cl-2d_all_augmentations_circle_small,2,"[0.3283281326, 0.33136317130000004, 0.33770775...",0.333661,0.000944
1,cl-2d_all_augmentations_circle_small,8,"[0.2404019088, 0.2517047226, 0.2393448055, 0.2...",0.249659,0.00389
2,cl-2d_all_augmentations_circle_small,32,"[0.2151170373, 0.21932129560000002, 0.21206335...",0.213862,0.001722
3,cl-2d_all_augmentations_circle_small,128,"[0.1883015335, 0.18426615000000002, 0.18958319...",0.184597,0.001353
4,cl-2d_only_flip_circle_small,2,"[0.20351152120000002, 0.2023205012, 0.20528288...",0.199221,0.004088


In [15]:
pivot_df = result_frame.pivot(index='method', columns='n_clusters')
pivot_df.drop(columns=['silhouette_scores'], inplace=True)

# Reset the index to make 'method' a column again
pivot_df.reset_index(inplace=True)

pivot_df = pivot_df.sort_values(['method'])

pivot_df = pivot_df.round({
    ('mean_silhouette', 2): 2,
    ('mean_silhouette', 8): 2,
    ('mean_silhouette', 32): 2,
    ('mean_silhouette', 128): 2,
    
})


pivot_df.head()

Unnamed: 0_level_0,method,mean_silhouette,mean_silhouette,mean_silhouette,mean_silhouette,std_err_silhouette,std_err_silhouette,std_err_silhouette,std_err_silhouette
n_clusters,Unnamed: 1_level_1,2,8,32,128,2,8,32,128
0,cl-2d_all_augmentations_circle_small,0.33,0.25,0.21,0.18,0.000944,0.00389,0.001722,0.001353
1,cl-2d_no_augmentations_circle_small,0.15,0.18,0.17,0.15,0.00128,0.001217,0.000748,0.000407
2,cl-2d_only_affine_circle_small,0.18,0.25,0.22,0.18,0.001864,0.002453,0.001005,0.000849
3,cl-2d_only_blur_circle_small,0.15,0.18,0.17,0.16,0.001047,0.001791,0.000952,0.000864
4,cl-2d_only_flip_circle_small,0.2,0.21,0.18,0.15,0.004088,0.003956,0.000939,0.00074


In [16]:
print(pivot_df.to_markdown())

|    | ('method', '')                            |   ('mean_silhouette', 2) |   ('mean_silhouette', 8) |   ('mean_silhouette', 32) |   ('mean_silhouette', 128) |   ('std_err_silhouette', 2) |   ('std_err_silhouette', 8) |   ('std_err_silhouette', 32) |   ('std_err_silhouette', 128) |
|---:|:------------------------------------------|-------------------------:|-------------------------:|--------------------------:|---------------------------:|----------------------------:|----------------------------:|-----------------------------:|------------------------------:|
|  0 | cl-2d_all_augmentations_circle_small      |                     0.33 |                     0.25 |                      0.21 |                       0.18 |                 0.000943793 |                 0.00389019  |                  0.00172193  |                   0.00135305  |
|  1 | cl-2d_no_augmentations_circle_small       |                     0.15 |                     0.18 |                      0.17 |             

|    | ('method', '')                            |   ('mean_silhouette', 2) |   ('mean_silhouette', 8) |   ('mean_silhouette', 32) |   ('mean_silhouette', 128) |   ('std_err_silhouette', 2) |   ('std_err_silhouette', 8) |   ('std_err_silhouette', 32) |   ('std_err_silhouette', 128) |
|---:|:------------------------------------------|-------------------------:|-------------------------:|--------------------------:|---------------------------:|----------------------------:|----------------------------:|-----------------------------:|------------------------------:|
|  0 | cl-2d_all_augmentations_circle_small      |                     0.33 |                     0.25 |                      0.21 |                       0.18 |                 0.000943793 |                 0.00389019  |                  0.00172193  |                   0.00135305  |
|  1 | cl-2d_no_augmentations_circle_small       |                     0.15 |                     0.18 |                      0.17 |                       0.15 |                 0.00127975  |                 0.00121702  |                  0.000747546 |                   0.000407132 |
|  2 | cl-2d_only_affine_circle_small            |                     0.18 |                     0.25 |                      0.22 |                       0.18 |                 0.00186429  |                 0.00245258  |                  0.00100536  |                   0.000848703 |
|  3 | cl-2d_only_blur_circle_small              |                     0.15 |                     0.18 |                      0.17 |                       0.16 |                 0.00104724  |                 0.00179105  |                  0.00095238  |                   0.000864283 |
|  4 | cl-2d_only_flip_circle_small              |                     0.2  |                     0.21 |                      0.18 |                       0.15 |                 0.00408838  |                 0.0039557   |                  0.00093934  |                   0.000739809 |
|  5 | cl-2d_only_scale-attenuation_circle_small |                     0.16 |                     0.18 |                      0.18 |                       0.15 |                 0.00168025  |                 0.00135849  |                  0.00123078  |                   0.000958443 |
|  6 | cl-2d_only_scale-thickness_circle_small   |                     0.16 |                     0.17 |                      0.18 |                       0.15 |                 0.00368238  |                 0.0022219   |                  0.00151036  |                   0.000147596 |
|  7 | cl-3d_all_augmentations_sphere_small      |                     0.21 |                     0.2  |                      0.17 |                       0.13 |                 0.00542504  |                 0.00290505  |                  0.000671907 |                   0.000676948 |
|  8 | cl-3d_no_augmentations_sphere_small       |                     0.26 |                     0.14 |                      0.12 |                       0.1  |                 0.0025538   |                 0.00207227  |                  0.000677404 |                   0.000763315 |
|  9 | cl-3d_only_affine_sphere_small            |                     0.16 |                     0.15 |                      0.14 |                       0.13 |                 0.00191893  |                 0.00267411  |                  0.0010629   |                   0.000432022 |
| 10 | cl-3d_only_blur_sphere_small              |                     0.18 |                     0.13 |                      0.12 |                       0.11 |                 0.00257565  |                 0.000830295 |                  0.00139488  |                   0.000826995 |
| 11 | cl-3d_only_flip_sphere_small              |                     0.13 |                     0.14 |                      0.13 |                       0.11 |                 0.00270418  |                 0.00132503  |                  0.000947729 |                   0.000566278 |
| 12 | cl-3d_only_scale-attenuation_sphere_small |                     0.15 |                     0.14 |                      0.12 |                       0.1  |                 0.00145784  |                 0.00233601  |                  0.000963543 |                   0.000484283 |
| 13 | cl-3d_only_scale-thickness_sphere_small   |                     0.16 |                     0.15 |                      0.12 |                       0.1  |                 0.000670484 |                 0.00270984  |                  0.00166129  |                   0.00052618  |
