In [None]:
%matplotlib notebook
import numpy as np
import pandas

import strawb
import os
import matplotlib.pyplot as plt

import cv2

# Load files from the ONC server
Be careful, depending on the amount of data this can take a while!!

In [None]:
# load DB
db = strawb.SyncDBHandler(file_name='Default')  # loads the db
db.load_onc_db_update(save_db=True)  # update the DB, could take some time if it has to load info. from ONC

Print some info from the DB

In [None]:
print(db.dataframe.columns)

### these are the available device codes
print(db.dataframe.deviceCode.unique())

### different measurement types for PMTSPEC and LIDAR, 
# works only if hdf5 attributes are imported from files on disc
#print(db.dataframe.measurement_type.unique())

### these are the parts of each module that produce data
print(db.dataframe.dataProductCode.unique())

### Select (a) file(s) of interest

In [None]:
mask = (db.dataframe.deviceCode == 'TUMPMTSPECTROMETER001') # that's the pmtspec module
mask &= (db.dataframe.dataProductCode =='MSSCD') # that's the camera data of the pmtspec module

## select the file for the biolumi event with a window of +- 5 hours
# timestamp = pd.Timestamp('2022-03-04T23:44:09', tz='UTC')  # gain = 30
timestamp = pandas.Timestamp('2021-09-04T23:44:09', tz='UTC')  # gain = 50
# mask &= db.dataframe.dateFrom >= timestamp - pd.Timedelta('5H')  # - 5 hours
# mask &= db.dataframe.dateFrom <= timestamp + pd.Timedelta('5H')  # - 5 hours

mask_time = strawb.tools.pd_timestamp_mask_between(
    db.dataframe.dateFrom,  
    db.dataframe.dateTo, 
    timestamp - pandas.Timedelta('4H'),
    timestamp + pandas.Timedelta('3H'),
    )


# # LED TUMPMTSPECTROMETER001 run
# timestamp = pandas.Timestamp('2022-08-02T09:30:00', tz='UTC')  # LED images

# mask_time = strawb.tools.pd_timestamp_mask_between(
#     db.dataframe.dateFrom,  
#     db.dataframe.dateTo, 
#     timestamp - pandas.Timedelta('1H'),
#     timestamp + pandas.Timedelta('2H'),
#     )

mask &= mask_time
### selected one file from the DB (it's the same as the file we selected above by hand)
db.dataframe[mask]

### Download the missing files which aren't synced so far from `db.dataframe[mask]`

In [None]:
if not db.dataframe.synced[mask].all():
    db.update_db_and_load_files(
        db.dataframe[mask],
        output=True,  # print output to console
        download=True,  # download the files
        save_db=True,
    )  # update the DB

In [None]:
db.dataframe

# Import the file to the Camera Module 

In [None]:
# select the Camera file(s) -> dataProductCode == 'MSSCD'
item = db.dataframe[mask & (db.dataframe.dataProductCode =='MSSCD')]
# item = pandas.DataFrame({'fullPath': glob.glob(os.path.expanduser('~/Downloads/*CAMERA.hdf5'))})

try: # if the pmtspec file is still open
    camera.file_handler.close()
except:
    pass
    
# generate a virtual hdf5 to combine the datasets if there are multiple files selected
if len(item) > 1:
    vhdf5 = strawb.VirtualHDF5('MSSCD_event_view.hdf5', item.fullPath.to_list())  
    file_name = vhdf5.file_name
else:
    file_name = item.fullPath[0]

# create an instance of the Camera
camera = strawb.Camera(file_name)

### Print some parameters

In [None]:
print(f'Module: {camera.file_handler.module}')
print(f'Number of Frames: {camera.file_handler.exposure_time.shape[0]}')
print(f'Date: {np.min(camera.file_handler.time.asdatetime()[:])} - {np.max(camera.file_handler.time.asdatetime()[:])}')
print(f'Exposure Times [s]: {np.unique(camera.file_handler.exposure_time)}')

### Mask images to export (here only one)

In [None]:
# may take some time, if the time period isn't changed, index 170 is the bright event
if False:
    # mask over a threshold + mask invalid frames + mask no lucifer enabled
    mask = (camera.images.integrated_minus_dark > 1e6) & camera.images.valid_mask

    index = np.argsort(camera.images.integrated_minus_dark)  # sort by charge [min,...,max]
    index = index[mask[index]]  # remove invalid items  & cam_module.invalid_mask
    index = index[::-1]  # revers the order
else:
    index = [170]
print(index)

In [None]:
raw = camera.images.load_raw(index=index)[0]
rgb = camera.images.frame_raw_to_rgb(raw/100)

In [None]:
plt.figure()
raw = camera.images.load_raw(index=index)[0]
plt.imshow(camera.images.frame_raw_to_rgb(raw/2)[:,:]/2**16)  # rgb[frame, row, col], /255 to get 0->1

## Show one image here

In [None]:
plt.figure()
plt.imshow(camera.images.load_rgb(index=index)[0,:,:]/2**16)  # rgb[frame, row, col], /255 to get 0->1
#plt.savefig("figures/biolumi_demo.pdf", backend="pdf")
#plt.savefig("figures/biolumi_demo.png", dpi=120)

# Show RAW Data, mosaic and demosaicing

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

import matplotlib.colors as colors

def invert_colors(x, gamma=1, axis=-1):
    x = x.astype(np.float64)
    
    if axis<0:
        axis = len(x.shape) + axis
        
    slicer = [slice(None)]*axis
    
    # cal. color angles
    brightness = np.sqrt(np.sum(x[(*slicer, slice(0, 3))] ** 2, axis=axis))
    phi = np.arctan2(x[(*slicer, 1)],
                     x[(*slicer, 0)])
    
    mask_0 = brightness!=0
    theta = np.zeros_like(phi)
    theta[mask_0] = np.arccos(x[(*slicer, 2)][mask_0] / brightness[mask_0])
    
    # get the inverse of (0,0,0) -> (1,1,1)
    phi[~mask_0] = np.pi/4.
    theta[~mask_0] = 0.9553166181245092  # the angle for (1,1,1)

    # invert length aka. brightness
    brightness = np.sqrt(3) - brightness
    brightness = brightness**gamma

    # back to colors
    x[(*slicer, 0)] = brightness * np.sin(theta) * np.cos(phi)
    x[(*slicer, 1)] = brightness * np.sin(theta) * np.sin(phi)
    x[(*slicer, 2)] = brightness * np.cos(theta)
    x[(*slicer, slice(None, 3))] /= x[(*slicer, slice(None, 3))].max()
    return x #.astype(dtype)

def cmap_cut(name, vmax=1, invert=True):
    """Cut the upper part of the cmap defined with vmax. vmax=1 will not cut the cmap.
    vmax=.5 will cut the upper half."""
    cmap = plt.get_cmap(name)
    x = np.array([cmap(x_i) for x_i in np.linspace(0, 1, cmap.N)])
    if invert:
        x = invert_colors(x)
    cmap_new = colors.ListedColormap(x[:int(cmap.N / vmax)], f'new_{cmap.name}')
    return cmap_new

# set up data
raw = camera.file_handler.raw[index][0]
# cut the margins from the array
raw = camera.images.cut2effective_pixel(raw)
raw_i = raw/raw.max()

raw_ma = np.ma.array(raw_i).astype(np.float64)


# Plot parameter
vmax = 1.1
gray = False

figsize = np.array([9,3])

# PLOT
fig, ax = plt.subplots(ncols=2,nrows=1, figsize=figsize, squeeze=False, sharex='row', sharey='row', 
                       gridspec_kw=dict(width_ratios=[1.05, 1.15, 1.00][1:]))
ax = ax.flatten()

i_ax = 0
# # FIRST plot
# ax[i_ax].set_title('RAW - Grayscale')
# cs = ax[i_ax].imshow(raw_i, cmap=cmap_cut('gray_r'))

# # set colobar size
# divider = make_axes_locatable(ax[i_ax])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(cs, cax=cax)
# i_ax += 1

# SECOND plot
ax[i_ax].set_title('RAW - RGB mosaic')

# to set colobar size
divider = make_axes_locatable(ax[1])

for color_i in ['Greens', 'Blues', 'Reds']:
    # rgb_mask needs EffMargins to adopted if margins are cut from data
    raw_ma.mask=~camera.images.get_raw_rgb_mask(
        raw_i.shape, 
        color_i, 
        eff_margin=camera.file_handler.EffMargins[:])
    cs = ax[i_ax].imshow(
        raw_ma.filled(np.nan), 
        vmin=0, vmax=1,
        cmap='gray_r' if gray else cmap_cut(color_i, 
                                            vmax=1 if color_i=='Blue' else vmax), 
        interpolation='nearest')
    
    # add colobar to divider
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar= plt.colorbar(cs, cax=cax)
    if color_i != 'Reds':  # remove ticks for green and blue
        cbar.set_ticks([])
i_ax += 1


ax[i_ax].set_title('Image - demosaiced')
rgb = camera.images.load_rgb(index=index, subtract_dark=False)
ax[i_ax].imshow((rgb[0]/2**16))  # rgb[frame, row, col], /255 to get 0->1

# ax[-1].invert_yaxis()

plt.tight_layout()
plt.savefig('image_demosaic.pdf')

# Less zoom version
ax[-1].set_xlim(20, 600)
ax[-1].set_ylim(250, 700)
ax[-1].invert_yaxis()
plt.tight_layout()
plt.savefig('image_demosaic_zoom1.pdf')

# Zoom version
scale = .3
ax[-1].set_xlim(295, 300 + int((418-295)*scale))
ax[-1].set_ylim(450, 450 + int((550-450)*scale/.8))
ax[-1].invert_yaxis()
plt.tight_layout()
plt.savefig('image_demosaic_zoom2.pdf')

## How to access the raw data (numpy)

In [None]:
# The raw pixel values are NOT loaded by default to the module to save RAM.
# They can be accessed directly from the file with the index, index = None (default) to loads all images
a = camera.file_handler.raw[[1,3]]  # direct h5py access, allows only sorted (non-duplicate) index access.
print(a.shape)

a = camera.file_handler.raw.getunsorted([1,3])  # STRAWb helper to access it unsorted (and duplicate) by index
print(a.shape)

raw = camera.file_handler.raw[:]  # get all images
# returns array of images on default, even if only one element is accessed
print(f'raw shape: {raw.shape}') # shape of images, n_pic x 2D shape of picture
print(f'raw picture: {raw[0].shape}') # 2D shape of picture

raw = camera.images.cut2effective_pixel(raw)
print(f'shape reduced to effective pixel: {raw[0].shape}') # 2D shape of picture

# Now we have the raw-data with valid pixel range 

In [None]:
# plot the histogram for some pixel
bins = np.arange(0, 2**16, 1000)
plt.figure()
for i in range(1,10,2):
    # pixel are selected with: 1200-50*i & 900-50*i
    plt.hist(raw[:, 1200-50*i, 900-50*i], bins=bins, histtype='step')

plt.yscale('log')
plt.show()

# Load or generate ImageClusterDB
Detecting the clusters and calculating their parameters **takes some time, ~5-10min**. Therefore, generate a file name, if it exists, load the DB from there or generate it and store it for a faster access the next time.

In [None]:
t_start,  t_end = pandas.to_datetime(camera.file_handler.time.asdatetime()[[0,-1]])

# formatter
str_formatter = '{dev_code}_{t_start:%Y%m%dT%H%M%S}_{t_end:%Y%m%dT%H%M%S}_test_imagecluster.gz'
formatter_dict = {'dev_code': camera.config.device_code,
                 't_start': t_start,
                 't_end': t_end}

file_name = str_formatter.format(**formatter_dict)
file_name

In [None]:
force_update = False

image_cluster_db = strawb.sync_db_handler.ImageClusterDB(file_name=file_name, 
                                                         device_code=camera.config.device_code,
                                                         load_db=False)
if force_update or not os.path.exists(file_name):
    image_cluster_db.dataframe = camera.find_cluster.df_all()
    image_cluster_db.save_db()
else:
    image_cluster_db.load_db()
    
# image_cluster_db = strawb.sync_db_handler.ImageClusterDB(load_db=False)
# image_cluster_db.dataframe = camera.find_cluster.df_all()
    
image_cluster_db.dataframe

### Get a DataFrame without label 0 and add charge, charge_log

In [None]:
df = image_cluster_db.dataframe[image_cluster_db.dataframe.label!=0]

df = df[df.label!=0]
df.loc[:,'charge'] = (df.charge_with_noise - df.noise).to_numpy()
df.loc[:,'charge_log'] = np.log(df.charge_with_noise - df.noise)

### Hist of cluster sizes

In [None]:
parameter = df.n_pixel
# get similar size bins in log space
bins = np.unique(np.geomspace(parameter.min(), parameter.max()*1.1, 100).astype(int))

# # linear bins
# bins = np.arange(int(df.n_pixel.min()), int(df.n_pixel.max()*1.1), 1)

count, edges = np.histogram(parameter, bins=bins,)

plt.figure()
StepPatch = plt.stairs(count, edges, fill=True)
plt.yscale('log')
plt.xscale('log')
plt.xlabel('Cluster Size [pixel]')
plt.ylabel('Counts')
plt.grid()
StepPatch.zorder=5
plt.tight_layout()

## Show images with detected cluster bigger than a threshold

In [None]:
# get pictures filtered by a parameter and a limit

# limit = 2e5
# parameter = 'charge'

limit = 25  # limit from plot
parameter = 'n_pixel'

file_times = camera.file_handler.time.asdatetime()[:]

gb = df[df[parameter]  > limit].groupby('time')
df_filter = df[df.time.isin(gb.groups) & (df.n_pixel > 5)]
# indexes = np.unique([np.argwhere(file_times == i) for i in gb.groups])

df_filter

In [None]:
limit = 20

# images = []  # to store the images later

for t_i, df_i in df_filter.groupby('time'):
    fig, ax = plt.subplots(ncols=2, #len(df_i)+1,
                           nrows=1, squeeze=False, figsize=(9,5))

    ax = ax.flatten()
    
    i = np.argwhere(file_times==t_i).flatten()[0]
    
    rgb = camera.images.load_rgb(index=i)[0]/2**16
#     images.append(rgb)  # to store the images later
#     continue
    ax[0].imshow(rgb)  # rgb[frame, row, col], /255 to get 0->1
    ax[1].imshow(rgb)  # rgb[frame, row, col], /255 to get 0->1
    
    n=1
    for j, df_j in df_i.iterrows():
        box_corners = cv2.boxPoints(((df_j.box_center_x, df_j.box_center_y), 
                                     (df_j.box_size_x, df_j.box_size_y), 
                                     df_j.angle))
        
        ax[1].plot(strawb.tools.connect_polar(box_corners[:,1]),
                 strawb.tools.connect_polar(box_corners[:,0]),
                 color='w', alpha=.5, label='Min. Box'
                 )
        ax[1].plot(df_j.center_of_mass_y, df_j.center_of_mass_x, 
                 'o', color='w', alpha=.75,
                 label='Center of Mass',
                 )
        ax[1].plot(df_j.center_of_pix_y, df_j.center_of_pix_x,
                 'x', color='w', alpha=.75,
                 label='Center of Pix',
                 )
        ax[1].plot(df_j.box_center_y, df_j.box_center_x,
                 '>', color='w', alpha=.75,
                 label='Center of Box',
                 )
        
        
        
    # don't show double labels
    handles, labels = ax[1].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax[1].legend(by_label.values(), by_label.keys(), loc='upper right')

    fig.suptitle(f'{camera.file_handler.module} - {t_i:%Y-%m-%d %H:%M:%S}', fontsize=12)
    plt.tight_layout()
    plt.savefig(f'{camera.file_handler.module}_{t_i:%Y%m%dT%H%M%S}_camera.pdf')
    
# # in case you can store the images, e.g. with numpy
# np.savez_compressed('LED_images.npz', images=rgb_s)

# # and read it back
# with np.load('LED_images.npz') as f:
#     images = f['images']

### GridSpec
GridSpec can be used to set different sizes for each subplot.
Here the goal is to get subplots all with the same height. All subplots are set to aspect-ratio equal. The subplots can be of ratios (widt/hight)

The width ratio for the GridSpec is simply the list of all individual image ratios $r_i$ of:
$$ r_i = \frac{h_i}{w_i}$$

In [None]:
import cv2
import copy

def get_rect(row):
    return ((row['box_center_x'], row['box_center_y']),
            (row['box_size_x'], row['box_size_y']),
            row['angle'])
    
for t_i, df_i in df_filter.groupby('time'):
    df_ii = df_i[df_i[parameter] > limit]
    
    # Calculate all pictures of clusters for on picture to set up the GridSpec width_ratios
    rgb = camera.images.load_rgb(index=i)[0]/2**16
    index = np.argmin(np.abs(camera.file_handler.time[:] - df_i.time.iloc[0].timestamp()))
    
    rgb = camera.images.load_rgb(index=index)[0]/2**16
    
    # Calculate all pictures of clusters for on picture to set up the GridSpec width_ratios
    image_size_list = [rgb.shape[:2]]
    images_tran = []
    
    for j, df_j in df_ii.iterrows():        
        rect_scale = strawb.camera.rect_scale_pad(get_rect(df_j), scale=1, pad=50)
        img_target, rect_target, t_matrix = strawb.camera.img_rectangle_cut(rgb,
                                                                            rect_scale, 
                                                                            angle=None,
                                                                            angle_normalize=False)
        images_tran.append([img_target, rect_target, t_matrix])
        image_size_list.append(img_target.shape[:2])
        
    # grid spec to resize the plots a bit 
    s = np.array(image_size_list)
    width_ratios = s[:,1]/s[:,0]
    print(width_ratios, len(images_tran))
    # Plot
    fig, ax = plt.subplots(figsize=(9,5), 
                           nrows=1, 
                           ncols=len(df_i[df_i[parameter] > limit])+1, 
                           squeeze=False,
                           gridspec_kw=dict(width_ratios=width_ratios, height_ratios=[1])
                          )
    ax = ax.flatten()
    
    # Draw image and add cluster
    ax[0].imshow(rgb)

    for n, (j, df_j) in enumerate(df_ii.iterrows()):
        # Draw Box and more for croped to cluster
        rect_scale = strawb.camera.rect_scale_pad(get_rect(df_j), scale=1, pad=50)
        img_target, rect_target, t_matrix = images_tran[n]
        box_i = strawb.camera.transform_cv2np(cv2.boxPoints(get_rect(df_j)), t_matrix)
        
        # Draw Box and more for cluster
        label_dict = dict(color='w', alpha=.5)
        ax[0].plot(*strawb.tools.connect_polar(box_i).T[::-1], **label_dict, label='Biolumi. Event')
#         ax[0].plot(*df_j.center_of_mass[::-1],'o', **label_dict, label='Center of Mass')
#         ax[0].plot(*df_j.center_of_pix[::-1], '*', **label_dict, label='Center of Pix')
#         ax[0].plot(*df_j.box_center[::-1],'x', **label_dict, label='Center of Box')

        ax[n+1].imshow(img_target)
        ax[n+1].plot(*strawb.tools.connect_polar(box_i).T, '-', color='gray', alpha=.75, label='Cluster Box')


        pos = np.array(df_j[['center_of_mass_y', 'center_of_mass_x']])
        ax[n+1].plot(*strawb.camera.transform_np([pos], t_matrix).T, 'o', **label_dict, label='Center Mass')
        pos = np.array(df_j[['center_of_pix_y', 'center_of_pix_x']])
        ax[n+1].plot(*strawb.camera.transform_np([pos], t_matrix).T, '*', **label_dict, label='Center Pixel',)
        pos = np.array(df_j[['box_center_y', 'box_center_x']])
        ax[n+1].plot(*strawb.camera.transform_np([pos], t_matrix).T, 'x', **label_dict, label='Center Box',)
        
    for axi in ax:
        axi.axis('off')
        
    # don't show double labels
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    for label_i in by_label:
        lin_i = copy.copy(by_label[label_i])
        lin_i.set_alpha(1)
        lin_i.set_c('gray')
        by_label[label_i] = lin_i
#     ax[0].legend(by_label.values(), by_label.keys(), ncol=2)
    
    leg = fig.legend(by_label.values(), by_label.keys(), loc='upper center', ncol=4)
        
#     fig.tight_layout()
#     fig.subplots_adjust(right=0.75) 
    fig.tight_layout()
    
#     plt.tight_layout()

# ML filtering

In [None]:
# Set up a plotting function for fast df insights
def stair_scatter(df, color=None, size=1, alpha=.1, ax=None, columns=None, log=True, **kwargs):
    if columns is None:
        columns = df.columns
    rows = columns[1:] 
    cols = columns[:-1]
    
    # ax[row, col]
    if ax is None:
        fig, ax = plt.subplots(nrows=len(rows), 
                       ncols=len(cols), 
                       sharex='col', sharey='row', 
                       squeeze=False, 
                       **kwargs)
        
        for j, cols_j in enumerate(cols):
            if log:
                ax[-1,j].set_xscale('log')
            ax[-1,j].set_xlabel(cols_j.replace('_',' '), rotation=0) #rotation=70
        for i, row_i in enumerate(rows):
            ax[i, 0].set_ylabel(row_i.replace('_',' '), rotation=90) #rotation=30)
            if log:
                ax[i, 0].set_yscale('log')
        for i, row_i in enumerate(rows):
            for j, cols_j in enumerate(cols):
                if j<=i:
                    ax[i, j].grid()
                else:
                    ax[i,j].axis('off')

    for i, row_i in enumerate(rows):
        for j, cols_j in enumerate(cols):
            if j<=i:#cols_j != row_i:
    #             ax[i, j].text(0.5, 0.5, f'x:{cols_j}\ny:{row_i}', 
    #                           horizontalalignment='center',
    #                           verticalalignment='center', 
    #                           transform=ax[i, j].transAxes)
                ax[i, j].scatter(df[cols_j], df[row_i], s=size, c=color, alpha=alpha)
        
    return ax

In [None]:
import sklearn.cluster

# define a container for a ML-Model
class ClusterModel:
    def __init__(self, model, df, columns, n=1000, name='', *kwargs):
        self.name = name
        # define the model
        self.model = model

        # fit the model
        X = sklearn.preprocessing.StandardScaler().fit_transform(df[columns].to_numpy())
        random_int = np.random.randint(0, len(X), n)  # select n out of X
        self.model.fit(X[random_int])

        # cal the all peaks
        self.labels = self.model.predict(X)
        
        # sort the labels by count
        clusters, clusters_counts = np.unique(self.labels, return_counts=True)
        self.clusters = clusters[np.argsort(-clusters_counts)]
        self.clusters_counts = clusters_counts[np.argsort(-clusters_counts)]

        self.labels_cs = np.zeros_like(self.labels, dtype=int) - 1
        for i, c_i in enumerate(self.clusters):
            self.labels_cs[self.labels == c_i] = i
            
        self.df = df

    def plot_level_hist(self, ax=None, norm_x=True, norm_y=True):
        if ax is None:
            plt.figure()
            ax = plt.gca()
            ax.set_yscale('log')
            ax.set_xlabel('Classification Typ')
            ax.set_ylabel('Count')
            
        if norm_x:
            class_typ = np.linspace(0, 1,len(self.clusters_counts))
        else:
            class_typ = np.arange(0,len(self.clusters_counts)+1, 1)
            
        if norm_y:
            norm_y_s = self.clusters_counts.max()
        else:
            norm_y_s = 1
            
        ax.plot(class_typ, self.clusters_counts/norm_y_s, label=f'{self.name.replace("_","-")} {len(self.clusters)}')
#         plt.hist(self.labels_cs, bins=len(self.clusters_counts))
        return ax

In [None]:
# Train model
columns = ['n_pixel', 'charge']
columns_log = ['n_pixel', 'charge_log']

n = 20000
n_clusters = 4
k_means_4 = ClusterModel(sklearn.cluster.KMeans(n_clusters=n_clusters), df, columns, name='KMeans_4', n=n)
# k_means_4_log = ClusterModel(sklearn.cluster.KMeans(n_clusters=n_clusters), df, columns_log, name='KMeans_4_log', n=n)

n_clusters = 2
k_means_2 = ClusterModel(sklearn.cluster.KMeans(n_clusters=n_clusters), df, columns, name='KMeans_2', n=n)
# k_means_2_log = ClusterModel(sklearn.cluster.KMeans(n_clusters=n_clusters), df, columns_log, name='KMeans_2_log', n=n)

# n = 2000
# aff_pro = ClusterModel(sklearn.cluster.AffinityPropagation(damping=.7), df, columns, name='AffinityPropagation')
# aff_pro_log = ClusterModel(sklearn.cluster.AffinityPropagation(damping=.7), df, columns_log, name='AffinityPropagation_log')

# n = 20000
# mean_shift = ClusterModel(sklearn.cluster.MeanShift(n_jobs=-1), df, columns, name='MeanShift')
# mean_shift_log = ClusterModel(sklearn.cluster.MeanShift(n_jobs=-1), df, columns_log, name='MeanShift_log')



In [None]:
algorithms = [k_means_2,#k_means_2_log,
              k_means_4, #k_means_4_log,
#               aff_pro, aff_pro_log,
#               mean_shift, mean_shift_log
             ]


ax = None
for i, alg_i in enumerate(algorithms):
    ax = alg_i.plot_level_hist(ax)
    
plt.legend()
plt.grid()
plt.tight_layout()

In [None]:
columns = [#'plateau_sizes', 'left_thresholds', 'right_thresholds'
           'n_pixel', 'charge', #'charge_with_noise', 'noise', 
           'sn_mean_deviation', 'sn_mean_deviation_sigma']

alg_i=k_means_4
stair_scatter(alg_i.df, 
              color=alg_i.labels_cs,
              size=1+10*alg_i.labels_cs,
              columns=columns,
              alpha=1,
              figsize=(9,9))
    

plt.tight_layout()

In [None]:
alg_i=k_means_4
# Sort the labels by counts
labels, counts = np.unique(alg_i.labels_cs, return_counts=True)

# Take half of the labels and show cluster
df[alg_i.labels_cs>labels[np.argsort(counts)][len(labels)//3]]

In [None]:
alg_i=k_means_2
# Sort the labels by counts
labels, counts = np.unique(alg_i.labels_cs, return_counts=True)

# Take half of the labels and show cluster
df[alg_i.labels_cs>labels[np.argsort(counts)][len(labels)//2]]