In [1]:
import os
import numpy as np
import glob
import json

import matplotlib.patches as patches

from astropy.io import fits
from astropy.wcs import WCS
from astropy.wcs.utils import skycoord_to_pixel
from astropy.coordinates import SkyCoord
import astropy.units as u
from astropy.coordinates import Angle

from sunpy.coordinates import frames
import sunpy.map as sunmap

import skimage.io as io

import MeanShift as MS

from copy import deepcopy
import pandas as pd

import clustering_utilities as c_utils
import tracking_utilities as t_utils

import importlib
importlib.reload(c_utils)

import ipywidgets as widgets
import matplotlib.pyplot as plt
import seaborn as sns

import panel as pn

from sklearn.model_selection import ParameterGrid
import concurrent.futures
from itertools import repeat
import multiprocessing

from tqdm.notebook import tqdm

from datetime import datetime, timedelta

%matplotlib ipympl

%load_ext autoreload
%autoreload 2





In [2]:
import matplotlib
print(matplotlib.__version__)


3.4.3


In [3]:
# wl_dir = "/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/image"
# wl_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/all"  
wl_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/all"  
# wl_dir = "/Users/nielssayez/Documents/Deepsun/Classification_dataset/2002-2019/all"  
# wl_dir = "/Users/nielssayez/Documents/Deepsun/Classification_dataset/ManualAnnotation/"  
wl_list = sorted(glob.glob(os.path.join(wl_dir, '**/*.FTS'),recursive=True))
wl_basenames = [ os.path.basename(wl) for wl in wl_list ]

# masks_dir = '/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/GroundTruth'
# masks_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/feb2023/T425-T375-T325_fgbg'
masks_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/T425-T375-T325_fgbg'
# masks_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/T400-T350-Alternating_pen_um'
# masks_dir = '/Users/nielssayez/Documents/Deepsun/Classification_dataset/ManualAnnotation/GroundTruth'

sqlite_db_path = "/globalscratch/users/n/s/nsayez/Classification_dataset/drawings_sqlite.sqlite"
# sqlite_db_path = "/Users/nielssayez/Documents/Deepsun/Classification_dataset/drawings_sqlite.sqlite"
database = sqlite_db_path
print(len(wl_list), )

3150


In [4]:
# rotten_list = [ ]
rotten_list = [                    
                    69,75,93,94,201,203,311,315,337,341,
                    403,409,420,441,613,615,668,726,743,
                    755,778,779,976,996,
                        
                    1036,1066,1081,1138,1255,1296,1337,1379,1398,1471,
                    1688,1735,1823,1925,1958,1989,1990,1995,
    
                    2018,2030,2073,2078,2104,2139,2151,2204,2205,2206,2230,2300,
                    2312,2325,2327,2349,2376,2452,2460,2461,2559,2673,2778,
                    2790,2793,2826,2836,2894,2897,2962,
    
                    3035,3142,3172,3178,3196,3217,3224,3242,3292,3296,
                    3312,3317,3335,3353,3389,3390,3415,3431,3510,3535,
                    3551,3571,3614,3664,3720,
              ]

In [5]:
# root_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset"
# tmp = root_dir+'/ManualAnnotation/wl_list2dbGroups_Classification.json'


# root_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/feb2023'
root_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2'
tmp = root_dir+'/wl_list2dbGroups_Classification.json'

In [6]:
huge_db_dict = { }
with open(tmp, 'r') as f:
    huge_db_dict = json.load(f)

In [7]:
print(len(huge_db_dict.keys()))
print(huge_db_dict[list(huge_db_dict.keys())[0]])

3150
{'group_list': [{'id': 164028, 'Latitude': 0.125985, 'Longitude': 1.413721, 'posx': 151.0, 'posy': 680.0, 'Lcm': -55.31, 'Zurich': 'J', 'McIntosh': 'Hax', 'angle': 57.08, 'area_px': 86.0, 'area_muHem': 103.43}, {'id': 164029, 'Latitude': 0.177233, 'Longitude': 1.204326, 'posx': 200.0, 'posy': 622.0, 'Lcm': -43.31, 'Zurich': 'D', 'McIntosh': 'Dao', 'angle': 45.38, 'area_px': 332.0, 'area_muHem': 316.74}, {'id': 164030, 'Latitude': -0.050828, 'Longitude': 0.8324, 'posx': 378.0, 'posy': 652.0, 'Lcm': -22.0, 'Zurich': 'A', 'McIntosh': 'Axx', 'angle': 22.04, 'area_px': 7.0, 'area_muHem': 4.95}, {'id': 164031, 'Latitude': -0.251044, 'Longitude': 1.076417, 'posx': 335.0, 'posy': 786.0, 'Lcm': -35.99, 'Zurich': 'C', 'McIntosh': 'Cro', 'angle': 36.57, 'area_px': 44.0, 'area_muHem': 36.03}, {'id': 164032, 'Latitude': -0.374039, 'Longitude': 1.620676, 'posx': 238.0, 'posy': 915.0, 'Lcm': -67.17, 'Zurich': 'H', 'McIntosh': 'Hkx', 'angle': 67.81, 'area_px': 133.0, 'area_muHem': 229.49}, {'id':

In [8]:
huge_dict = { }



In [9]:
def add_rejected_to_distributions(distributions, rejected_class):
    '''
    Add rejected to distributions according to the class
    '''
    if rejected_class not in distributions:
        distributions[rejected_class] = 0
    distributions[rejected_class] += 1

    return

def matching_in_wl(basename, huge_dict, ax= None):
    cur_image_dict = huge_dict[basename]
    
    angle = cur_image_dict["SOLAR_P0"]
    deltashapeX = cur_image_dict["deltashapeX"]
    deltashapeY = cur_image_dict["deltashapeY"]
    
    drawing_radius_px = huge_db_dict[basename]["dr_radius_px"]
    
    group_list = cur_image_dict['db']
    
    ms_dict = cur_image_dict['meanshift']
    
#     print(ms_dict)
    
    centroids = np.array(ms_dict["centroids"])
    centroids_px = np.array(ms_dict["centroids_px"])
    
    
    ms_centroids, ms_members = centroids_px, ms_dict['groups_px']
#     print('ms_members', ms_members)
    
    db_classes = [{"Zurich":item['Zurich'], "McIntosh":item['McIntosh'] } for item in group_list]
    db_bboxes = [np.array(item['bbox_wl']) for item in group_list]


    db_centers_px = np.array([[(b[2]+b[0])/2,(b[3]+b[1])/2] for b in db_bboxes])
    
    ########
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5,5))
        
    for i, bbox in enumerate(db_bboxes):
        ax.scatter(bbox[0],bbox[1])
        linestyle = '-'
        ax.add_patch(patches.Rectangle((bbox[0],bbox[1]),bbox[2]-bbox[0],bbox[3]-bbox[1],linewidth=1,edgecolor='b',facecolor='none', linestyle=linestyle))
    
    for g in ms_members:
        g = np.array(g)
        ax.scatter(g[:,1],g[:,0])
    ax.set_xlim(0,2048)
    ax.set_ylim(0,2048)

    for c in ms_centroids:
        ax.scatter(c[0],c[1], c='r', marker='x')
    # ax.invert_xaxis()
    # fig.show()
    ########
        
    # check that current bbox is does not overlap any
    isolated_bboxes_bool = np.array(c_utils.get_intersecting_db_bboxes(db_bboxes)) == 0
    isolated_bboxes_indices = np.where(isolated_bboxes_bool == True)[0]
#     print("isolated_bboxes_bool",isolated_bboxes_bool)
#     print(isolated_bboxes_indices)
    
    cur_rejected_class_distibutions = { 
            'noMS_but_DB': {},
            'singleMS_multipleDB': {}, 
            'num_oneDBbbox_multipleMSoverlap_ambiguity':{},
            # 'noDB_but_MS': {},
        }
    cur_out_stats = {
        # General info
        'num_DB_groups':len(db_bboxes),
        'num_MS_groups':len(centroids_px),
        'num_DB_isolated_groups':len(isolated_bboxes_indices),
        'num_DB_overlaping_bboxes':len(db_bboxes) - len(isolated_bboxes_indices),
        # MS with DB matching info
        "num_MSmatchesDB":0,
        # MS with DB rejection info
        "num_noMS_but_DB_reject":0,
        "num_singleMS_multipleDB_reject":0,
        "num_oneDBbbox_multipleMSoverlap_ambiguity_reject":0,
        # MS with DB no match info
        "num_noDB_but_MS":0,
        }
    cur_out_groups = []
    for i, (db_bbox, db_center, db_class) in enumerate(
                                            zip([db_bboxes[a] for a in isolated_bboxes_indices.tolist() ],
                                                [db_centers_px[a] for a in isolated_bboxes_indices.tolist()],
                                                [db_classes[a] for a in isolated_bboxes_indices.tolist()],
                                            )):
        
        
        
        intersect = c_utils.contains_ms_groups(db_bbox, db_center, ms_centroids, ms_members)
        
        if sum(intersect) == 0: # Il n'y a eu aucune détection dans cette zone
            cur_out_stats['num_noMS_but_DB_reject'] += 1
            cause = 'noMS_but_DB'
            add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
            pass
        elif sum(intersect) == 1: # il n'y a de l'overlap qu'avec un seul groupe meanshift            
#             print('hit')
            idx = intersect.index(True)
#             print(idx)
            # vérifier que le groupe meanshift n'intersecte aucune autre bbox
            num_intersections = np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes))
            if num_intersections > 1:
                cur_out_stats['num_singleMS_multipleDB_reject'] += 1
                cause = 'singleMS_multipleDB'
                add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
                continue
            
            Rmm = huge_db_dict[basename]['dr_radius_mm']
            R_pixel = huge_db_dict[basename]['dr_radius_px']
            sun_center = huge_db_dict[basename]['dr_center_px']
            dr_pixpos = np.array([group_list[i]['posx'], group_list[i]['posy']])
            
            angular_excentricity =  c_utils.get_angle2(dr_pixpos, R_pixel, sun_center)
            
            cur_group_dict={
                            "centroid_px": centroids_px[idx],
                            "centroid_Lat": centroids[idx][0],
                            "centroid_Lon": centroids[idx][1],
                            "angular_excentricity_rad": angular_excentricity,
                            "angular_excentricity_deg": np.rad2deg(angular_excentricity),
                            "Zurich":   db_class["Zurich"],
                            "McIntosh": db_class["McIntosh"],
                            "members": ms_members[idx],
                            "members_mean_px": np.mean(ms_members[idx], axis=0),
                        }
            
            
            cur_out_groups.append(cur_group_dict)
            cur_out_stats['num_MSmatchesDB'] += 1

        else: # db_bbox intersecte plusieurs groupes meanshift
            cur_out_stats['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] += 1
            cause = 'num_oneDBbbox_multipleMSoverlap_ambiguity'
            add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
            pass

    if len(centroids>0):
        # print(centroids)
        # print()
        # count the number of MS groups that do not have any overlap with the DB
        num_intersections_per_group = [np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes)) for idx in range(len(ms_members))]
        num_MS_without_DB_overlap = len(np.where(np.array(num_intersections_per_group) == 0)[0])
        cur_out_stats['num_noDB_but_MS'] = num_MS_without_DB_overlap
    
    # print(cur_rejected_class_distibutions)
    cur_out_stats['rejected_class_distributions'] = deepcopy(cur_rejected_class_distibutions)
            

    return cur_out_groups, cur_out_stats,  cur_rejected_class_distibutions


In [10]:
import matplotlib
colors = ['tab:blue','tab:orange','tab:green','tab:red',
          'tab:purple','tab:brown','tab:pink',
          'tab:olive','tab:cyan']


def rotate_pt_around_center(point, rotation_pt, angle):
    '''
    @param point: (x,y) tuple of the point to rotate
    @param rotation_pt: (x,y) tuple of the point to rotate around
    @param angle: angle to rotate in degrees
    '''
    angle = np.deg2rad(angle)
    x, y = point
    x0, y0 = rotation_pt
    x1 = x0 + np.cos(angle) * (x - x0) - np.sin(angle) * (y - y0)
    y1 = y0 + np.sin(angle) * (x - x0) + np.cos(angle) * (y - y0)
    return x1, y1


def new_refresh(value):
#     print('new_refresh',1)
#     print('here')
    look_distance = look_distance_slider.value
    kernel_bandwidthLon = kernel_bandwidthLon_slider.value
    kernel_bandwidthLat = kernel_bandwidthLat_slider.value
    n_iterations = n_iterations_slider.value
    
#     print('new_refresh',1)
    basename = os.path.basename(wl_list[img_slider.value]).split(".")[0]
###########
    huge_dict = {}
    result_key, result_dict =  c_utils.process_one_image( wl_list[img_slider.value],
                                huge_db_dict,
                                huge_dict,
                                wl_list,
                                rotten_list,
                                masks_dir,
                                look_distance,
                                kernel_bandwidthLon,
                                kernel_bandwidthLat,
                                n_iterations,
                                input_type=input_type
                            )
    
#     print(basename, result_key)
#     print(result_dict)
    
    huge_dict[result_key] = result_dict
#     print(huge_dict)
    ax3[2].clear()
    matchings, matchings_stats, rejects_distrib = matching_in_wl(basename,huge_dict,ax3[2])
    # print(matchings)
    # print(matchings_stats)
    # print(rejects_distrib)
    matchings_ta.value =  result_key + '\n' + json.dumps(matchings_stats,sort_keys=True, indent=4)

##################
    cur_db_dict = huge_db_dict[basename]
    date = cur_db_dict["wl_date"] 

    m, h = t_utils.open_and_add_celestial(wl_list[img_slider.value])
    corrected = False
    if not 'DATE-OBS' in h:
        # print('corrected')
        m, h = t_utils.open_and_add_celestial2(wl_list[img_slider.value], date_obs=date)
        corrected = True
    # print('radius =', h['SOLAR_R'])
    
    
#     mask = io.imread(os.path.join(masks_dir,basename+".png"))
    
    if input_type == "mask":
        mask = io.imread(os.path.join(masks_dir,basename+".png"))
    elif input_type == "confidence_map":
        mask = np.load(os.path.join(masks_dir,basename+"_proba_map.npy"))
    mask2 = mask.copy()
    mask2[mask2>0] = 1
    
    flip_time = "2003-03-08T00:00:00"
    should_flip = (datetime.fromisoformat(date) - datetime.fromisoformat(flip_time)) < timedelta(0)
    if should_flip:
        m = sunmap.Map(np.flip(m.data,axis=0), h)      
        mask = np.flip(mask,axis=0)
        mask2 = np.flip(mask2,axis=0)


    cur_db_dict = huge_db_dict[basename]
    group_list = cur_db_dict["group_list"]
    drawing_radius_px = cur_db_dict["dr_radius_px"]
    date = cur_db_dict["wl_date"]
    
    Rmm = cur_db_dict["dr_radius_mm"]

    sunspots_sk, sunspots_areas_muHem = c_utils.get_sunspots4(h,m, mask2, Rmm, sky_coords=True)
    sunspots_pixel, sunspots_areas = c_utils.get_sunspots4(h, m, mask2, Rmm, sky_coords=False)
#     sunspots_sk, sunspots_areas = c_utils.get_sunspots3(h,m, mask2, sky_coords=True)
#     sunspots_pixel, _ = c_utils.get_sunspots3(h, m, mask2, sky_coords=False)

    wcs2 = WCS(h)
    
    if sunspots_pixel is not None:


        sk_Lon = sunspots_sk.lon.rad
        sk_Lat = sunspots_sk.lat.rad
        sk_LatLon = np.stack((sk_Lat,sk_Lon),axis=1)

        sunspots_areas_muHem = np.array(sunspots_areas_muHem)
#         print(sunspots_sk)
#         print(sunspots_pixel)
#         print(sunspots_areas_muHem)


        nan_indexes = np.unique(np.argwhere(np.isnan(sk_LatLon))[:,0])
#         print('nan_indexes', nan_indexes)
        clean = (~np.isnan(sk_Lon) & ~np.isnan(sk_Lat))
        if len(nan_indexes) > 0:
            sunspots_sk = sunspots_sk[clean]
            sunspots_areas = (np.array(sunspots_areas)[clean]).tolist()
            sunspots_areas_muHem = np.array(sunspots_areas_muHem)[clean]
            sunspots_pixel = sunspots_pixel[clean]
            sk_LatLon = sk_LatLon[clean]
#             print(len(sk_LatLon),len(sunspots_areas_muHem))



        global ms_model
    #     print(sunspots_sk.radius.km[0])
        ms_model = MS.Mean_Shift(look_distance, kernel_bandwidthLon, kernel_bandwidthLat, sunspots_sk.radius.km[0], n_iterations, max_scaled_area_muHem=200)

    #     ms_model.fit(sk_LatLon, sunspots_areas)
        ms_model.fit(sk_LatLon, sunspots_areas_muHem)
        
#         print(ms_model.history)

        ms_centroids = ms_model.centroids

        sk_sequ_meanshift = SkyCoord(ms_centroids[:,1]*u.rad, ms_centroids[:,0]*u.rad , frame=frames.HeliographicCarrington,
                            obstime=m.date, observer="earth")

        pix_centers_meanshift = np.array(skycoord_to_pixel(sk_sequ_meanshift, wcs2, origin=0)).T.tolist()

        ms_classifications = ms_model.predict(sk_LatLon)
#         print(ms_classifications.shape, np.unique(ms_classifications))

        ms_group_sunspots = [(sk_LatLon[ms_classifications == c].tolist()) for c in np.unique(ms_classifications)] 
        ms_group_sunspots_px = [sunspots_pixel[ms_classifications == c].tolist() for c in np.unique(ms_classifications)]

        ms_group_sunspots_areas = [sunspots_areas_muHem[ms_classifications == c].tolist() for c in np.unique(ms_classifications)]
        # print(len(ms_group_sunspots))

    #latitude and longitudes in radians
    extreme_values = -np.pi, np.pi, 0, 2*np.pi

    dr_obstime = date+'.000'  
    all_sks = []
    all_pixels = []
    for item in group_list:
        cur_sk = SkyCoord(item["Longitude"]*u.rad, item["Latitude"]*u.rad , frame=frames.HeliographicCarrington,
                      obstime=dr_obstime, observer="earth") 
        coords_wl = skycoord_to_pixel(cur_sk, wcs2, origin=0)
        all_sks.append(cur_sk)
        all_pixels.append(coords_wl)

    bboxes, bboxes_wl, rectangles, rectangles_wl = c_utils.grouplist2bboxes_and_rectangles(group_list, 
                                                                                   drawing_radius_px,
                                                                                   h["SOLAR_R"],
                                                                                   all_pixels)

    # rotated_bboxes = [plt.Rectangle((bbox_wl[0], bbox_wl[1]),
    #                                 bbox_wl[2]-bbox_wl[0], bbox_wl[3]-bbox_wl[1],
    #                               color='b', fill=False,) for bbox_wl in bboxes_wl]
    #                             #   color='b', fill=False, angle=h["SOLAR_R"]) for bbox_wl in bboxes_wl]
                                
    rotated_bbox_ref = [ rotate_pt_around_center((bbox_wl[0], bbox_wl[1]), 
                                                (bbox_wl[0]+(bbox_wl[2] - bbox_wl[0] )/2 , 
                                                    bbox_wl[1]+(bbox_wl[3] - bbox_wl[1] )/2), 
                                                -h["SOLAR_R"]) for bbox_wl in bboxes_wl]
    rotated_bboxes = [plt.Rectangle((rotated_bbox_ref[i][0], rotated_bbox_ref[i][1]),
                                                 bbox_wl[3]-bbox_wl[1], bbox_wl[2]-bbox_wl[0],
                                                angle=-h["SOLAR_R"],color='b', fill=False,) 
                                        for i,bbox_wl in enumerate(bboxes_wl)]
                                                                                   
    # print(sk_LatLon.shape)
    ax3[0].clear(), ax3[1].clear()#, ax3[2].clear()
    ax3[0].set_title(basename)
    ax3[0].imshow(m.data,cmap='gray')
    ax3[0].imshow(mask,cmap='jet',alpha=0.5)
    for i, r in enumerate(rotated_bboxes):
        ax3[0].add_patch(r)
        if info_cb.value:
            ax3[0].text(rotated_bbox_ref[i][0], rotated_bbox_ref[i][1], 
                          f' {group_list[i]["McIntosh"]} : {group_list[i]["area_muHem"]}',color='b')
    ax3[0].invert_yaxis()

    ax3[1].imshow(m.data,cmap='gray')
    ax3[1].invert_yaxis()
    
    ax3[1].set_title(f'angle: {h["SOLAR_P0"]}, should_flip: {should_flip}')
    if sunspots_pixel is not None:
        for i in range(len(ms_group_sunspots)):
            c = colors[np.unique(ms_classifications)[i]%len(colors)]
            # c = colors[ms_classifications[i]%len(colors)]
            cur = np.array(ms_group_sunspots_px[i])
            ms_centers = pix_centers_meanshift[i]
            ax3[1].scatter(ms_centers[0],ms_centers[1], s=10, c=c, marker='x')
            ax3[1].scatter(cur[:,1], cur[:,0], color=c, s=2)
    #         print(ms_group_sunspots_areas[i])
            if info_cb.value:
                for j in range(len(cur)):
                    ax3[1].text(cur[j,1], cur[j,0],  '%.2f' % ms_group_sunspots_areas[i][j] ,va='top',c=c)

    #     ax3[2].scatter(ms_centroids[:,1], ms_centroids[:,0], s=1)
    #     ax3[2].set_ylim(extreme_values[0], extreme_values[1])
    #     ax3[2].set_xlim(extreme_values[2], extreme_values[3])

        # ax5[1].scatter(sk_LatLon[:,1], sk_LatLon[:,0], s=1)
        # ax5[1].set_xlim(0, 2*np.pi)
        # ax5[1].set_ylim(-np.pi/2, np.pi/2)

    
    
    hist_refresh(None)
#create a list of matplotlib colors
colors = [ "red", "green", "blue", "orange", "purple", "brown", "pink", "gray", "olive", "cyan", "magenta", "yellow"] 

def hist_refresh(change):
    # if (xlims0, ylims0) != (0, 1):
    xlims0 = ax5.get_xlim()
    ylims0 = ax5.get_ylim()
    # print(xlims0, ylims0)
    
    look_distance = look_distance_slider.value
    kernel_bandwidthLon = kernel_bandwidthLon_slider.value
    kernel_bandwidthLat = kernel_bandwidthLat_slider.value
    n_iterations = n_iterations_slider.value

    global ms_model
    
        
    # print(ms_model.history)
    step = hist_slider.value
    ax5.clear()
    # ax5.set_title('History step {}'.format(step))
    if ms_model is not None:
        for i in range(len(ms_model.history[step])):
            cur_width = ms_model.get_area_weighted_ellipsis_width( ms_model.areas[i], ms_model.areas)
            cur_height = kernel_bandwidthLat
            cur_color = colors[i%len(colors)]
            # if (ms_model.data[i][0] <= 0.28 and ms_model.data[i][0] >= 0.25) and (ms_model.data[i][1] >= 1.45 and ms_model.data[i][1] <= 1.48):
            #     print(cur_width, cur_height)
            ellipsis = matplotlib.patches.Ellipse((ms_model.history[step][i,1], ms_model.history[step][i,0]), 2*cur_width, 2*cur_height, fill=False, color=cur_color)
            # ellipsis = matplotlib.patches.Ellipse((ms_model.history[step][i,1], ms_model.history[step][i,0]), 2*cur_width, 2*cur_height, fill=False, color='red')
            # ellipsis = matplotlib.patches.Ellipse((ms_model.history[step][i,1], ms_model.history[step][i,0]), 2*kernel_bandwidthLon, 2*kernel_bandwidthLat, fill=False, color='red')
            ax5.add_patch(ellipsis)
        # ax5.scatter(ms_model.history[0][:,1], ms_model.history[0][:,0], s=4, marker='X', c='g')
        # ax5.scatter(ms_model.history[step][:,1], ms_model.history[step][:,0], s=3)
        
        for j in range(len(ms_model.history[0])):
            cur_color = colors[j%len(colors)]
            print(cur_color)
            ax5.scatter(ms_model.history[0][j,1], ms_model.history[0][j,0], s=20, marker='o', color=cur_color, alpha = 0.5)

        for i in range(len(ms_model.history[step])):
            cur_color = colors[i%len(colors)]
            ax5.scatter(ms_model.history[step][i,1], ms_model.history[step][i,0], s=30, marker='+', c=cur_color)

    ax5.set_ylim(-np.pi/2, np.pi/2)
    ax5.set_xlim(0, 2*np.pi)
    ax5.set_xlabel('Longitude [rad]')
    ax5.set_ylabel('Latitude [rad]')
    # ax5.set_xlim(np.min(ms_model.data[:,1]), np.max(ms_model.data[:,1]))
    if (xlims0, ylims0) != ((0., 1.),(0., 1.)):
        ax5.set_xlim(xlims0)
        ax5.set_ylim(ylims0)
    

# look_distance = .1 # How far to look for neighbours.
# kernel_bandwidthLon = .2  # Longitude Kernel parameter.
# kernel_bandwidthLat = .1  # Latitude Kernel parameter.
# n_iterations = 20 # Number of iterations
# input_type = 'confidence_map'
input_type = 'mask'
    
ms_model = None
# 830, 536, 509, 24, 31, 183
#30
# img_slider = widgets.IntSlider(min=0, max=len(wl_list)-1, step=1, value=24, description='Image')
#début des problèmes 349
img_slider = widgets.IntSlider(min=0, max=len(wl_list)-1, step=1, value=2183, description='Image')
# img_slider = widgets.IntSlider(min=0, max=len(wl_list)-1, step=1, value=1505, description='Image')
info_cb = widgets.Checkbox(value=False, description='Info', disabled=False, indent=False)

max_n_iterations = 20
look_distance_slider = widgets.FloatSlider(min=.1,max=.1, description='look_distance')
kernel_bandwidthLon_slider = widgets.FloatSlider(min=.35,max=.45,step=.05, description='kernel_bandwidthLon')
kernel_bandwidthLat_slider = widgets.FloatSlider(min=.08,max=.2,step=.02, description='kernel_bandwidthLat')
n_iterations_slider = widgets.IntSlider(min=20,max=max_n_iterations, description='look_distance')

img_slider.observe(new_refresh, 'value')
info_cb.observe(new_refresh, 'value')

look_distance_slider.observe(new_refresh,'value')
kernel_bandwidthLon_slider.observe(new_refresh,'value')
kernel_bandwidthLat_slider.observe(new_refresh,'value')
n_iterations_slider.observe(new_refresh,'value')

hist_slider = widgets.IntSlider(min=0, max=max_n_iterations-1, step=1, value=0, description='History')
hist_slider.observe(hist_refresh, 'value')

matchings_ta = widgets.Textarea(description='Matchings', value='',layout=widgets.Layout(height="100%", width="100%"))
def matchings_refresh(change):
    matchings_ta.rows = matchings_ta.value.count('\n') + 1
matchings_ta.observe(matchings_refresh, 'value')

plt.ioff()
fig3,ax3 = plt.subplots(1,3,figsize=(9,3))
fig5,ax5 = plt.subplots(1,1,figsize=(5,2.5))
fig3.tight_layout()
# print('ICI1')
new_refresh(None)
# print('ICI2')
hist_refresh(None)
plt.ion()

rotten_list = [
    
    37,38,39,40,52, 64,65,69,70,
    
    72,97,99,100,101,102,103,104,142,159,160,161,169,187,190,210,211,212,218,264,300,312,314,316,319,322,327,339,
    343,353,356,387,408,413,414,418,424,425,448,473,474,493,512,508,611,614,666,675,696,726,330,747,750,758,
    761,784,804,823,832,840,855,914,935,940,948,990,1013,
    
    1025,1039,1040,1089,1172,1303,1332,1345,1397,1409,1413,1414,1421,1440,1444,1468,1469,1488,1576,1646,1692,
    1735,1815,1840,1867,1893,1900,1905,1919,1924,1925,1930,1953,1969,1992,
    
    2007,2039,2043,2045,2049,2050,2078,2121,2133,2143,2185,2208,2220,2254,2266,2272,2298,2344,3262,3274,2375,
    2445,2454,2468,2492,2494,2495,2500,2501,2503,2516,2518,2536,2568,2574,2598,2604,2633,2635,2749,2763,2815,
    2818,2820,2821,2834,2835,2851,2857,2867,2896,2899,2848,2951,2952,2956,2964,2980,2981,2994,
    
    3018,3092,3093,3097,3099,3101,3106,3118,3122,3123,3124,3140,3148
    
    
]



display(widgets.HBox([widgets.VBox([look_distance_slider,
                          kernel_bandwidthLon_slider,
                          kernel_bandwidthLat_slider,
                          n_iterations_slider,img_slider, info_cb], layout=widgets.Layout(width='25%',object_position='bottom') ),
                      fig3.canvas]))
# display(widgets.HBox([matchings_ta, widgets.VBox([widgets.HBox([hist_slider]), fig5.canvas])]))
display(widgets.HBox([widgets.HBox([matchings_ta],layout=widgets.Layout(width='50%')), 
                      widgets.HBox([widgets.VBox([widgets.HBox([hist_slider]), fig5.canvas])])
                     ]))
# display(matchings_ta)

# text_area_input = pn.widgets.input.TextAreaInput(name='Text Area Input', placeholder='Enter a string here...')
# display(text_area_input)

red
green
blue
orange
purple
brown
pink
gray
olive
cyan
magenta
yellow
red
green
blue
orange
purple
brown
pink
gray
red
green
blue
orange
purple
brown
pink
gray
olive
cyan
magenta
yellow
red
green
blue
orange
purple
brown
pink
gray


HBox(children=(VBox(children=(FloatSlider(value=0.1, description='look_distance', max=0.1, min=0.1), FloatSlid…

HBox(children=(HBox(children=(Textarea(value='UPH20110413081253\n{\n    "num_DB_groups": 5,\n    "num_DB_isola…

In [11]:
param_grid_values = {
    'look_distance' : [0.1],
    # 'kernel_bandwidthLon' : [ 0.05 , 0.1, 0.15, .2,.21,.22,.23,.24,.25, .3,.35,.45],
    # 'kernel_bandwidthLat' : [.08,],
    'kernel_bandwidthLon' : [ 0.1, 0.15, .2, .25, .3, .35, .4, .45, .5, .55, .6, .65, .7, .75],
#     'kernel_bandwidthLon' : [ .25],
#     'kernel_bandwidthLat' : [ 0.02, 0.04 , 0.06, .08, .1, .12, .16, .18, .2],
    'kernel_bandwidthLat' : [ 0.02, 0.04 , 0.06, .08, .1, .12],
    # 'kernel_bandwidthLat' : [ 0.05 , 0.06, 0.07, .08, .09, .1, .11, .12],
    'n_iterations' : [20],
}

param_grid = ParameterGrid(param_grid_values)
num_cpu = 20
# num_cpu = 30
input_type = "mask"

grid_huge_dict = {str(i): {} for i,_ in enumerate(param_grid)}

In [None]:

for i, params in tqdm(enumerate(param_grid)):
    # if i > 0:
    #     break
    print(i, params)
    look_distance = params['look_distance']  # How far to look for neighbours.
    kernel_bandwidthLon = params['kernel_bandwidthLon']  # Longitude Kernel parameter.
    kernel_bandwidthLat = params['kernel_bandwidthLat']  # Latitude Kernel parameter.
    n_iterations = params['n_iterations'] # Number of iterations

    
    cur_huge_dict = { } if str(i) not in grid_huge_dict else grid_huge_dict[str(i)]
    print(len(list(cur_huge_dict.keys())))    
#     grid_huge_dict[i] = {'a':0}


    with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
        for result_key, result_dict in tqdm(executor.map(c_utils.process_one_image, 
                                                            wl_list[:],
                                                            repeat(huge_db_dict),
                                                            repeat(cur_huge_dict),
                                                            repeat(wl_list),
                                                            repeat(rotten_list),
                                                            repeat(masks_dir),
                                                            repeat(look_distance),
                                                            repeat(kernel_bandwidthLon),
                                                            repeat(kernel_bandwidthLat),
                                                            repeat(n_iterations),
                                                            repeat(input_type)
                                                            )):
            # print(result_key)
            if not len(list(result_dict.keys())) == 0:
                print('here')
#                 cur_huge_dict[result_key] = deepcopy(result_dict)
                grid_huge_dict[str(i)][result_key] = deepcopy(result_dict)

#     grid_huge_dict[i] = deepcopy(cur_huge_dict)




    
    

In [11]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.int64):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, np.bool_):
            return bool(obj)
        return super(NpEncoder, self).default(obj)

In [9]:

    
    
import json 
tmp = './grid_search_huge_dict_2002-19.json'
with open(tmp, 'w') as f:
    json.dump(grid_huge_dict, f, cls=NpEncoder)

NameError: name 'grid_huge_dict' is not defined

In [16]:
param_grid_values = {
    'look_distance' : [0.1],
    # 'kernel_bandwidthLon' : [ 0.05 , 0.1, 0.15, .2,.21,.22,.23,.24,.25, .3,.35,.45],
    'kernel_bandwidthLat' : [.08,],
#     'kernel_bandwidthLon' : [ 0.1, 0.15, .2, .25, .3, .35, .4, .45, .5, .55, .6, .65, .7, .75],
    'kernel_bandwidthLon' : [ .35],
#     'kernel_bandwidthLat' : [ 0.02, 0.04 , 0.06, .08, .1, .12, .16, .18, .2],
    'kernel_bandwidthLat' : [ .08,],
#     'kernel_bandwidthLat' : [ 0.02, 0.04 , 0.06, .08, .1, .12],
    # 'kernel_bandwidthLat' : [ 0.05 , 0.06, 0.07, .08, .09, .1, .11, .12],
    'n_iterations' : [20],
}

param_grid = ParameterGrid(param_grid_values)

import json 
tmp2 = './grid_search_huge_dict.json'
with open(tmp2, 'r') as f:
    grid_huge_dict = json.load(f)

In [14]:
param_grid_values

{'look_distance': [0.1],
 'kernel_bandwidthLat': [0.08],
 'kernel_bandwidthLon': [0.35],
 'n_iterations': [20]}

# Functions

In [12]:
import cv2
from matplotlib import pyplot as plt
import matplotlib.patches as patches

def compute_IoUs(db_bboxes, ms_bboxes):
    '''Compute IoUs between all bboxes in db_bboxes and ms_bboxes.'''
    db_bboxes = np.array(db_bboxes)
    ms_bboxes = np.array(ms_bboxes)
    # print(db_bboxes.shape)
    # print(ms_bboxes.shape)

    # if db_bboxes is single bbox, convert to 2D array.
    if db_bboxes.ndim == 1:
        db_bboxes = np.array([db_bboxes])
    # if ms_bboxes is single bbox, convert to 2D array.
    if ms_bboxes.ndim == 1:
        ms_bboxes = np.array([ms_bboxes])

    # if db_bboxes is empty, return empty 2D array.
    if db_bboxes.size == 0:
        return np.array([[]])

    # print(db_bboxes.shape)
    # print(ms_bboxes.shape)
    # db_bboxes has shape (N, 4) and ms_bboxes has shape (M, 4).
    # each element has the form [x1, y1, x2, y2].
    # We want to compute IoUs between all N and M boxes.
    db_bboxes = db_bboxes.astype(np.float)
    ms_bboxes = ms_bboxes.astype(np.float)
    # print(db_bboxes.shape)
    # print(ms_bboxes.shape)

    # Compute areas of all db_bboxes and ms_bboxes.
    area_db = (db_bboxes[:, 2] - db_bboxes[:, 0]) * (db_bboxes[:, 3] - db_bboxes[:, 1])
    area_ms = (ms_bboxes[:, 2] - ms_bboxes[:, 0]) * (ms_bboxes[:, 3] - ms_bboxes[:, 1])
    # print(area_db.shape, area_ms.shape)

    # compute intersections
    # intersections has shape (N, M) and intersections[i, j] is the intersection
    # between db_bboxes[i] and ms_bboxes[j].
    intersections = np.zeros((db_bboxes.shape[0], ms_bboxes.shape[0]))
    for i in range(db_bboxes.shape[0]):
        for j in range(ms_bboxes.shape[0]):
            x1 = max(db_bboxes[i, 0], ms_bboxes[j, 0])
            y1 = max(db_bboxes[i, 1], ms_bboxes[j, 1])
            x2 = min(db_bboxes[i, 2], ms_bboxes[j, 2])
            y2 = min(db_bboxes[i, 3], ms_bboxes[j, 3])
            intersections[i, j] = max(x2 - x1, 0) * max(y2 - y1, 0)
    # print(intersections.shape)
    # print(intersections)

    # compute unions
    unions = area_db[:, np.newaxis] + area_ms[np.newaxis, :] - intersections

    # compute IoUs
    ious = intersections / unions


    return ious

def compute_distances(db_bboxes, ms_bboxes):
    '''Compute distances between all bboxes in db_bboxes and ms_bboxes.'''
    db_bboxes = np.array(db_bboxes)
    ms_bboxes = np.array(ms_bboxes)

    # if db_bboxes is single bbox, convert to 2D array.
    if db_bboxes.ndim == 1:
        db_bboxes = np.array([db_bboxes])
    # if ms_bboxes is single bbox, convert to 2D array.
    if ms_bboxes.ndim == 1:
        ms_bboxes = np.array([ms_bboxes])

    # if db_bboxes is empty, return empty 2D array.
    if db_bboxes.size == 0:
        return np.array([[]])

    # db_bboxes has shape (N, 4) and ms_bboxes has shape (M, 4).
    # each element has the form [x1, y1, x2, y2].
    # We want to compute distances between all N and M boxes.
    db_bboxes = db_bboxes.astype(np.float)
    ms_bboxes = ms_bboxes.astype(np.float)

    # Compute centers of all db_bboxes and ms_bboxes.
    center_db = (db_bboxes[:, 2:] + db_bboxes[:, :2]) / 2
    center_ms = (ms_bboxes[:, 2:] + ms_bboxes[:, :2]) / 2
    # print('here')
    # print(center_db.shape)
    # print(center_ms.shape)

    distances = np.repeat( center_ms[ np.newaxis,:,:], center_db.shape[0], axis=0)

    # print(distances.shape)

    distances = np.sqrt(np.sum((distances - center_db[:,np.newaxis,:])**2, axis=2))

    # print(distances.shape)
    # print(distances)

    return distances


def find_closest_ms_bbox(db_bboxes, ms_bboxes, maximum_distance=300):
    '''Find the closest ms_bbox to db_bbox.'''
    ious = compute_IoUs(db_bboxes, ms_bboxes)
    distances = compute_distances(db_bboxes, ms_bboxes)
    

    # if ious is empty, closest_ms_bbox is empty.
    if (ious.size == 0) and (distances.size == 0):
        closest_ms_bbox = np.array([])
        closest_ms_bbox_idx = np.array([])
        
        closest_db_bbox = np.array([])
        closest_db_bbox_idx = np.array([])
    else:
        # Find the closest ms_bbox to each db_bbox.
        closest_ms_bbox = np.min(distances, axis=1)
        closest_ms_bbox_idx = np.argmin(distances, axis=1)
        # print("closest_ms_bbox_idx", closest_ms_bbox_idx)
        
        
        # Find the closest db_bbox to each ms_bbox.
        closest_db_bbox = np.min(distances.T, axis=1)
        closest_db_bbox_idx = np.argmin(distances.T, axis=1)
    
    # print(closest_ms_bbox.shape)
    # print(closest_ms_bbox_idx.shape)
    # print(closest_ms_bbox)
    # print(closest_ms_bbox_idx)

    # replace the closest ms_bbox with -1 if it is too far away.
    closest_ms_bbox[closest_ms_bbox > maximum_distance] = -1
    closest_ms_bbox_idx[closest_ms_bbox == -1 ] = -1
     
    # replace the closest db_bbox with -1 if it is too far away.
    closest_db_bbox[closest_db_bbox > maximum_distance] = -1
    closest_db_bbox_idx[closest_db_bbox == -1 ] = -1

    # count the number of db_bboxes that are too far away.
    db_too_far = closest_ms_bbox == -1
    num_db_too_far = np.sum(db_too_far)
    db_too_far_idx = np.where(db_too_far)[0]
    # count the number of ms_bboxes that are too far away.
    ms_too_far = closest_db_bbox == -1
    num_ms_too_far = np.sum(ms_too_far)
    ms_too_far_idx = np.where(ms_too_far)[0]
    

    # get the ious of the closest ms_bbox to each db_bbox.
    closest_ms_bbox_iou = np.array([ious[i,closest_ms_bbox_idx[i]] for i in range(len(db_bboxes))])
    closest_ms_bbox_iou[closest_ms_bbox == -1] = -1


    multiDB_singleMS_idx = []
    num_multiDB_singleMS = 0
    # if two db_bboxes have the same closest ms_bbox, keep the one with the highest IoU.
    for i in np.unique(closest_ms_bbox_idx):
        if i == -1:
            continue
        # get the indexes of the db_bboxes that have the same closest ms_bbox.
        idx = np.where(closest_ms_bbox_idx == i)[0]
        if len(idx) > 1:
            # keep the db_bbox with the highest IoU.
            highest_iou_idx = np.argmax(closest_ms_bbox_iou[idx])
            # add the indexes of the other db_bboxes to multiDB_singleMS.
            multiDB_singleMS_idx.extend(idx[np.arange(len(idx)) != highest_iou_idx])

            # set the closest ms_bbox of the other db_bboxes to -1.
            closest_ms_bbox_idx[idx[np.arange(len(idx)) != highest_iou_idx]] = -1
            closest_ms_bbox[idx[np.arange(len(idx)) != highest_iou_idx]] = -1
            closest_ms_bbox_iou[idx[np.arange(len(idx)) != highest_iou_idx]] = -1
            # get the number of indexes that were set to -1.
            num_multiDB_singleMS += len(idx) - 1
    assert num_multiDB_singleMS == len(multiDB_singleMS_idx)
    multiDB_singleMS_idx = np.array(multiDB_singleMS_idx)
    

    # get the indexes of the ms_bboxes that appear as the closest ms_bbox to some db_bbox.
    ms_bbox_idx = np.unique(closest_ms_bbox_idx)
    
    # remove the -1 index.
    ms_bbox_idx = ms_bbox_idx[ms_bbox_idx != -1]
    # print("ms_bbox_idx", ms_bbox_idx)

    # print(closest_ms_bbox, closest_ms_bbox_idx)
    # print(ms_bbox_idx)
    
    candidates_indexes = np.array(range(len(ms_bboxes)))
    unmatched_ms = np.setdiff1d(candidates_indexes, ms_bbox_idx)

    if (ious.size == 0) and (distances.size == 0):
        bad_ms = np.array([])
    else:   
        # get the ms_bboxes in unmatched_ms that have iou > 0. with some db_bbox.
        bad_ms = unmatched_ms[np.max(ious[:, unmatched_ms], axis=0) > 0.]

    # remove bad_ms from unmatched_ms.
    unmatched_ms = np.setdiff1d(unmatched_ms, bad_ms)

    # unmatched_db contains: 
    # 1) the indexes of the db_bboxes that are too far away from any ms_bbox. 
    # 2) the indexes of the db_bboxes that have the same closest ms_bbox as another db_bbox but that have a lower IoU.
    unmatched_db = np.where(closest_ms_bbox == -1)[0]

    # print("unmatched_db", unmatched_db)
    # print("multiDB_singleMS", multiDB_singleMS)
    # print("db_too_far", db_too_far)


    # 2 types of 'rejected'  db bboxes (candidate too far + candidate better matched by another db bbox) : 
    # Make sure that the number of unmatched db + multiDB_singleMS is equal to the number of too far db.
    assert len(unmatched_db) == num_db_too_far + num_multiDB_singleMS
#     assert len(unmatched_ms) == num_ms_too_far +
    
    # Make sure that the number of unmatched db + ms_bbox_idx (number of 1 to 1 matches) is equal to the number of db_bboxes.
    assert len(unmatched_db) + len(ms_bbox_idx) == len(db_bboxes)
    # Make sure that the number of bad ms + unmatched ms + ms_bbox_idx is equal to the number of ms_bboxes.
    assert len(unmatched_ms) + len(bad_ms) + len(ms_bbox_idx) == len(ms_bboxes)

    matched_db = np.where(closest_ms_bbox != -1)[0]
    matches = closest_ms_bbox_idx

    
    return  (ious, distances, 
            unmatched_ms, bad_ms, ms_too_far_idx, 
            unmatched_db, multiDB_singleMS_idx, db_too_far_idx, 
            matched_db, matches)

def find_matchings_one_image(cur_huge_dict, basename, wl_dir, mask_dir, input_type, show=False):
        cur_image_dict = cur_huge_dict[basename]
        
        angle = cur_image_dict["SOLAR_P0"]
        deltashapeX = cur_image_dict["deltashapeX"]
        deltashapeY = cur_image_dict["deltashapeY"]
        
        drawing_radius_px = huge_db_dict[basename]["dr_radius_px"]
        
        group_list = cur_image_dict['db']
        
        ms_dict = cur_image_dict['meanshift']

        ms_members = ms_dict['groups_px']

        # print('ms_dict: ', ms_dict)
        
        centroids = np.array(ms_dict["centroids"])
        centroids_px = np.array(ms_dict["centroids_px"])
        
        db_classes = [{"Zurich":item['Zurich'], "McIntosh":item['McIntosh'] } for item in group_list]
        # Attention: bbox_wl is in the form [lat1, Lon1, lat2, Lon2] -> [y1, x1, y2, x2]
        # x1 is 
        db_bboxes = [np.array(item['bbox_wl']) for item in group_list]
        db_centers_px = np.array([[(b[2]+b[0])/2,(b[3]+b[1])/2] for b in db_bboxes])

        # open the image
        image = np.array(io.imread(os.path.join(wl_dir, basename + '.FTS')))
        image = c_utils.rotate_CV_bound(image, angle, interpolation=cv2.INTER_NEAREST)
        image = image[deltashapeX//2:image.shape[0]-deltashapeX//2,
                            deltashapeY//2:image.shape[1]-deltashapeY//2]

        # open the mask
        # mask = np.array(io.imread(os.path.join(masks_dir, basename + '.png')))
        if input_type == "mask":
            mask = io.imread(os.path.join(masks_dir,basename+".png"))
        elif input_type == "confidence_map":
#             print("here")
            mask = np.load(os.path.join(masks_dir,basename+"_proba_map.npy"))
            mask[mask>0] = 1   
            
        msk = c_utils.expand_small_spots(mask)

        # rotate the mask
        mask = c_utils.rotate_CV_bound(mask, angle, interpolation=cv2.INTER_NEAREST)
        mask = mask[deltashapeX//2:mask.shape[0]-deltashapeX//2,
                            deltashapeY//2:mask.shape[1]-deltashapeY//2] 

        group_masks = [c_utils.get_mask_from_coords(mask, members) for members in ms_dict['groups_px']]
         
        try:
            groups_bboxes = [c_utils.get_bbox_from_mask(mask) for mask in group_masks]
            groups_bboxes = [(b[1], b[0], b[3], b[2]) for b in groups_bboxes]
        except:
            print("il y a un souci")
            print(basename)
            print(centroids)
        
            raise

        res = find_closest_ms_bbox(db_bboxes, groups_bboxes)
        ious, distances, unmatched_ms, bad_ms, ms_too_far, unmatched_db, multiDB_singleMS, db_too_far, matched_db, matches = res
        
        unmatched_ms = unmatched_ms.tolist()
        unmatched_db = unmatched_db.tolist()
        bad_ms = bad_ms.tolist()
        multiDB_singleMS = multiDB_singleMS.tolist()
        db_too_far = db_too_far.tolist()
        ms_too_far = ms_too_far.tolist()
        

        cur_out_stats = {
            # General info
            'num_DB_groups':len(db_bboxes),
            'num_MS_groups':len(centroids_px),

            'matches':matches,

            # MS with DB matching info
            'unmatched_db':unmatched_db,
            'multiDB_singleMS': multiDB_singleMS,
            'db_too_far':db_too_far,


            'unmatched_ms':unmatched_ms,
            'bad_ms':bad_ms,
            'ms_too_far': ms_too_far,

            "ious":ious.tolist(),
            "distances":distances.tolist()

            }
        
        Rmm = huge_db_dict[basename]['dr_radius_mm']
        R_pixel = huge_db_dict[basename]['dr_radius_px']
        sun_center = huge_db_dict[basename]['dr_center_px']
        
        cur_out_groups = []
        for i, match in enumerate(matches):
            # print('i: ', i, 'match: ', match)
            if match != -1:
                db_class = db_classes[i]

                for pt in ms_members[match]:
                    # print('pt: ', pt)
                    assert c_utils.contains_sunspot(groups_bboxes[match],pt), "pt: {} not in bbox: {}".format(pt, groups_bboxes[match])


                dr_pixpos = np.array([group_list[i]['posx'], group_list[i]['posy']])
                
                angular_excentricity =  c_utils.get_angle2(dr_pixpos, R_pixel, sun_center)
                
                cur_group_dict={
                                "centroid_px": centroids_px[match],
                                "centroid_Lat": centroids[match][0],
                                "centroid_Lon": centroids[match][1],
                                "angular_excentricity_rad": angular_excentricity,
                                "angular_excentricity_deg": np.rad2deg(angular_excentricity),
                                "Zurich":   db_class["Zurich"],
                                "McIntosh": db_class["McIntosh"],
                                "members": ms_members[match],
                                "members_mean_px": np.mean(ms_members[match], axis=0),
                            }
                
                
                cur_out_groups.append(cur_group_dict)



        out_groups = {}
        if len(cur_out_groups) > 0:
            out_groups = { "angle": angle,
                                        "deltashapeX":deltashapeX,
                                        "deltashapeY":deltashapeY,
                                        "groups": cur_out_groups,
                                    }


        ############## SHOW THE RESULTS ################
        if show:
            print('unmatched db: ', unmatched_db)
            print('multiDB_singleMS: ', multiDB_singleMS)
            print('db_too_far: ', db_too_far)

            print('unmatched ms', unmatched_ms)
            print('bad_ms: ', bad_ms)

            # if ious in not empty
            if ious.size != 0:
                # get the best iou for each db_bbox
                best_ious = np.max(ious, axis=1)
                best_ious_idx = np.argmax(ious, axis=1)
            # show the mask
            # plt.figure()
            fig, ax = plt.subplots(1,1, figsize=(5,5))
            ax.imshow(image,cmap='gray')
            ax.imshow(mask, alpha=0.5)
            # show the db_bboxes
            for i, bbox in enumerate(db_bboxes):
                linestyle = '-'
                if i in unmatched_db:
                    linestyle = '--'
                ax.add_patch(patches.Rectangle((bbox[0],bbox[1]),bbox[2]-bbox[0],bbox[3]-bbox[1],linewidth=1,edgecolor='b',facecolor='none', linestyle=linestyle))
                #format best iou to .2f
                b=best_ious[i]
                b = "{:.2f}".format(b)
                ax.text(bbox[0],bbox[1], b, color='b')
            # show the groups_bboxes
            for i, bbox in enumerate(groups_bboxes):
                color = 'g'
                if i in bad_ms:
                    color = 'r'
                elif i in unmatched_ms:
                    color = 'y'
                ax.add_patch(patches.Rectangle((bbox[0],bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1],linewidth=1,edgecolor=color,facecolor='none'))
                if i in bad_ms:
                    ax.text(bbox[0],bbox[1], 'bad', color='r')

            for i, match in enumerate(matches):
                if match != -1:
                    ax.plot([db_bboxes[i][0], groups_bboxes[match][0]], [db_bboxes[i][1], groups_bboxes[match][1]], color='g')
            
            plt.show()

        return basename, out_groups, cur_out_stats


# Nouvelle méthode de comptage de TP et FN

In [22]:
import json 
tmp = './grid_image_out_dict2.json'
with open(tmp, 'w') as f:
    json.dump(grid_image_out_dict, f, cls=NpEncoder)
tmp = './grid_image_out_dict_stats2.json'
with open(tmp, 'w') as f:
    json.dump(grid_image_out_dict_stats, f, cls=NpEncoder)

In [13]:
tmp2 = './grid_image_out_dict2.json'
with open(tmp2, 'r') as f:
    grid_image_out_dict = json.load(f)
tmp2 = './grid_image_out_dict_stats2.json'
with open(tmp2, 'r') as f:
    grid_image_out_dict_stats = json.load(f)

print(len(grid_image_out_dict))
print(grid_image_out_dict.keys())
print(len(grid_image_out_dict_stats))
print(grid_image_out_dict_stats.keys())

84
dict_keys(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83'])
84
dict_keys(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '

In [14]:
print(list(cur_huge_dict.keys())[:])
print()
print(wl_dir)
print(masks_dir)
print(input_type)
print(show)

NameError: name 'cur_huge_dict' is not defined

In [18]:
%matplotlib inline
grid_image_out_dict = { }
grid_image_out_dict_stats = { }
num_cpu = 1
show=False

# input_type = "mask"
input_type = "confidence_map"

#####
param_optim_folder = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/param_optimization'
#####

for param_idx, params in tqdm(enumerate(param_grid)):
    print(param_idx, params)

    image_out_dict = {}
    image_out_dict_stats = {}

#     {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}

#     cur_huge_dict = deepcopy(grid_huge_dict[str(param_idx)])
    fn = f'cur_dict_2002-19_dist{params["look_distance"]}_Lon{params["kernel_bandwidthLon"]}_lat{params["kernel_bandwidthLat"]}_iter{params["n_iterations"]}.json'  
    print(fn)
    
#     raise
    
    cur_huge_dict_filename = os.path.join(param_optim_folder,fn)
    with open(cur_huge_dict_filename,'r') as f:
        cur_huge_dict = json.load(f)
        

    with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
        for result_key, result_dict, result_dict_stats in tqdm(executor.map(find_matchings_one_image, 
                                                            repeat(cur_huge_dict),
                                                            list(cur_huge_dict.keys())[:],
                                                            repeat(wl_dir),
                                                            repeat(masks_dir),
                                                            repeat(input_type),
                                                            repeat(show)
                                                            )):
            # print(result_key)
#             print(result_dict_stats["ms_too_far"])
            image_out_dict[result_key] = result_dict
            image_out_dict_stats[result_key] = result_dict_stats
            # image_out_dict_stats[result_key] = result_dict['stats']
#             break
   
        
        

    print('num_images: ', len(list(image_out_dict.keys())))
    num_groups = 0
    for k,v in image_out_dict.items():
        # print(k,v)
        if v:
            num_groups += len(v['groups'])
    print("num_groups: ",num_groups)
    # print(image_out_dict)

    grid_image_out_dict[param_idx] = deepcopy(image_out_dict)
    grid_image_out_dict_stats[param_idx] = deepcopy(image_out_dict_stats)

#     break           


0it [00:00, ?it/s]

0 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}
cur_dict_2002-19_dist0.1_Lon0.35_lat0.08_iter20.json


0it [00:00, ?it/s]

il y a un souci
UPH20020424160712
[[-1.4253026   4.83547862]
 [-1.09247329  0.72597214]
 [-0.61633114  1.65385455]
 [-0.51818533  1.95031579]
 [-0.36382902  1.62707913]
 [-0.27493048  0.47727966]
 [-0.24746782  1.10254169]
 [-0.22808715  0.66157459]
 [-0.18240187  1.89199828]
 [ 0.16808046  0.25863839]
 [ 0.1849908   1.2200027 ]
 [ 0.18746656  0.08924538]
 [ 0.1990925   5.64325923]
 [ 0.21406713  5.44660203]
 [ 0.23233274  0.72213837]
 [ 1.02138304  1.67071476]]
il y a un souci
UPH20020427094522
[[-0.29370982  6.04161915]
 [-0.16415706  1.29068314]
 [ 0.07858173  0.28114037]
 [ 0.20015943  0.0639768 ]
 [ 0.21289656  5.58437773]
 [ 0.23160884  6.07792489]
 [ 0.52092379  0.62975058]
 [ 0.53234301  4.61078865]
 [ 0.92223714  5.74547774]
 [ 1.02188899  1.03517208]
 [ 1.35315494  0.48461629]]
il y a un souci
UPH20020428170429
[[-1.39270956  4.05804618]
 [-1.23364673  5.67914648]
 [-0.556701    0.94172172]
 [-0.16441495  0.96301111]
 [ 0.10065892  0.38853537]
 [ 0.20247049  0.10830652]
 [ 0.

ValueError: zero-size array to reduction operation minimum which has no identity

In [17]:
import collections
from collections import Counter

cur_out_stats = {
    # # General info
    # 'num_DB_groups': 0,
    # 'num_MS_groups': 0,
    # 'num_DB_isolated_groups': 0,
    # 'num_DB_overlaping_bboxes': 0,
    # # MS with DB matching info
    # "num_MSmatchesDB":0,
    # # MS with DB rejection info
    # "num_noMS_but_DB_reject":0,
    # "num_singleMS_multipleDB_reject":0,
    # "num_oneDBbbox_multipleMSoverlap_ambiguity_reject":0,
    # # MS with DB no match info
    # "num_noDB_but_MS":0,
            # General info
            'num_DB_groups':0,
            'num_MS_groups':0,

            'matches':0,

            # MS with DB matching info
            'unmatched_db':0,
            'multiDB_singleMS': 0,
            'db_too_far':0,


            'unmatched_ms':0,
            'bad_ms':0,
            'ms_too_far':0,

            "ious":0,
            "distances":0

    }

stats_keys = cur_out_stats.keys()

should_check =  list(stats_keys) + [
                                        'num_rejects_all',
                                        'num_optimizable_rejects',
                                        'rate_optimizable_rejects',
                                    ]

gridsearch_dict_per_img = { k : np.zeros((len(param_grid_values['kernel_bandwidthLon']), 
                                   len(param_grid_values['kernel_bandwidthLat']),
                                    len(wl_list)
                                   ))
                    for k in should_check}

should_check_global = list(stats_keys) + [
                                            'num_rejects_all',
                                            'num_optimizable_rejects',
                                            'rate_optimizable_rejects',
                                         ]
gridsearch_dict_total = { k : np.zeros((len(param_grid_values['kernel_bandwidthLon']), 
                                   len(param_grid_values['kernel_bandwidthLat'])
                                   ))
                    for k in should_check_global}

# grid_search_reject_distribs = { k : {} for param_idx, params in enumerate(param_grid)}
grid_search_reject_distribs = {}


for param_idx, params in enumerate(param_grid):
    param_idx= str(param_idx)
    image_out_dict = grid_image_out_dict[param_idx]
    image_out_dict_stats = grid_image_out_dict_stats[param_idx]
    
    # General info
    num_DB_groups_per_image = np.array([item['num_DB_groups'] for k,item in image_out_dict_stats.items()])
    num_MS_groups_per_image = np.array([item['num_MS_groups'] for k,item in image_out_dict_stats.items()])

    diff_num_groups = num_DB_groups_per_image - num_MS_groups_per_image # should be shown on histogram , closer to 0 is better


    # matches
    num_MSmatchesDB = np.array([len(np.where(v['matches'] != -1)[0]) for k,v in image_out_dict_stats.items()])
    
    num_unmatched_db = np.array([len(v['unmatched_db']) for k,v in image_out_dict_stats.items()])
    num_multiDB_singleMS = np.array([len(v['multiDB_singleMS']) for k,v in image_out_dict_stats.items()])
    num_db_too_far = np.array([len(v['db_too_far']) for k,v in image_out_dict_stats.items()])


    num_unmatchedMS = np.array([len(v['unmatched_ms']) for k,v in image_out_dict_stats.items()])
    num_badMS = np.array([len(v['bad_ms']) for k,v in image_out_dict_stats.items()])
    num_ms_too_far = np.array([len(v['ms_too_far']) for k,v in image_out_dict_stats.items()])
    
  
    num_optimizable_rejects = num_badMS + num_multiDB_singleMS 
    num_rejects_all = num_optimizable_rejects + num_db_too_far

    # get indices of current params
    kernel_bandwidthLon_idx = param_grid_values['kernel_bandwidthLon'].index(params['kernel_bandwidthLon'])
    kernel_bandwidthLat_idx = param_grid_values['kernel_bandwidthLat'].index(params['kernel_bandwidthLat'])
    ############################
    # Fill per-image dictinnary 
    ############################

    # number of DB groups
    gridsearch_dict_per_img['num_DB_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_DB_groups_per_image
    # number of MS groups
    gridsearch_dict_per_img['num_MS_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_MS_groups_per_image

    # number of matches
    # print(num_MSmatchesDB.shape)
    # print(num_MSmatchesDB)
    gridsearch_dict_per_img['matches'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_MSmatchesDB
    # number of unmatched DB groups
    gridsearch_dict_per_img['unmatched_db'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_unmatched_db
    # number of multi DB groups matched to single MS group
    gridsearch_dict_per_img['multiDB_singleMS'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_multiDB_singleMS
    # number of DB groups too far from MS group
    gridsearch_dict_per_img['db_too_far'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_db_too_far
    # number of unmatched MS groups
    gridsearch_dict_per_img['unmatched_ms'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_unmatchedMS
    # number of bad MS groups
    gridsearch_dict_per_img['bad_ms'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_badMS
    # number of MS groups too far from DB group
    gridsearch_dict_per_img['ms_too_far'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_ms_too_far


    gridsearch_dict_per_img['num_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_optimizable_rejects
    gridsearch_dict_per_img['num_rejects_all'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_rejects_all
     
    ############################
    # Fill total dictionnary
    ############################
    # number of DB groups
    gridsearch_dict_total['num_DB_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_DB_groups_per_image)
    # number of MS groups
    gridsearch_dict_total['num_MS_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_MS_groups_per_image)

    # number of matches
    gridsearch_dict_total['matches'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_MSmatchesDB)
    # number of unmatched DB groups
    gridsearch_dict_total['unmatched_db'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_unmatched_db)
    # number of multi DB groups matched to single MS group
    gridsearch_dict_total['multiDB_singleMS'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_multiDB_singleMS)
    # number of DB groups too far from MS group
    gridsearch_dict_total['db_too_far'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_db_too_far)
    # number of unmatched MS groups
    gridsearch_dict_total['unmatched_ms'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_unmatchedMS)
    # number of bad MS groups
    gridsearch_dict_total['bad_ms'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_badMS)
    # number of MS groups too far from DB group
    gridsearch_dict_total['ms_too_far'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_ms_too_far)
    
    

    gridsearch_dict_total['num_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_optimizable_rejects)
    gridsearch_dict_total['num_rejects_all'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_rejects_all)

#     break
    

# gridsearch = { k : pd.DataFrame(v) }


KeyError: '0'

In [33]:
%matplotlib ipympl
# accuracy
# all_isolated = gridsearch_dict_total['num_DB_isolated_groups'] 
# all_overlapping = gridsearch_dict_total['num_DB_overlaping_bboxes'] 

# print('all_isolated: ', all_isolated[0,0],'all_overlapping: ' ,all_overlapping[0,0])
# all_noMS_but_DB = gridsearch_dict_total['num_noMS_but_DB_reject']

all_groups = gridsearch_dict_total['num_DB_groups']

all_too_far = gridsearch_dict_total['db_too_far']

to_count = all_groups - all_too_far

all_optimizable_rejects = gridsearch_dict_total['num_optimizable_rejects']

all_matches = gridsearch_dict_total['matches']

all_unmatched_ms = gridsearch_dict_total['unmatched_ms']
all_ms_too_far = gridsearch_dict_total['ms_too_far']



accuracy = all_matches / (all_matches + all_optimizable_rejects )

accuracy_df = pd.DataFrame(accuracy)
accuracy_df.columns = param_grid_values['kernel_bandwidthLat']
accuracy_df.index = param_grid_values['kernel_bandwidthLon']
accuracy_df = accuracy_df.transpose()


fig_, ax_ = plt.subplots(1,1, figsize=(8*1,5*1))
ax_ = sns.heatmap(accuracy_df, ax=ax_, annot=True)
# ax_ = sns.heatmap(accuracy_df, ax=ax_, vmin=0, annot=True)
ax_.set_title('Accuracy')
ax_.set_ylabel('kernel_bandwidthLat')
ax_.set_xlabel('kernel_bandwidthLon')

fig_.tight_layout()


############
print(all_unmatched_ms.T)
print(all_ms_too_far.T)
print((all_unmatched_ms - all_ms_too_far).T)

accuracy = all_matches / (all_matches + all_optimizable_rejects + (all_unmatched_ms - all_ms_too_far)  )
accuracy_df = pd.DataFrame(accuracy)
accuracy_df.columns = param_grid_values['kernel_bandwidthLat']
accuracy_df.index = param_grid_values['kernel_bandwidthLon']
accuracy_df = accuracy_df.transpose()

fig2_, ax2_ = plt.subplots(1,1, figsize=(8*1,5*1))
ax2_ = sns.heatmap(accuracy_df, ax=ax2_, annot=True)
# ax2_ = sns.heatmap(accuracy_df, ax=ax_, vmin=0, annot=True)
ax2_.set_title('Accuracy')
ax2_.set_ylabel('kernel_bandwidthLat')
ax2_.set_xlabel('kernel_bandwidthLon')

fig2_.tight_layout()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[[11. 12. 12. 12. 12. 12. 12. 12. 12. 12. 11. 11. 11. 11.]
 [12. 12. 12. 12. 12. 12. 12. 11. 11. 11. 11. 11. 11. 11.]
 [12. 12. 12. 12. 12. 11. 11. 11. 11. 11. 10. 10.  9.  9.]
 [12. 12. 12. 12. 11. 10.  9.  9.  8.  8.  8.  8.  7.  7.]
 [11. 11. 11.  9.  8.  7.  6.  6.  6.  6.  6.  6.  6.  6.]
 [11. 10.  7.  7.  7.  6.  6.  6.  6.  6.  6.  6.  6.  5.]]
[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 2.]]
[[8. 9. 9. 9. 9. 9. 9. 9. 9. 9. 8. 8. 8. 8.]
 [9. 9. 9. 9. 9. 9. 9. 8. 8. 8. 8. 8. 8. 8.]
 [9. 9. 9. 9. 9. 8. 8. 8. 8. 8. 7. 7. 6. 6.]
 [9. 9. 9. 9. 8. 7. 6. 6. 5. 5. 5. 5. 4. 4.]
 [8. 8. 8. 6. 5. 4. 3. 3. 3. 3. 3. 3. 3. 3.]
 [8. 7. 4. 4. 4. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Ancienne méthode de comptage

In [25]:

grid_image_out_dict = { }
grid_image_out_dict_stats = { }

for param_idx, params in enumerate(param_grid):
    print(param_idx, params)

    image_out_dict = {}
    image_out_dict_stats = {}

    cur_huge_dict = deepcopy(grid_huge_dict[param_idx])

    for basename in tqdm(list(cur_huge_dict.keys())[:]):
        cur_image_dict = cur_huge_dict[basename]
        
        angle = cur_image_dict["SOLAR_P0"]
        deltashapeX = cur_image_dict["deltashapeX"]
        deltashapeY = cur_image_dict["deltashapeY"]
        
        drawing_radius_px = huge_db_dict[basename]["dr_radius_px"]
        
        group_list = cur_image_dict['db']
        
        ms_dict = cur_image_dict['meanshift']
        
        centroids = np.array(ms_dict["centroids"])
        centroids_px = np.array(ms_dict["centroids_px"])
        
        db_classes = [{"Zurich":item['Zurich'], "McIntosh":item['McIntosh'] } for item in group_list]
        db_bboxes = [np.array(item['bbox_wl']) for item in group_list]
        db_centers_px = np.array([[(b[2]+b[0])/2,(b[3]+b[1])/2] for b in db_bboxes])
           
        # check that current bbox is does not overlap any
        isolated_bboxes_bool = np.array(c_utils.get_intersecting_db_bboxes(db_bboxes)) == 0
        isolated_bboxes_indices = np.where(isolated_bboxes_bool == True)[0]
    #     print("isolated_bboxes_bool",isolated_bboxes_bool)
    #     print(isolated_bboxes_indices)
        
        cur_rejected_class_distibutions = { 
                'noMS_but_DB': {},
                'singleMS_multipleDB': {}, 
                'num_oneDBbbox_multipleMSoverlap_ambiguity':{},
                # 'noDB_but_MS': {},
            }
        cur_out_stats = {
            # General info
            'num_DB_groups':len(db_bboxes),
            'num_MS_groups':len(centroids_px),
            'num_DB_isolated_groups':len(isolated_bboxes_indices),
            'num_DB_overlaping_bboxes':len(db_bboxes) - len(isolated_bboxes_indices),
            # MS with DB matching info
            "num_MSmatchesDB":0,
            # MS with DB rejection info
            "num_noMS_but_DB_reject":0,
            "num_singleMS_multipleDB_reject":0,
            "num_oneDBbbox_multipleMSoverlap_ambiguity_reject":0,
            # MS with DB no match info
            "num_noDB_but_MS":0,
            }
        cur_out_groups = []
        for i, (db_bbox, db_center, db_class) in enumerate(
                                                zip([db_bboxes[a] for a in isolated_bboxes_indices.tolist() ],
                                                    [db_centers_px[a] for a in isolated_bboxes_indices.tolist()],
                                                    [db_classes[a] for a in isolated_bboxes_indices.tolist()],
                                                )):
            
            
            ms_centroids, ms_members = centroids_px, ms_dict['groups_px']
            
            intersect = c_utils.contains_ms_groups(db_bbox, db_center, ms_centroids, ms_members)
            
            if sum(intersect) == 0: # Il n'y a eu aucune détection dans cette zone
                cur_out_stats['num_noMS_but_DB_reject'] += 1
                cause = 'noMS_but_DB'
                add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
                pass
            elif sum(intersect) == 1: # il n'y a de l'overlap qu'avec un seul groupe meanshift            
    #             print('hit')
                idx = intersect.index(True)
    #             print(idx)
                # vérifier que le groupe meanshift n'intersecte aucune autre bbox
                num_intersections = np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes))
                if num_intersections > 1:
                    cur_out_stats['num_singleMS_multipleDB_reject'] += 1
                    cause = 'singleMS_multipleDB'
                    add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
                    continue
                
                Rmm = huge_db_dict[basename]['dr_radius_mm']
                R_pixel = huge_db_dict[basename]['dr_radius_px']
                sun_center = huge_db_dict[basename]['dr_center_px']
                dr_pixpos = np.array([group_list[i]['posx'], group_list[i]['posy']])
                
                angular_excentricity =  c_utils.get_angle2(dr_pixpos, R_pixel, sun_center)
                
                cur_group_dict={
                                "centroid_px": centroids_px[idx],
                                "centroid_Lat": centroids[idx][0],
                                "centroid_Lon": centroids[idx][1],
                                "angular_excentricity_rad": angular_excentricity,
                                "angular_excentricity_deg": np.rad2deg(angular_excentricity),
                                "Zurich":   db_class["Zurich"],
                                "McIntosh": db_class["McIntosh"],
                                "members": ms_members[idx],
                                "members_mean_px": np.mean(ms_members[idx], axis=0),
                            }
                
                
                cur_out_groups.append(cur_group_dict)
                cur_out_stats['num_MSmatchesDB'] += 1

            else: # db_bbox intersecte plusieurs groupes meanshift
                cur_out_stats['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] += 1
                cause = 'num_oneDBbbox_multipleMSoverlap_ambiguity'
                add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
                pass
                
        if len(cur_out_groups) > 0:
            image_out_dict[basename] = { "angle": angle,
                                        "deltashapeX":deltashapeX,
                                        "deltashapeY":deltashapeY,
                                        "groups": cur_out_groups,
                                    }

        # count the number of MS groups that do not have any overlap with the DB
        num_intersections_per_group = [np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes)) for idx in range(len(ms_members))]
        num_MS_without_DB_overlap = len(np.where(np.array(num_intersections_per_group) == 0)[0])
        cur_out_stats['num_noDB_but_MS'] = num_MS_without_DB_overlap
        
        # print(cur_rejected_class_distibutions)
        cur_out_stats['rejected_class_distributions'] = deepcopy(cur_rejected_class_distibutions)
        # print(cur_out_stats)
        image_out_dict_stats[basename] = deepcopy(cur_out_stats)

        

    print('num_images: ', len(list(image_out_dict.keys())))
    num_groups = 0
    for k,v in image_out_dict.items():
        num_groups += len(v['groups'])
    print("num_groups: ",num_groups)
    # print(image_out_dict)

    grid_image_out_dict[param_idx] = deepcopy(image_out_dict)
    grid_image_out_dict_stats[param_idx] = deepcopy(image_out_dict_stats)



0 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  94
1 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  92
2 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  93
3 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  94
4 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  95
5 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  96
6 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  94
7 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  95
8 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  96
9 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  96
10 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  97
11 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  98
12 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  98
13 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  98
14 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  89
15 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  93
16 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  96
17 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  100
18 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  102
19 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  104
20 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  105
21 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  103
22 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  103
23 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  105
24 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  109
25 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  109
26 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  108
27 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  108
28 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  87
29 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  92
30 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  99
31 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  102
32 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  106
33 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  106
34 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  107
35 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  109
36 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  111
37 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  111
38 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  111
39 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  111
40 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  110
41 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  110
42 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  87
43 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  96
44 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  101
45 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  105
46 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  108
47 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  110
48 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  111
49 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  112
50 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  112
51 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  112
52 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  108
53 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  108
54 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  106
55 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  106
56 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  86
57 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  98
58 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  103
59 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  109
60 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  108
61 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  109
62 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  111
63 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  110
64 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  108
65 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  107
66 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  105
67 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  103
68 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  101
69 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  101
70 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  32
num_groups:  85
71 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  33
num_groups:  97
72 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  106
73 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  106
74 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  34
num_groups:  104
75 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  109
76 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  107
77 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  107
78 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  107
79 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  103
80 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  102
81 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  99
82 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  96
83 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  35
num_groups:  94


In [13]:
import collections
from collections import Counter

cur_out_stats = {
    # General info
    'num_DB_groups': 0,
    'num_MS_groups': 0,
    'num_DB_isolated_groups': 0,
    'num_DB_overlaping_bboxes': 0,
    # MS with DB matching info
    "num_MSmatchesDB":0,
    # MS with DB rejection info
    "num_noMS_but_DB_reject":0,
    "num_singleMS_multipleDB_reject":0,
    "num_oneDBbbox_multipleMSoverlap_ambiguity_reject":0,
    # MS with DB no match info
    "num_noDB_but_MS":0,
    }

stats_keys = cur_out_stats.keys()

should_check =  list(stats_keys) + [
                                        'num_rejects_all',
                                        'num_optimizable_rejects',
                                        'rate_optimizable_rejects',
                                    ]

gridsearch_dict_per_img = { k : np.zeros((len(param_grid_values['kernel_bandwidthLon']), 
                                   len(param_grid_values['kernel_bandwidthLat']),
                                    len(wl_list)
                                   ))
                    for k in should_check}

should_check_global = list(stats_keys) + [
                                            'num_rejects_all',
                                            'num_optimizable_rejects',
                                            'rate_optimizable_rejects',
                                         ]
gridsearch_dict_total = { k : np.zeros((len(param_grid_values['kernel_bandwidthLon']), 
                                   len(param_grid_values['kernel_bandwidthLat'])
                                   ))
                    for k in should_check_global}

# grid_search_reject_distribs = { k : {} for param_idx, params in enumerate(param_grid)}
grid_search_reject_distribs = {}





for param_idx, params in enumerate(param_grid):
    image_out_dict = grid_image_out_dict[param_idx]
    image_out_dict_stats = grid_image_out_dict_stats[param_idx]
    
    # General info
    num_DB_groups_per_image = np.array([item['num_DB_groups'] for k,item in image_out_dict_stats.items()])
    num_MS_groups_per_image = np.array([item['num_MS_groups'] for k,item in image_out_dict_stats.items()])

    diff_num_groups = num_DB_groups_per_image - num_MS_groups_per_image # should be shown on histogram , closer to 0 is better

    # number of isolated DB groups per image and mean number in dataset
    num_isolated_DB_groups = np.array([v['num_DB_isolated_groups'] for k,v in image_out_dict_stats.items()])
    # number of DB groups with overlap with other DB groups
    num_DB_overlaping_bboxes = np.array([v['num_DB_overlaping_bboxes'] for k,v in image_out_dict_stats.items()])

    # matches
    num_MSmatchesDB = np.array([v['num_MSmatchesDB'] for k,v in image_out_dict_stats.items()])
    
    # rejects
    # 1) numbers of rejects due to no MS group overlapping DB group (# cannot do anything to this, examples are discarded)
    num_noMS_but_DB_reject = np.array([v['num_noMS_but_DB_reject'] for k,v in image_out_dict_stats.items()])
    # 2) numbers of rejects due to single MS group overlapping multiple DB groups (# minimize this)
    num_singleMS_multipleDB_reject = np.array([v['num_singleMS_multipleDB_reject'] for k,v in image_out_dict_stats.items()])
    # 3) numbers of rejects due to single DB group overlapping multiple MS groups (# minimize this)
    num_oneDBbbox_multipleMSoverlap_ambiguity_reject = np.array([v['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] for k,v in image_out_dict_stats.items()])

    num_optimizable_rejects = num_singleMS_multipleDB_reject + num_oneDBbbox_multipleMSoverlap_ambiguity_reject
    num_rejects_all = num_optimizable_rejects + num_noMS_but_DB_reject

    # rate of optimizable rejects
    rate_optimizable_rejects = num_optimizable_rejects / num_rejects_all

    # print(num_optimizable_rejects)


    # Aggregate reject distributions of images over the dataset
    dict1 = [v['rejected_class_distributions']['noMS_but_DB'] for k,v in image_out_dict_stats.items()]
    summed_dict1 = sum(map(collections.Counter, dict1),Counter())
    dict2 = [v['rejected_class_distributions']['singleMS_multipleDB'] for k,v in image_out_dict_stats.items()]
    summed_dict2 = sum(map(collections.Counter, dict2),Counter())
    dict3 = [v['rejected_class_distributions']['num_oneDBbbox_multipleMSoverlap_ambiguity'] for k,v in image_out_dict_stats.items()]
    summed_dict3 = sum(map(collections.Counter, dict3),Counter())
    dataset_rejected_class_distibutions = { 
        'noMS_but_DB': summed_dict1,
        'singleMS_multipleDB': summed_dict2, 
        'num_oneDBbbox_multipleMSoverlap_ambiguity': summed_dict3,
    }

    # print(dict1)
    # print(summed_dict1)
    


    # get indices of current params
    kernel_bandwidthLon_idx = param_grid_values['kernel_bandwidthLon'].index(params['kernel_bandwidthLon'])
    kernel_bandwidthLat_idx = param_grid_values['kernel_bandwidthLat'].index(params['kernel_bandwidthLat'])
    ############################
    # Fill per-image dictinnary 
    ############################

    # number of DB groups
    gridsearch_dict_per_img['num_DB_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_DB_groups_per_image
    # number of MS groups
    gridsearch_dict_per_img['num_MS_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_MS_groups_per_image

    # number of isolated DB groups
    gridsearch_dict_per_img['num_DB_isolated_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_isolated_DB_groups
    gridsearch_dict_per_img['num_DB_overlaping_bboxes'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_DB_overlaping_bboxes

    # matches
    gridsearch_dict_per_img['num_MSmatchesDB'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_MSmatchesDB
    
    # not optimizable rejects
    gridsearch_dict_per_img['num_noMS_but_DB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_noMS_but_DB_reject
    
    # optimizable rejects
    gridsearch_dict_per_img['num_singleMS_multipleDB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_singleMS_multipleDB_reject
    gridsearch_dict_per_img['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_oneDBbbox_multipleMSoverlap_ambiguity_reject
    gridsearch_dict_per_img['num_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_optimizable_rejects
    
    gridsearch_dict_per_img['num_rejects_all'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_rejects_all
    
    # rate of optimizable rejects
    gridsearch_dict_per_img['rate_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = rate_optimizable_rejects
    
    ############################
    # Fill total dictionnary
    ############################
    # number of DB groups
    gridsearch_dict_total['num_DB_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_DB_groups_per_image)
    # number of MS groups
    gridsearch_dict_total['num_MS_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_MS_groups_per_image)

    # number of isolated DB groups
    gridsearch_dict_total['num_DB_isolated_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_isolated_DB_groups)
    gridsearch_dict_total['num_DB_overlaping_bboxes'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_DB_overlaping_bboxes)

    # matches
    gridsearch_dict_total['num_MSmatchesDB'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_MSmatchesDB)

    # not optimizable rejects
    gridsearch_dict_total['num_noMS_but_DB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_noMS_but_DB_reject)

    # optimizable rejects
    gridsearch_dict_total['num_singleMS_multipleDB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_singleMS_multipleDB_reject)
    gridsearch_dict_total['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_oneDBbbox_multipleMSoverlap_ambiguity_reject)
    gridsearch_dict_total['num_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_optimizable_rejects)

    gridsearch_dict_total['num_rejects_all'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_rejects_all)


    rate_optimizable_rejects_all = np.sum(num_optimizable_rejects) / np.sum(num_rejects_all)
    gridsearch_dict_total['rate_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = rate_optimizable_rejects_all

    grid_search_reject_distribs[param_idx] = deepcopy(dataset_rejected_class_distibutions) 

    
    

gridsearch = { k : pd.DataFrame(v) }


In [36]:
# accuracy
all_isolated = gridsearch_dict_total['num_DB_isolated_groups'] 
all_overlapping = gridsearch_dict_total['num_DB_overlaping_bboxes'] 

print('all_isolated: ', all_isolated[0,0],'all_overlapping: ' ,all_overlapping[0,0])

all_noMS_but_DB = gridsearch_dict_total['num_noMS_but_DB_reject']

optimizable = all_isolated - all_noMS_but_DB

print('all_optimizable', optimizable[0,0])


matches = gridsearch_dict_total['num_MSmatchesDB']

accuracy = matches / optimizable

accuracy_df = pd.DataFrame(accuracy)
accuracy_df.columns = param_grid_values['kernel_bandwidthLat']
accuracy_df.index = param_grid_values['kernel_bandwidthLon']


fig_, ax_ = plt.subplots(1,1, figsize=(3.5*1,3*1))
ax_ = sns.heatmap(accuracy_df, ax=ax_, vmin=0, annot=True)
ax_.set_title('Accuracy')
ax_.set_xlabel('kernel_bandwidthLat')
ax_.set_ylabel('kernel_bandwidthLon')

fig_.tight_layout()


all_isolated:  164.0 all_overlapping:  29.0
all_optimizable 117.0


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [37]:

num_rows_2 = 1
num_cols_2 = 2
fig2, ax2 = plt.subplots(num_rows_2,num_cols_2, figsize=(3.5*num_cols_2,3*num_rows_2))


df2 = pd.DataFrame(np.sum(gridsearch_dict_per_img['num_optimizable_rejects'],axis=-1))
df2.columns = param_grid_values['kernel_bandwidthLat']
df2.index = param_grid_values['kernel_bandwidthLon']
ax_ = sns.heatmap(df2, ax=ax2[0], vmin=0, annot=True)
ax_.set_title('Number of optimizable rejects')
ax_.set_xlabel('kernel_bandwidthLat')
ax_.set_ylabel('kernel_bandwidthLon')

df2 = pd.DataFrame(gridsearch_dict_total['rate_optimizable_rejects'])
df2.columns = param_grid_values['kernel_bandwidthLat']
df2.index = param_grid_values['kernel_bandwidthLon']
ax_ = sns.heatmap(df2, ax=ax2[1], vmin=0, annot=True)
ax_.set_title('Rate of optimizable rejects')
ax_.set_xlabel('kernel_bandwidthLat')
ax_.set_ylabel('kernel_bandwidthLon')

fig2.tight_layout()



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:

# dataset_rejected_class_distibutions = { 
#     'noMS_but_DB': summed_dict1,
#     'singleMS_multipleDB': summed_dict2, 
#     'num_oneDBbbox_multipleMSoverlap_ambiguity': summed_dict3,
# }

# param_grid_values = {
#     'look_distance' : [0.1],
#     'kernel_bandwidthLon' : [ 0.05 , 0.1, 0.15, .2,.21,.22,.23,.24,.25, .3,.35,.45],
#     'kernel_bandwidthLat' : [.08,],
#     # 'kernel_bandwidthLon' : [ 0.05 , 0.1, 0.15, .2, .25, .3,.35,.45],
#     # 'kernel_bandwidthLat' : [ 0.04 , 0.06, .08, .1, .12],
#     # 'kernel_bandwidthLat' : [ 0.05 , 0.06, 0.07, .08, .09, .1, .11, .12],
#     'n_iterations' : [20],
# }


num_rows_distrib = len(param_grid_values['kernel_bandwidthLon'])
num_cols_distrib = len(param_grid_values['kernel_bandwidthLat'])
fig_distrib, ax_distrib = plt.subplots(num_rows_distrib,num_cols_distrib, figsize=(5*num_cols_distrib,3*num_rows_distrib))

# for i in range(num_rows_distrib):
#     for j in range(num_cols_distrib):
#         cur_ax = ax_distrib[i,j]

for param_idx, params in enumerate(param_grid):
#     print(params)
    cur_lat_idx = param_grid_values['kernel_bandwidthLat'].index(params['kernel_bandwidthLat'])
    cur_lon_idx = param_grid_values['kernel_bandwidthLon'].index(params['kernel_bandwidthLon'])

    data = grid_search_reject_distribs[param_idx]
    data_df = pd.DataFrame(data)

    # reorder rows alphabetically
    data_df = data_df.reindex(sorted(data_df.columns), axis=1)
    data_df.sort_index( inplace=True)
    data_df = data_df.fillna(0)
    # data_df = data_df.T


#     display(data_df)

    cur_ax = ax_distrib[cur_lon_idx,cur_lat_idx]
#     cur_ax = ax_distrib[cur_lon_idx]

    # ax_ = data_df['noMS_but_DB'].plot(kind='bar' , ax=cur_ax, stacked=False, alpha=0.5, color='k' )
    # ax_ = data_df['singleMS_multipleDB'].plot(kind='bar' , ax=cur_ax, stacked=False, alpha=0.5 , color='r')
    # ax_ = data_df['num_oneDBbbox_multipleMSoverlap_ambiguity'].plot(kind='bar' , ax=cur_ax, stacked=False, alpha=0.5 , color='b')


    # ax_ = data_df.plot(kind='bar' , ax=cur_ax, stacked=True, alpha=0.5 )
#     deta_df2 = 
    ax_ = data_df.plot(kind='bar' , ax=cur_ax, alpha=0.5, color=['k','r','b'],label=['O','U','I'] )

    ax_.set_title(f'Rejected classes distribution \n Lon = {params["kernel_bandwidthLon"]}, Lat = {params["kernel_bandwidthLat"]}')
    # ax_.set_xlabel(f'kernel_bandwidthLat = {params["kernel_bandwidthLat"]}')
    # ax_.set_ylabel(f'kernel_bandwidthLon = {params["kernel_bandwidthLon"]}')

    # break

    

        


fig_distrib.tight_layout()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:


def on_value_change(change):
   

    # print(ax)
    # print(len(ax.collections))
    for i in range(num_cols):
        if len(ax[i].collections) > 0:
            cb = ax[i].collections[-1].colorbar
            if cb is not None:
                cb.remove()
            ax[i].clear()

    df = pd.DataFrame(gridsearch_dict_per_img['num_optimizable_rejects'][:, :, a_slider.value])
    df.columns = param_grid_values['kernel_bandwidthLat']
    df.index = param_grid_values['kernel_bandwidthLon']
    # ax0 = sns.heatmap(df, ax=ax, cbar=False,)
    ax0 = sns.heatmap(df, ax=ax[0], vmin=0, annot=True)
    ax0.set_title('Number of optimizable rejects')
    ax0.set_xlabel('kernel_bandwidthLat')
    ax0.set_ylabel('kernel_bandwidthLon')

    df = pd.DataFrame(gridsearch_dict_per_img['num_singleMS_multipleDB_reject'][:, :, a_slider.value])
    df.columns = param_grid_values['kernel_bandwidthLat']
    df.index = param_grid_values['kernel_bandwidthLon']
    # ax1 = sns.heatmap(df, ax=ax, cbar=False,)
    ax1 = sns.heatmap(df, ax=ax[1], vmin=0, annot=True)
    ax1.set_title('Only MS candidate covers \n2+ DB bboxes')
    ax1.set_xlabel('kernel_bandwidthLat')
    ax1.set_ylabel('kernel_bandwidthLon')
    
    df = pd.DataFrame(gridsearch_dict_per_img['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'][:, :, a_slider.value])
    df.columns = param_grid_values['kernel_bandwidthLat']
    df.index = param_grid_values['kernel_bandwidthLon']
    # ax0 = sns.heatmap(df, ax=ax, cbar=False,)
    ax2 = sns.heatmap(df, ax=ax[2], vmin=0, annot=True)
    ax2.set_title('DB group matches \nmultiple MS groups')
    ax2.set_xlabel('kernel_bandwidthLat')
    ax2.set_ylabel('kernel_bandwidthLon')

    fig.tight_layout()

    # pass

a_slider = widgets.IntSlider(min=0, max=len(wl_list)-1, step=1, value=0)
a_slider.observe(on_value_change, names='value')

plt.ioff()
num_rows = 1
num_cols = 3
fig, ax = plt.subplots(num_rows,num_cols, figsize=(3*num_cols,3*num_rows))
on_value_change(None)
plt.ion()

display(widgets.VBox([a_slider, fig.canvas]))


VBox(children=(IntSlider(value=0, max=35), Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', '…

In [None]:
look_distance = .05  # How far to look for neighbours.
kernel_bandwidthLon = .25  # Longitude Kernel parameter.
kernel_bandwidthLat = .08  # Latitude Kernel parameter.
n_iterations = 1000 # Number of iterations



In [None]:
#create set of 5 ints
def create_set_of_5_ints():
    return set(range(300,525,25))

def all_pairs(s):
    '''Returns all pairs of elements in the set s such that (a, b) == (b, a), 
    and (a, a) is not included in the result.'''
    return set((min(x, y), max(x, y)) for x in s for y in s if (x != y) and np.abs(x-y)>25)

my_set = create_set_of_5_ints()
print(my_set)
print(all_pairs(my_set))
print(len(all_pairs(my_set)))




# Test overlapping bboxes

In [29]:

grid_image_out_dict_overlapping = { }
grid_image_out_dict_stats_overlapping = { }

for param_idx, params in enumerate(param_grid):
    print(param_idx, params)

    image_out_dict = {}
    image_out_dict_stats = {}

    cur_huge_dict = deepcopy(grid_huge_dict[param_idx])

    for basename in tqdm(list(cur_huge_dict.keys())[:]):
        cur_image_dict = cur_huge_dict[basename]
        
        angle = cur_image_dict["SOLAR_P0"]
        deltashapeX = cur_image_dict["deltashapeX"]
        deltashapeY = cur_image_dict["deltashapeY"]
        
        drawing_radius_px = huge_db_dict[basename]["dr_radius_px"]
        
        group_list = cur_image_dict['db']
        
        ms_dict = cur_image_dict['meanshift']
        
        centroids = np.array(ms_dict["centroids"])
        centroids_px = np.array(ms_dict["centroids_px"])
        
        db_classes = [{"Zurich":item['Zurich'], "McIntosh":item['McIntosh'] } for item in group_list]
        db_bboxes = [np.array(item['bbox_wl']) for item in group_list]
        db_centers_px = np.array([[(b[2]+b[0])/2,(b[3]+b[1])/2] for b in db_bboxes])
           
        # check that current bbox is does not overlap any
        isolated_bboxes_bool = np.array(c_utils.get_intersecting_db_bboxes(db_bboxes)) == 0
        overlapping_bboxes_indices = np.where(isolated_bboxes_bool == False)[0]

        
        cur_rejected_class_distibutions = { 
                'noMS_but_DB': {},
                'singleMS_multipleDB': {}, 
                'num_oneDBbbox_multipleMSoverlap_ambiguity':{},
                # 'noDB_but_MS': {},
            }
        cur_out_stats = {
            # General info
            'num_DB_groups':len(db_bboxes),
            'num_MS_groups':len(centroids_px),
            'num_DB_isolated_groups':len(db_bboxes) - len(overlapping_bboxes_indices),
            'num_DB_overlaping_bboxes':len(overlapping_bboxes_indices),
            # MS with DB matching info
            "num_MSmatchesDB":0,
            # MS with DB rejection info
            "num_noMS_but_DB_reject":0,
            "num_singleMS_multipleDB_reject":0,
            "num_oneDBbbox_multipleMSoverlap_ambiguity_reject":0,
            # MS with DB no match info
            "num_noDB_but_MS":0,
            }
        cur_out_groups = []
        for i, (db_bbox, db_center, db_class) in enumerate(
                                                zip([db_bboxes[a] for a in overlapping_bboxes_indices.tolist() ],
                                                    [db_centers_px[a] for a in overlapping_bboxes_indices.tolist()],
                                                    [db_classes[a] for a in overlapping_bboxes_indices.tolist()],
                                                )):
            
            # matcher avec le ms gorup le plus proche + garde fou de distance
            ms_centroids, ms_members = centroids_px, ms_dict['groups_px']
            
            intersect = c_utils.contains_ms_groups(db_bbox, db_center, ms_centroids, ms_members)
            
            if sum(intersect) == 0: # Il n'y a eu aucune détection dans cette zone
                cur_out_stats['num_noMS_but_DB_reject'] += 1
                cause = 'noMS_but_DB'
                add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
                pass
            elif sum(intersect) == 1: # il n'y a de l'overlap qu'avec un seul groupe meanshift            
    #             print('hit')
                idx = intersect.index(True)
    #             print(idx)
                # vérifier que le groupe meanshift n'intersecte aucune autre bbox
                num_intersections = np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes))
                if num_intersections > 1:
                    cur_out_stats['num_singleMS_multipleDB_reject'] += 1
                    cause = 'singleMS_multipleDB'
                    add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
                    continue
                
                Rmm = huge_db_dict[basename]['dr_radius_mm']
                R_pixel = huge_db_dict[basename]['dr_radius_px']
                sun_center = huge_db_dict[basename]['dr_center_px']
                dr_pixpos = np.array([group_list[i]['posx'], group_list[i]['posy']])
                
                angular_excentricity =  c_utils.get_angle2(dr_pixpos, R_pixel, sun_center)
                
                cur_group_dict={
                                "centroid_px": centroids_px[idx],
                                "centroid_Lat": centroids[idx][0],
                                "centroid_Lon": centroids[idx][1],
                                "angular_excentricity_rad": angular_excentricity,
                                "angular_excentricity_deg": np.rad2deg(angular_excentricity),
                                "Zurich":   db_class["Zurich"],
                                "McIntosh": db_class["McIntosh"],
                                "members": ms_members[idx],
                                "members_mean_px": np.mean(ms_members[idx], axis=0),
                            }
                
                
                cur_out_groups.append(cur_group_dict)
                cur_out_stats['num_MSmatchesDB'] += 1

            else: # db_bbox intersecte plusieurs groupes meanshift
                cur_out_stats['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] += 1
                cause = 'num_oneDBbbox_multipleMSoverlap_ambiguity'
                add_rejected_to_distributions(cur_rejected_class_distibutions[cause], db_class["McIntosh"][0])
                pass
                
        if len(cur_out_groups) > 0:
            image_out_dict[basename] = { "angle": angle,
                                        "deltashapeX":deltashapeX,
                                        "deltashapeY":deltashapeY,
                                        "groups": cur_out_groups,
                                    }

        # count the number of MS groups that do not have any overlap with the DB
        num_intersections_per_group = [np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes)) for idx in range(len(ms_members))]
        num_MS_without_DB_overlap = len(np.where(np.array(num_intersections_per_group) == 0)[0])
        cur_out_stats['num_noDB_but_MS'] = num_MS_without_DB_overlap
        
        # print(cur_rejected_class_distibutions)
        cur_out_stats['rejected_class_distributions'] = deepcopy(cur_rejected_class_distibutions)
        # print(cur_out_stats)
        image_out_dict_stats[basename] = deepcopy(cur_out_stats)

        

    print('num_images: ', len(list(image_out_dict.keys())))
    num_groups = 0
    for k,v in image_out_dict.items():
        num_groups += len(v['groups'])
    print("num_groups: ",num_groups)
    # print(image_out_dict)

    grid_image_out_dict_overlapping[param_idx] = deepcopy(image_out_dict)
    grid_image_out_dict_stats_overlapping[param_idx] = deepcopy(image_out_dict_stats)



0 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
1 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
2 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
3 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
4 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
5 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
6 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
7 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
8 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
9 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
10 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
11 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
12 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
13 {'kernel_bandwidthLat': 0.02, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
14 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
15 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
16 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
17 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
18 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
19 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
20 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
21 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
22 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
23 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
24 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
25 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
26 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
27 {'kernel_bandwidthLat': 0.04, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
28 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
29 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
30 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
31 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
32 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
33 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
34 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
35 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
36 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
37 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
38 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
39 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
40 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
41 {'kernel_bandwidthLat': 0.06, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
42 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
43 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
44 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
45 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
46 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
47 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
48 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
49 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
50 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
51 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
52 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
53 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
54 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
55 {'kernel_bandwidthLat': 0.08, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
56 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
57 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
58 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
59 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
60 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  16
61 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
62 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
63 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
64 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  17
65 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
66 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
67 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
68 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
69 {'kernel_bandwidthLat': 0.1, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  15
70 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.05, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  13
71 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.1, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
72 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.15, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  7
num_groups:  14
73 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.2, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  6
num_groups:  14
74 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.25, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  6
num_groups:  15
75 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.3, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  6
num_groups:  15
76 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.35, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  6
num_groups:  15
77 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.45, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  6
num_groups:  13
78 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.5, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  6
num_groups:  11
79 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.55, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  6
num_groups:  11
80 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.6, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  5
num_groups:  9
81 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.65, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  5
num_groups:  9
82 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.7, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  5
num_groups:  9
83 {'kernel_bandwidthLat': 0.12, 'kernel_bandwidthLon': 0.75, 'look_distance': 0.1, 'n_iterations': 20}


  0%|          | 0/36 [00:00<?, ?it/s]

num_images:  5
num_groups:  9


In [30]:
import collections
from collections import Counter

cur_out_stats = {
    # General info
    'num_DB_groups': 0,
    'num_MS_groups': 0,
    'num_DB_isolated_groups': 0,
    'num_DB_overlaping_bboxes': 0,
    # MS with DB matching info
    "num_MSmatchesDB":0,
    # MS with DB rejection info
    "num_noMS_but_DB_reject":0,
    "num_singleMS_multipleDB_reject":0,
    "num_oneDBbbox_multipleMSoverlap_ambiguity_reject":0,
    # MS with DB no match info
    "num_noDB_but_MS":0,
    }

stats_keys = cur_out_stats.keys()

should_check =  list(stats_keys) + [
                                        'num_rejects_all',
                                        'num_optimizable_rejects',
                                        'rate_optimizable_rejects',
                                    ]

gridsearch_dict_per_img_overlapping = { k : np.zeros((len(param_grid_values['kernel_bandwidthLon']), 
                                   len(param_grid_values['kernel_bandwidthLat']),
                                    len(wl_list)
                                   ))
                    for k in should_check}

should_check_global = list(stats_keys) + [
                                            'num_rejects_all',
                                            'num_optimizable_rejects',
                                            'rate_optimizable_rejects',
                                         ]
gridsearch_dict_total_overlapping = { k : np.zeros((len(param_grid_values['kernel_bandwidthLon']), 
                                   len(param_grid_values['kernel_bandwidthLat'])
                                   ))
                    for k in should_check_global}

# grid_search_reject_distribs = { k : {} for param_idx, params in enumerate(param_grid)}
grid_search_reject_distribs_overlapping = {}





for param_idx, params in enumerate(param_grid):
    image_out_dict = grid_image_out_dict_overlapping[param_idx]
    image_out_dict_stats = grid_image_out_dict_stats_overlapping[param_idx]
    
    # General info
    num_DB_groups_per_image = np.array([item['num_DB_groups'] for k,item in image_out_dict_stats.items()])
    num_MS_groups_per_image = np.array([item['num_MS_groups'] for k,item in image_out_dict_stats.items()])

    diff_num_groups = num_DB_groups_per_image - num_MS_groups_per_image # should be shown on histogram , closer to 0 is better

    # number of isolated DB groups per image and mean number in dataset
    num_isolated_DB_groups = np.array([v['num_DB_isolated_groups'] for k,v in image_out_dict_stats.items()])
    # number of DB groups with overlap with other DB groups
    num_DB_overlaping_bboxes = np.array([v['num_DB_overlaping_bboxes'] for k,v in image_out_dict_stats.items()])

    # matches
    num_MSmatchesDB = np.array([v['num_MSmatchesDB'] for k,v in image_out_dict_stats.items()])
    
    # rejects
    # 1) numbers of rejects due to no MS group overlapping DB group (# cannot do anything to this, examples are discarded)
    num_noMS_but_DB_reject = np.array([v['num_noMS_but_DB_reject'] for k,v in image_out_dict_stats.items()])
    # 2) numbers of rejects due to single MS group overlapping multiple DB groups (# minimize this)
    num_singleMS_multipleDB_reject = np.array([v['num_singleMS_multipleDB_reject'] for k,v in image_out_dict_stats.items()])
    # 3) numbers of rejects due to single DB group overlapping multiple MS groups (# minimize this)
    num_oneDBbbox_multipleMSoverlap_ambiguity_reject = np.array([v['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] for k,v in image_out_dict_stats.items()])

    num_optimizable_rejects = num_singleMS_multipleDB_reject + num_oneDBbbox_multipleMSoverlap_ambiguity_reject
    num_rejects_all = num_optimizable_rejects + num_noMS_but_DB_reject

    # rate of optimizable rejects
    rate_optimizable_rejects = num_optimizable_rejects / num_rejects_all

    # print(num_optimizable_rejects)


    # Aggregate reject distributions of images over the dataset
    dict1 = [v['rejected_class_distributions']['noMS_but_DB'] for k,v in image_out_dict_stats.items()]
    summed_dict1 = sum(map(collections.Counter, dict1),Counter())
    dict2 = [v['rejected_class_distributions']['singleMS_multipleDB'] for k,v in image_out_dict_stats.items()]
    summed_dict2 = sum(map(collections.Counter, dict2),Counter())
    dict3 = [v['rejected_class_distributions']['num_oneDBbbox_multipleMSoverlap_ambiguity'] for k,v in image_out_dict_stats.items()]
    summed_dict3 = sum(map(collections.Counter, dict3),Counter())
    dataset_rejected_class_distibutions = { 
        'noMS_but_DB': summed_dict1,
        'singleMS_multipleDB': summed_dict2, 
        'num_oneDBbbox_multipleMSoverlap_ambiguity': summed_dict3,
    }

    # print(dict1)
    # print(summed_dict1)
    


    # get indices of current params
    kernel_bandwidthLon_idx = param_grid_values['kernel_bandwidthLon'].index(params['kernel_bandwidthLon'])
    kernel_bandwidthLat_idx = param_grid_values['kernel_bandwidthLat'].index(params['kernel_bandwidthLat'])
    ############################
    # Fill per-image dictinnary 
    ############################

    # number of DB groups
    gridsearch_dict_per_img_overlapping['num_DB_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_DB_groups_per_image
    # number of MS groups
    gridsearch_dict_per_img_overlapping['num_MS_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_MS_groups_per_image

    # number of isolated DB groups
    gridsearch_dict_per_img_overlapping['num_DB_isolated_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_isolated_DB_groups
    gridsearch_dict_per_img_overlapping['num_DB_overlaping_bboxes'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_DB_overlaping_bboxes

    # matches
    gridsearch_dict_per_img_overlapping['num_MSmatchesDB'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_MSmatchesDB
    
    # not optimizable rejects
    gridsearch_dict_per_img_overlapping['num_noMS_but_DB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_noMS_but_DB_reject
    
    # optimizable rejects
    gridsearch_dict_per_img_overlapping['num_singleMS_multipleDB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_singleMS_multipleDB_reject
    gridsearch_dict_per_img_overlapping['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_oneDBbbox_multipleMSoverlap_ambiguity_reject
    gridsearch_dict_per_img_overlapping['num_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_optimizable_rejects
    
    gridsearch_dict_per_img_overlapping['num_rejects_all'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = num_rejects_all
    
    # rate of optimizable rejects
    gridsearch_dict_per_img_overlapping['rate_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = rate_optimizable_rejects
    
    ############################
    # Fill total dictionnary
    ############################
    # number of DB groups
    gridsearch_dict_total_overlapping['num_DB_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_DB_groups_per_image)
    # number of MS groups
    gridsearch_dict_total_overlapping['num_MS_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_MS_groups_per_image)

    # number of isolated DB groups
    gridsearch_dict_total_overlapping['num_DB_isolated_groups'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_isolated_DB_groups)
    gridsearch_dict_total_overlapping['num_DB_overlaping_bboxes'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_DB_overlaping_bboxes)

    # matches
    gridsearch_dict_total_overlapping['num_MSmatchesDB'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_MSmatchesDB)

    # not optimizable rejects
    gridsearch_dict_total_overlapping['num_noMS_but_DB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_noMS_but_DB_reject)

    # optimizable rejects
    gridsearch_dict_total_overlapping['num_singleMS_multipleDB_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_singleMS_multipleDB_reject)
    gridsearch_dict_total_overlapping['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_oneDBbbox_multipleMSoverlap_ambiguity_reject)
    gridsearch_dict_total_overlapping['num_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_optimizable_rejects)

    gridsearch_dict_total_overlapping['num_rejects_all'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = np.sum(num_rejects_all)


    rate_optimizable_rejects_all = np.sum(num_optimizable_rejects) / np.sum(num_rejects_all)
    gridsearch_dict_total_overlapping['rate_optimizable_rejects'][kernel_bandwidthLon_idx,kernel_bandwidthLat_idx] = rate_optimizable_rejects_all

    grid_search_reject_distribs_overlapping[param_idx] = deepcopy(dataset_rejected_class_distibutions) 

    
    

gridsearch = { k : pd.DataFrame(v) }


In [32]:
# accuracy
all_isolated = gridsearch_dict_total_overlapping['num_DB_isolated_groups'] 
all_overlapping = gridsearch_dict_total_overlapping['num_DB_overlaping_bboxes'] 

print('all_isolated: ', all_isolated[0,0],'all_overlapping: ' ,all_overlapping[0,0])

all_noMS_but_DB = gridsearch_dict_total_overlapping['num_noMS_but_DB_reject']

optimizable = all_overlapping - all_noMS_but_DB

print('all_optimizable', optimizable[0,0])


matches = gridsearch_dict_total_overlapping['num_MSmatchesDB']

accuracy = matches / optimizable

accuracy_df = pd.DataFrame(accuracy)
accuracy_df.columns = param_grid_values['kernel_bandwidthLat']
accuracy_df.index = param_grid_values['kernel_bandwidthLon']


fig_, ax_ = plt.subplots(1,1, figsize=(3.5*1,3*1))
ax_ = sns.heatmap(accuracy_df, ax=ax_, vmin=0, annot=True)
ax_.set_title('Accuracy')
ax_.set_xlabel('kernel_bandwidthLat')
ax_.set_ylabel('kernel_bandwidthLon')

fig_.tight_layout()


all_isolated:  164.0 all_overlapping:  29.0
all_optimizable 27.0


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [35]:

def on_value_change_overlapping(change):
   

    # print(ax)
    # print(len(ax.collections))
    for i in range(num_cols):
        if len(ax[i].collections) > 0:
            cb = ax[i].collections[-1].colorbar
            if cb is not None:
                cb.remove()
            ax[i].clear()

    df = pd.DataFrame(gridsearch_dict_per_img_overlapping['num_optimizable_rejects'][:, :, a_slider.value])
    df.columns = param_grid_values['kernel_bandwidthLat']
    df.index = param_grid_values['kernel_bandwidthLon']
    # ax0 = sns.heatmap(df, ax=ax, cbar=False,)
    ax0 = sns.heatmap(df, ax=ax[0], vmin=0, annot=True)
    ax0.set_title('Number of optimizable rejects')
    ax0.set_xlabel('kernel_bandwidthLat')
    ax0.set_ylabel('kernel_bandwidthLon')

    df = pd.DataFrame(gridsearch_dict_per_img_overlapping['num_singleMS_multipleDB_reject'][:, :, a_slider.value])
    df.columns = param_grid_values['kernel_bandwidthLat']
    df.index = param_grid_values['kernel_bandwidthLon']
    # ax1 = sns.heatmap(df, ax=ax, cbar=False,)
    ax1 = sns.heatmap(df, ax=ax[1], vmin=0, annot=True)
    ax1.set_title('Only MS candidate covers \n2+ DB bboxes')
    ax1.set_xlabel('kernel_bandwidthLat')
    ax1.set_ylabel('kernel_bandwidthLon')
    
    df = pd.DataFrame(gridsearch_dict_per_img_overlapping['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'][:, :, a_slider.value])
    df.columns = param_grid_values['kernel_bandwidthLat']
    df.index = param_grid_values['kernel_bandwidthLon']
    # ax0 = sns.heatmap(df, ax=ax, cbar=False,)
    ax2 = sns.heatmap(df, ax=ax[2], vmin=0, annot=True)
    ax2.set_title('DB group matches \nmultiple MS groups')
    ax2.set_xlabel('kernel_bandwidthLat')
    ax2.set_ylabel('kernel_bandwidthLon')

    fig.tight_layout()

    # pass

a_slider = widgets.IntSlider(min=0, max=len(wl_list)-1, step=1, value=0)
a_slider.observe(on_value_change_overlapping, names='value')

plt.ioff()
num_rows = 1
num_cols = 3
fig, ax = plt.subplots(num_rows,num_cols, figsize=(3*num_cols,3*num_rows))
on_value_change_overlapping(None)
plt.ion()

display(widgets.VBox([a_slider, fig.canvas]))


VBox(children=(IntSlider(value=0, max=35), Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', '…

In [4]:
# Toy example
look_distance = .1  # How far to look for neighbours.
kernel_bandwidthLon = .25  # Longitude Kernel parameter.
kernel_bandwidthLat = .08  # Latitude Kernel parameter.
n_iterations = 5 # Number of iterations

radius = 695700

points_latLon = np.array([[1, 0],
                          [0,2.6],
                          [0,2.7]])
areas = [1000, 500 , 510]


ms_model = MS.Mean_Shift(look_distance, kernel_bandwidthLon, kernel_bandwidthLat, radius, n_iterations)
ms_model.fit(points_latLon, areas)
ms_centroids = ms_model.centroids
print()
print(ms_model.get_area_weighted_ellipsis_width(areas[2],areas))

print(ms_centroids)




0.1550124999999984
[[0.         2.65049505]
 [1.         0.        ]]
