In [1]:
import os
import glob
import json
import numpy as np
from tqdm.notebook import tqdm

from astropy.io import fits
from astropy.wcs import WCS
from astropy.wcs.utils import skycoord_to_pixel
from astropy.coordinates import SkyCoord
import astropy.units as u
from astropy.coordinates import Angle

# from sunpy.physics.differential_rotation import diff_rot, solar_rotate_coordinate
from sunpy.coordinates import frames , transformations

from copy import deepcopy

import concurrent.futures
from itertools import repeat
import multiprocessing

from matplotlib import rc
import matplotlib.patches as patches 
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import cv2
import skimage.io as io

import ipywidgets as widgets
%matplotlib widget

import importlib
import tracking_utilities as utils #import the module here, so that it can be reloaded.
importlib.reload(utils)
import Class2Bbox as c2bb
importlib.reload(c2bb)


import MeanShift as MS
import clustering_utilities as c_utils

%reload_ext autoreload
%autoreload 2



# Process Predictions

In [2]:

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.int64):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, np.bool_):
            return bool(obj)
        return super(NpEncoder, self).default(obj)

In [3]:
# wl_dir = "/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/image"
wl_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/all"  
# wl_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/all"  
# wl_dir = "/Users/nielssayez/Documents/Deepsun/Classification_dataset/2002-2019/all"  
# wl_dir = "/Users/nielssayez/Documents/Deepsun/Classification_dataset/ManualAnnotation/"  
wl_list = sorted(glob.glob(os.path.join(wl_dir, '**/*.FTS'),recursive=True))
# wl_basenames = [ os.path.basename(wl) for wl in wl_list ]
wl_basenames = [ os.path.basename(wl).split('.')[0] for wl in wl_list ]


masks_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/T425-T375-T325_fgbg'
# masks_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/feb2023/T425-T375-T325_fgbg'
# masks_dir = '/home/ucl/elen/nsayez/bio-blueprints/outputs/2023-01-22/01-18-26_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run0/predictions_4TTA_ManualAnnotation'
# masks_dir = '/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/GroundTruth'
# masks_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/T400-T350-Alternating_pen_um'
# masks_dir = '/Users/nielssayez/Documents/Deepsun/Classification_dataset/ManualAnnotation/GroundTruth'

sqlite_db_path = "/globalscratch/users/n/s/nsayez/Classification_dataset/drawings_sqlite.sqlite"
# sqlite_db_path = "/Users/nielssayez/Documents/Deepsun/Classification_dataset/drawings_sqlite.sqlite"
database = sqlite_db_path
print(len(wl_list), )

3150


In [4]:
# wl_basenames

In [5]:
db_drawings_datetimes = [c_utils.db_string_to_datetime(item[0]) for item in c_utils.get_unique_drawing_datetimes(database,'drawings')]
dr_basenames = [c_utils.datetime_to_drawing_name(item) for item in db_drawings_datetimes]

In [6]:
huge_db_dict = c_utils.wl_list2dbGroups(wl_list[:], dr_basenames, database)
# print(list(huge_db_dict.items())[0])

  0%|          | 0/3150 [00:00<?, ?it/s]

In [7]:
# # root_dir = '/Users/nielssayez/Documents/Deepsun/Classification_dataset'
# root_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset'

# # tmp = root_dir+'/2002-2019/T400_transunet_wl_list2dbGroups_Classification.json'
# tmp = root_dir+'/2002-2019/jan2023/wl_list2dbGroups_Classification.json'
# # tmp = root_dir+'/ManualAnnotation/wl_list2dbGroups_Classification.json'

# root_dir = './json_output'
# # root_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset'
# tmp = root_dir+'/ManualAnnotation/wl_list2dbGroups_Classification.json'


# root_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/feb2023'
root_dir = '/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/'
tmp = root_dir+'/wl_list2dbGroups_Classification.json'

In [8]:
import json
print(tmp)
with open(tmp, 'w') as f:
    json.dump(huge_db_dict, f, cls=NpEncoder)

/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2//wl_list2dbGroups_Classification.json


In [9]:
import json
with open(tmp, 'r') as f:
    huge_db_dict = json.load(f)

In [10]:
# indexes of images that should not be taken into account as the segmentation is really bad
# rotten_list = [ ]
rotten_list = [
    
    37,38,39,40,52, 64,65,69,70,
    
    72,97,99,100,101,102,103,104,142,159,160,161,169,187,190,211,212,218,264,300,312,314,316,319,322,327,339,
    343,353,356,387,408,413,414,418,424,425,448,473,474,493,512,508,611,614,666,675,696,726,330,747,750,758,
    761,784,804,823,832,840,855,914,935,940,948,990,1013,
    
    1025,1039,1040,1089,1172,1303,1332,1345,1397,1409,1413,1414,1421,1440,1444,1468,1469,1488,1576,1646,1692,
    1735,1815,1840,1867,1893,1900,1905,1919,1924,1925,1930,1953,1969,1992,
    
    2007,2039,2043,2045,2049,2050,2078,2121,2133,2143,2185,2208,2220,2254,2266,2272,2298,2344,3262,3274,2375,
    2445,2454,2468,2492,2494,2495,2500,2501,2503,2516,2518,2536,2568,2574,2598,2604,2633,2635,2749,2763,2815,
    2818,2820,2821,2834,2835,2851,2857,2867,2896,2899,2848,2951,2952,2956,2964,2980,2981,2994,
    
    3018,3092,3093,3097,3099,3101,3106,3118,3122,3123,3124,3140,3148
    
    
]
# rotten_list = [                    
#                     69,75,93,94,201,203,311,315,337,341,
#                     403,409,420,441,613,615,668,726,743,
#                     755,778,779,976,996,
                        
#                     1036,1066,1081,1138,1255,1296,1337,1379,1398,1471,
#                     1688,1735,1823,1925,1958,1989,1990,1995,
    
#                     2018,2030,2073,2078,2104,2139,2151,2204,2205,2206,2230,2300,
#                     2312,2325,2327,2349,2376,2452,2460,2461,2559,2673,2778,
#                     2790,2793,2826,2836,2894,2897,2962,
    
#                     3035,3142,3172,3178,3196,3217,3224,3242,3292,3296,
#                     3312,3317,3335,3353,3389,3390,3415,3431,3510,3535,
#                     3551,3571,3614,3664,3720,
#               ]
print(len(rotten_list))

183


In [11]:
print(len(huge_dict.keys()))

NameError: name 'huge_dict' is not defined

In [82]:
import json
with open(root_dir+'/meanshift_Classification.json', 'w') as f:
    json.dump(huge_dict, f, cls=NpEncoder)

In [9]:
import json
with open(root_dir+'/meanshift_Classification.json', 'r') as f:
    huge_dict = json.load(f)
print(len(huge_dict.keys()))

3653


In [12]:
huge_dict = {}

In [13]:
def generate_batch(lst, batch_size):
    """  Yields bacth of specified size """
    if batch_size<=0:
       return
    for i in range(0, len(lst), batch_size):
        yield lst[i : i + batch_size]
        
# [len(batch) for batch in generate_batch(wl_list, 100)]



In [15]:
import multiprocessing as mp

look_distance = .1  # How far to look for neighbours.
kernel_bandwidthLon = .35  # Longitude Kernel parameter.
kernel_bandwidthLat = .08  # Latitude Kernel parameter.
n_iterations = 20 # Number of iterations

# num_cpu = 15
# num_cpu = 7
num_cpu = multiprocessing.cpu_count() // 2

input_type = 'confidence_map'
# input_type = 'mask'

print(num_cpu)

for batch in generate_batch(range(len(wl_list)), 400):
    idx_start, idx_end = batch[0], batch[-1]
#     if idx_start < 2000:
#         continue

    print(idx_start, '-->', idx_end)
    
    
    with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
    #     with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
#     with concurrent.futures.ThreadPoolExecutor(max_workers=int(num_cpu)) as executor:
#     with mp.Pool(num_cpu) as executor:
        for result_key, result_dict in tqdm(executor.map(c_utils.process_one_image, 
                                                                    wl_list[idx_start:idx_end],
                                                                    repeat(huge_db_dict),
                                                                    repeat(deepcopy(huge_dict)),
                                                                    repeat(wl_list),
                                                                    repeat(rotten_list),
                                                                    repeat(masks_dir),
                                                                    repeat(look_distance),
                                                                    repeat(kernel_bandwidthLon),
                                                                    repeat(kernel_bandwidthLat),
                                                                    repeat(n_iterations),
                                                                    repeat(input_type),)
#                                                                     chunksize=max(round(len(wl_list[idx_start:idx_end])//num_cpu),1
                                                                , 
                                                    ):
                    if not len(list(result_dict.keys())) == 0:
#                         print(result_key)
                        huge_dict[result_key] = deepcopy(result_dict)

32
0 --> 399


0it [00:00, ?it/s]

400 --> 799


0it [00:00, ?it/s]

800 --> 1199


0it [00:00, ?it/s]

1200 --> 1599


0it [00:00, ?it/s]

1600 --> 1999


0it [00:00, ?it/s]

2000 --> 2399


0it [00:00, ?it/s]

2400 --> 2799


0it [00:00, ?it/s]

2800 --> 3149


0it [00:00, ?it/s]

In [33]:
# import multiprocessing as mp

# look_distance = .1  # How far to look for neighbours.
# kernel_bandwidthLon = .35  # Longitude Kernel parameter.
# kernel_bandwidthLat = .08  # Latitude Kernel parameter.
# n_iterations = 20 # Number of iterations

# # num_cpu = 20
# # num_cpu = 7
# num_cpu = multiprocessing.cpu_count() - 4

# input_type = 'confidence_map'
# # input_type = 'mask'



# for batch in generate_batch(range(len(wl_list)), 100):
#     idx_start, idx_end = batch[0], batch[-1]
#     if idx_start < 3600:
#         continue
#     print(idx_start, '-->', idx_end)
    
    
#     with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
#     #     with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
#     #     with concurrent.futures.ThreadPoolExecutor(max_workers=int(num_cpu)) as executor:
# #     with mp.Pool(num_cpu) as executor:
#                 for result_key, result_dict in tqdm(executor.map(c_utils.process_one_image, 
#                                                                     wl_list[idx_start:idx_end],
#                                                                     repeat(huge_db_dict),
#                                                                     repeat(huge_dict),
#                                                                     repeat(wl_list),
#                                                                     repeat(rotten_list),
#                                                                     repeat(masks_dir),
#                                                                     repeat(look_distance),
#                                                                     repeat(kernel_bandwidthLon),
#                                                                     repeat(kernel_bandwidthLat),
#                                                                     repeat(n_iterations),
#                                                                     repeat(input_type),
# #                                                                     chunksize=max(round(len(wl_list[idx_start:idx_end])//num_cpu),1)
#                                                                 ), 
#                                                     ):
#                     if not len(list(result_dict.keys())) == 0:
#                         print(result_key)
#                         huge_dict[result_key] = deepcopy(result_dict)

In [11]:


# look_distance = .1  # How far to look for neighbours.
# kernel_bandwidthLon = .35  # Longitude Kernel parameter.
# kernel_bandwidthLat = .08  # Latitude Kernel parameter.
# n_iterations = 20 # Number of iterations

# num_cpu = 20
# # num_cpu = 48
# # num_cpu = multiprocessing.cpu_count()

# input_type = 'confidence_map'
# # input_type = 'mask'

# for batch in generate_batch(wl_list, 100):
#     with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
# #     with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
# #     with concurrent.futures.ThreadPoolExecutor(max_workers=int(num_cpu)) as executor:
#             for result_key, result_dict in tqdm(executor.map(c_utils.process_one_image, 
#                                                                 batch[:],
#                                                                 repeat(huge_db_dict),
#                                                                 repeat(huge_dict),
#                                                                 repeat(wl_list),
#                                                                 repeat(rotten_list),
#                                                                 repeat(masks_dir),
#                                                                 repeat(look_distance),
#                                                                 repeat(kernel_bandwidthLon),
#                                                                 repeat(kernel_bandwidthLat),
#                                                                 repeat(n_iterations),
#                                                                 repeat(input_type),
#                                                                 chunksize=100), 
#                                                 ):
#                 if not len(list(result_dict.keys())) == 0:
#                     print(result_key)
#                     huge_dict[result_key] = deepcopy(result_dict)

In [None]:
for im in tqdm(wl_list[:]):
    result_key, result_dict = c_utils.process_one_image(
                                                                im,
                                                                huge_db_dict,
                                                                huge_dict,
                                                                wl_list,
                                                                rotten_list,
                                                                masks_dir,
                                                                look_distance,
                                                                kernel_bandwidthLon,
                                                                kernel_bandwidthLat,
                                                                n_iterations,
                                                                input_type
                                                        )
#     print(result_key)
    if not len(list(result_dict.keys())) == 0:
        huge_dict[result_key] = result_dict

  0%|          | 0/3750 [00:00<?, ?it/s]

In [12]:
# huge_dict

In [16]:
import json
# with open('/Users/nielssayez/Documents/Deepsun/Classification_dataset/ManualAnnotation/meanshift_Classification.json', 'w') as f:
# with open(root_dir+'/2002-2019/jan2023/meanshift_Classification.json', 'w') as f:
# with open(root_dir+'/ManualAnnotation/meanshift_Classification.json', 'w') as f:
# with open(root_dir+'/meanshift_Classification.json', 'w') as f:
with open(root_dir+'/test/meanshift_Classification.json', 'w') as f:
    json.dump(huge_dict, f, cls=NpEncoder)

In [17]:
import json
# with open('/Users/nielssayez/Documents/Deepsun/Classification_dataset/ManualAnnotation/meanshift_Classification.json', 'r') as f:
# with open(root_dir+'/2002-2019/jan2023/meanshift_Classification.json', 'r') as f:
# with open(root_dir+'/ManualAnnotation/meanshift_Classification.json', 'r') as f:
with open(root_dir+'/test/meanshift_Classification.json', 'r') as f:
    huge_dict = json.load(f)

In [17]:
print(len(list(huge_dict.keys())))

2961


In [19]:
print()




# View Results

In [17]:
colors = ['tab:blue','tab:orange','tab:green','tab:red',
          'tab:purple','tab:brown','tab:pink','tab:gray',
          'tab:olive','tab:cyan']

from matplotlib.colors import ListedColormap, LinearSegmentedColormap   
cmap_gt = cm.autumn
cmap_gt = cmap_gt(range(255))
cmap_gt = ListedColormap([(0, 0, 0, 0), *cmap_gt])

cmap = plt.get_cmap('autumn')
cmap.set_under((0,0,0,0))

def my_refresh(value):
    ccc = fig.gca()
    xlims0 = ccc.get_xlim()
    ylims0 = ccc.get_ylim()
    
    ccc.set_visible(False) 
    
    basename = list(huge_dict.keys())[img_selector.value]
    tmp_idx = wl_basenames.index(basename)
    wl = wl_list[tmp_idx]
     
    #####################
    
    cur_db_dict = huge_db_dict[basename]
    
    group_list = cur_db_dict["group_list"]
    drawing_radius_mm = cur_db_dict["dr_radius_mm"]
    drawing_radius_px = cur_db_dict["dr_radius_px"]
    date2 = cur_db_dict["dr_date"]
    date = cur_db_dict["wl_date"] 
    
    #####################
    m, h = utils.open_and_add_celestial(wl_list[tmp_idx])
#     m, h = utils.open_and_add_celestial2(wl_list[tmp_idx], date_obs=date)
#     print(h)

    corrected = False
    
    if not 'DATE-OBS' in h:
        m, h = utils.open_and_add_celestial2(wl_list[tmp_idx], date_obs=date)
        corrected = True
        
    wcs = WCS(h)
    wcs.heliographic_observer = m.observer_coordinate
    origin = m.data.shape[0]//2, m.data.shape[1]//2
    
    if input_type == "mask":
        mask = io.imread(os.path.join(masks_dir,basename+".png"))
    elif input_type == "confidence_map":
        mask = np.load(os.path.join(masks_dir,basename+"_proba_map.npy"))
        
    print(m.data.shape , mask.shape)
    
        # raise AssertionError("Not implemented yet")
#     mask[mask>0] = 1 

    m_rot = m.rotate(angle=-h["SOLAR_P0"] * u.deg)
    top_right = SkyCoord( 1000 * u.arcsec, 1000 * u.arcsec, frame=m_rot.coordinate_frame)
    bottom_left = SkyCoord(-1000 * u.arcsec, -1000 * u.arcsec, frame=m_rot.coordinate_frame)
    m_rot_submap = m_rot.submap(bottom_left, top_right=top_right)
    m_rot_submap_shape = m_rot_submap.data.shape
    m_rot_shape = m_rot.data.shape
    deltashapeX = np.abs(m_rot_shape[0]-m_rot_submap_shape[0])
    deltashapeY = np.abs(m_rot_shape[1]-m_rot_submap_shape[1]) 

    h2 = m_rot_submap.fits_header 

    h2.append(('CTYPE1', 'HPLN-TAN'))
    h2.append(('CTYPE2', 'HPLT-TAN'))
    wcs2 = WCS(h2)
    wcs2.heliographic_observer = m_rot_submap.observer_coordinate
    origin = m_rot_submap.data.shape[0]//2, m_rot_submap.data.shape[1]//2
    
#     if not m_rot_submap.date == date:
#         m_rot_submap.date = date
#         print(m_rot_submap.date)
    
    axClustering = fig.add_subplot(projection=m_rot_submap)
    m_rot_submap.plot(axes=axClustering, interpolation='None')
    m_rot_submap.draw_grid()   
    
    disp_mask = c_utils.rotate_CV_bound(mask, angle=h["SOLAR_P0"], interpolation=cv2.INTER_NEAREST) #rotate(mask, angle=h["SOLAR_P0"], reshape=True)
    disp_mask = disp_mask[deltashapeX//2:disp_mask.shape[0]-deltashapeX//2,
                          deltashapeY//2:disp_mask.shape[1]-deltashapeY//2] 
    axClustering.imshow(disp_mask, cmap=cmap, interpolation="None", alpha=.5 )
#     axClustering.imshow(disp_mask, cmap=cmap, interpolation="None", alpha=.5, vmin=.001 )
    
    axClustering.set_title(axClustering.get_title()+ f' [corrected: {corrected}]'  + f'-> drawing: {date2}' )
    
    ####################
    dr_obstime = date+'.000'  
    all_sks = []
    all_pixels = []
    for item in group_list:
        cur_sk = SkyCoord(item["Longitude"]*u.rad, item["Latitude"]*u.rad , frame=frames.HeliographicCarrington,
                      obstime=dr_obstime, observer="earth") 
        coords_wl = skycoord_to_pixel(cur_sk, wcs2, origin=0)
        all_sks.append(cur_sk)
        all_pixels.append(coords_wl)
        
    bboxes, bboxes_wl, rectangles, rectangles_wl = c_utils.grouplist2bboxes_and_rectangles(group_list, 
                                                                                   drawing_radius_px,
                                                                                   h["SOLAR_R"],
                                                                                   all_pixels)
    for cur_sk in all_sks:
        axClustering.plot_coord(cur_sk, "o", color='b', markersize=2)

    for i, r in enumerate(rectangles_wl):
        axClustering.add_patch(r)
        if areas_cb.value:
            axClustering.text(bboxes_wl[i][0]+3, bboxes_wl[i][1]+3, 
                          f' {group_list[i]["McIntosh"]} : {group_list[i]["area_muHem"]}',color='b') 
    
    ####################
    
    sunspots_sk, sunspots_areas = c_utils.get_sunspots3(h,m, mask>0, sky_coords=True)
    sk_Lon = sunspots_sk.lon.rad
    sk_Lat = sunspots_sk.lat.rad
    sk_LatLon = np.stack((sk_Lat,sk_Lon),axis=1)
    
    nan_indexes = np.unique(np.argwhere(np.isnan(sk_LatLon))[:,0])
    clean = (~np.isnan(sk_Lon) & ~np.isnan(sk_Lat))
    if len(nan_indexes) > 0:
        sunspots_sk = sunspots_sk[clean]
        sunspots_areas = (np.array(sunspots_areas)[clean]).tolist()
        sk_LatLon = sk_LatLon [clean]
    
    cur_dict = huge_dict[basename]
    
    ######################
    if ms_cb.value:
        cur_ms = cur_dict["meanshift"]

        ms_centroids = np.array(cur_ms['centroids'])
        ms_areas = np.array(cur_ms['areas'])
        ms_groups = np.array(cur_ms['groups'])
        
        my_ms = MS.Mean_Shift(look_distance, kernel_bandwidthLon, kernel_bandwidthLat, sunspots_sk.radius.km[0] ,  n_iterations)
        ##########
        my_ms.set_centroids(ms_centroids)
        # my_ms.fit(sk_LatLon, sunspots_areas)
        # ms_centroids = my_ms.centroids
        ###########
        ms_classifications = my_ms.predict(sk_LatLon)

        # print(m.date, '  ', date )
        sk_sequ_meanshift = SkyCoord(ms_centroids[:,1]*u.rad, ms_centroids[:,0]*u.rad , frame=frames.HeliographicCarrington,
                              obstime=m.date, observer="earth")
        
        pix_centers_meanshift = skycoord_to_pixel(sk_sequ_meanshift, wcs2, origin=0)

        for i, sk in enumerate(sk_sequ_meanshift):
                axClustering.plot_coord(sk, "X", color=colors[i%len(colors)], markersize=8)
                if areas_cb.value:
                    axClustering.text(pix_centers_meanshift[0][i], pix_centers_meanshift[1][i],  '%.2f' % ms_areas[i] ,va='top',c=colors[i%len(colors)])
        for i in range(len(sk_LatLon)):
            c = colors[ms_classifications[i]%len(colors)]
            axClustering.plot_coord(sunspots_sk[i], "o", color=c, markersize=2)

look_distance = .1 # How far to look for neighbours.
kernel_bandwidthLon = .3  # Longitude Kernel parameter.
kernel_bandwidthLat = .08  # Latitude Kernel parameter.
n_iterations = 20 # Number of iterations


input_type = "confidence_map"
# input_type = "mask"

huge_dict = {k: huge_dict[k] for k in sorted(list(huge_dict.keys()))}

db_drawings_datetimes = [c_utils.db_string_to_datetime(item[0]) for item in c_utils.get_unique_drawing_datetimes(database,'drawings')]
dr_basenames = [c_utils.datetime_to_drawing_name(item) for item in db_drawings_datetimes]

# print(huge_dict.keys())
    
img_selector = widgets.IntSlider(value=158, min=0, max=len(list(huge_dict.keys()))-1)
areas_cb = widgets.Checkbox(value=False, description="Show Areas")
km_cb = widgets.Checkbox(value=False, description="Show KMeans")
ms_cb = widgets.Checkbox(value=True, description="Show MeanShift")

img_selector.observe(my_refresh)
areas_cb.observe(my_refresh)
ms_cb.observe(my_refresh)
km_cb.observe(my_refresh)

plt.ioff()
fig, ax_widget = plt.subplots(nrows=1, ncols=1, figsize=(8,8))
my_refresh(0)
plt.ion()

widgets.VBox([widgets.HBox([img_selector, areas_cb, ms_cb]),fig.canvas])

(2048, 2048) (2048, 2048)


VBox(children=(HBox(children=(IntSlider(value=158, max=2960), Checkbox(value=False, description='Show Areas'),…

(2048, 2048) (2048, 2048)


TypeError: Overlay grids can only be plotted on WCSAxes plots.

In [11]:
from scipy.ndimage import rotate as rotate_image

def rotate_img_opencv(image, angle, interpolation):
    (h, w) = image.shape[:2]
    (cX, cY) = (w // 2, h // 2)

    M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
    rotated = cv2.warpAffine(image, M, (w, h))
    return rotated

def my_refresh(value):
    
    basename = list(huge_dict.keys())[img_selector.value]
    tmp_idx = wl_basenames.index(basename)
    wl = wl_list[tmp_idx]
    
    m, h = utils.open_and_add_celestial(wl_list[tmp_idx])
     
    if input_type == "mask":
        mask = io.imread(os.path.join(masks_dir,basename+".png"))
    elif input_type == "confidence_map":
#         print('here')
        mask = np.load(os.path.join(masks_dir,basename+"_proba_map.npy"))
    
    ax2_widget[0].imshow(m.data, cmap='gray')
    ax2_widget[0].imshow(mask, cmap=cmap, alpha = .5, vmin=.1)
    ax2_widget[0].set_title(basename)
    ax2_widget[0].invert_yaxis()
    
    m_rot = m.rotate(angle=-h["SOLAR_P0"] * u.deg)
    top_right = SkyCoord( 1000 * u.arcsec, 1000 * u.arcsec, frame=m_rot.coordinate_frame)
    bottom_left = SkyCoord(-1000 * u.arcsec, -1000 * u.arcsec, frame=m_rot.coordinate_frame)
    m_rot_submap = m_rot.submap(bottom_left, top_right=top_right)
    m_rot_submap_shape = m_rot_submap.data.shape
    m_rot_shape = m_rot.data.shape
    deltashapeX = np.abs(m_rot_shape[0]-m_rot_submap_shape[0])
    deltashapeY = np.abs(m_rot_shape[1]-m_rot_submap_shape[1]) 
    
    
#     rot_img = rotate_image(m_rot_submap.data, angle=h["SOLAR_P0"]) #rotate(mask, angle=h["SOLAR_P0"], reshape=True)
#     disp_mask = rotate_image(mask, angle=h["SOLAR_P0"]) #rotate(mask, angle=h["SOLAR_P0"], reshape=True)

    rot_img = c_utils.rotate_CV_bound(m_rot_submap.data, angle=h["SOLAR_P0"], interpolation=cv2.INTER_NEAREST) #rotate(mask, angle=h["SOLAR_P0"], reshape=True)
    disp_mask = c_utils.rotate_CV_bound(mask, angle=h["SOLAR_P0"], interpolation=cv2.INTER_NEAREST) #rotate(mask, angle=h["SOLAR_P0"], reshape=True)
    disp_mask = disp_mask[deltashapeX//2:disp_mask.shape[0]-deltashapeX//2,
                          deltashapeY//2:disp_mask.shape[1]-deltashapeY//2] 
    
    ax2_widget[1].imshow(m_rot_submap.data, cmap='gray')
    ax2_widget[1].imshow(disp_mask, cmap=cmap, alpha = .5, vmin=.1)
    ax2_widget[1].set_title(h["SOLAR_P0"])
    ax2_widget[1].invert_yaxis()
    

input_type = "confidence_map"

look_distance = .1 # How far to look for neighbours.
kernel_bandwidthLon = .3  # Longitude Kernel parameter.
kernel_bandwidthLat = .08  # Latitude Kernel parameter.
n_iterations = 20 # Number of iterations


db_drawings_datetimes = [c_utils.db_string_to_datetime(item[0]) for item in c_utils.get_unique_drawing_datetimes(database,'drawings')]
dr_basenames = [c_utils.datetime_to_drawing_name(item) for item in db_drawings_datetimes]

# print(huge_dict.keys())
huge_dict = {k: huge_dict[k] for k in sorted(list(huge_dict.keys()))}

    
img_selector = widgets.IntSlider(value=3000, min=0, max=len(list(huge_dict.keys()))-1)
mask_cb = widgets.Checkbox(value=True, description="Show Masks")

img_selector.observe(my_refresh)

plt.ioff()
fig2, ax2_widget = plt.subplots(nrows=1, ncols=2, figsize=(10,4))
my_refresh(0)
plt.ion()

widgets.VBox([widgets.HBox([img_selector]),fig2.canvas])


NameError: name 'cmap' is not defined

# Associate groups

In [18]:
image_out_dict = {}
image_out_dict_stats = {}

# for basename in list(huge_dict.keys())[3400:3401]:
# for basename in list(huge_dict.keys())[:1]:
#     cur_image_dict = huge_dict[basename]
# for basename in ["UPH20120912082643",]:
for basename in tqdm(list(huge_dict.keys())[:]):
# for basename in ["UPH20170623120332"]:
    cur_image_dict = huge_dict[basename]
    
    angle = cur_image_dict["SOLAR_P0"]
    deltashapeX = cur_image_dict["deltashapeX"]
    deltashapeY = cur_image_dict["deltashapeY"]
    
    drawing_radius_px = huge_db_dict[basename]["dr_radius_px"]
    
    group_list = cur_image_dict['db']
    
    
    ms_dict = cur_image_dict['meanshift']
    
    centroids = np.array(ms_dict["centroids"])
    centroids_px = np.array(ms_dict["centroids_px"])
#     print('centroids_px', centroids_px)
    
#     db_bboxes = grouplist2bboxes(group_list, drawing_radius_px)
    db_classes = [{"Zurich":item['Zurich'], "McIntosh":item['McIntosh'] } for item in group_list]
    db_bboxes = [np.array(item['bbox_wl']) for item in group_list]
    db_centers_px = np.array([[(b[2]+b[0])/2,(b[3]+b[1])/2] for b in db_bboxes])
    
#     print("db_centers_px", db_centers_px)
#     print("db_bboxes", db_bboxes)
    
        
    # check that current bbox is does not overlap any
    isolated_bboxes_bool = np.array(c_utils.get_intersecting_db_bboxes(db_bboxes)) == 0
    isolated_bboxes_indices = np.where(isolated_bboxes_bool == True)[0]
#     print("isolated_bboxes_bool",isolated_bboxes_bool)
#     print(isolated_bboxes_indices)
    
    cur_out_stats = {
        # General info
        'num_DB_groups':len(db_bboxes),
        'num_MS_groups':len(centroids_px),
        'num_DB_isolated_groups':len(isolated_bboxes_indices),
        'num_DB_overlaping_bboxes':len(db_bboxes) - len(isolated_bboxes_indices),
        # MS with DB matching info
        "num_MSmatchesDB":0,
        # MS with DB rejection info
        "num_noMS_but_DB_reject":0,
        "num_singleMS_multipleDB_reject":0,
        "num_oneDBbbox_multipleMSoverlap_ambiguity_reject":0,
        # MS with DB no match info
        "num_noDB_but_MS":0,
        }
    cur_out_groups = []
    for i, (db_bbox, db_center, db_class) in enumerate(
                                            zip([db_bboxes[a] for a in isolated_bboxes_indices.tolist() ],
                                                [db_centers_px[a] for a in isolated_bboxes_indices.tolist()],
                                                [db_classes[a] for a in isolated_bboxes_indices.tolist()],
                                               )):
        
        
        ms_centroids, ms_members = centroids_px, ms_dict['groups_px']
        
        intersect = c_utils.contains_ms_groups(db_bbox, db_center, ms_centroids, ms_members)
        
        if sum(intersect) == 0: # Il n'y a eu aucune détection dans cette zone
            cur_out_stats['num_noMS_but_DB_reject'] += 1
            pass
        elif sum(intersect) == 1: # il n'y a de l'overlap qu'avec un seul groupe meanshift            
#             print('hit')
            idx = intersect.index(True)
#             print(idx)
            # vérifier que le groupe meanshift n'intersecte aucune autre bbox
            num_intersections = np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes))
            if num_intersections > 1:
                cur_out_stats['num_singleMS_multipleDB_reject'] += 1
                continue
            
            Rmm = huge_db_dict[basename]['dr_radius_mm']
            R_pixel = huge_db_dict[basename]['dr_radius_px']
            sun_center = huge_db_dict[basename]['dr_center_px']
            dr_pixpos = np.array([group_list[i]['posx'], group_list[i]['posy']])
            
            angular_excentricity =  c_utils.get_angle2(dr_pixpos, R_pixel, sun_center)
            
            cur_group_dict={
                            "centroid_px": centroids_px[idx],
                            "centroid_Lat": centroids[idx][0],
                            "centroid_Lon": centroids[idx][1],
                            "angular_excentricity_rad": angular_excentricity,
                            "angular_excentricity_deg": np.rad2deg(angular_excentricity),
                            "Zurich":   db_class["Zurich"],
                            "McIntosh": db_class["McIntosh"],
                            "members": ms_members[idx],
                            "members_mean_px": np.mean(ms_members[idx], axis=0),
                           }
            
            
            cur_out_groups.append(cur_group_dict)
            cur_out_stats['num_MSmatchesDB'] += 1

        else: # db_bbox intersecte plusieurs groupes meanshift
            cur_out_stats['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] += 1
            pass
            
    if len(cur_out_groups) > 0:
        image_out_dict[basename] = { "angle": angle,
                                     "deltashapeX":deltashapeX,
                                     "deltashapeY":deltashapeY,
                                     "groups": cur_out_groups,
                                   }

    # count the number of MS groups that do not have any overlap with the DB
    num_intersections_per_group = [np.sum(c_utils.count_group_intersections(ms_members[idx], db_bboxes)) for idx in range(len(ms_members))]
    num_MS_without_DB_overlap = len(np.where(np.array(num_intersections_per_group) == 0)[0])
    cur_out_stats['num_noDB_but_MS'] = num_MS_without_DB_overlap
    
    image_out_dict_stats[basename] = cur_out_stats

    

print('num_images: ', len(list(image_out_dict.keys())))
num_groups = 0
for k,v in image_out_dict.items():
    num_groups += len(v['groups'])
print("num_groups: ",num_groups)
# print(image_out_dict)


  0%|          | 0/2961 [00:00<?, ?it/s]

num_images:  1856
num_groups:  3889


In [19]:
# Plot some stats
# General info
num_DB_groups_per_image = np.array([item['num_DB_groups'] for k,item in image_out_dict_stats.items()])
num_MS_groups_per_image = np.array([item['num_MS_groups'] for k,item in image_out_dict_stats.items()])

diff_num_groups = num_DB_groups_per_image - num_MS_groups_per_image # should be shown on histogram , closer to 0 is better

# number of isolated DB groups per image and mean number in dataset
num_isolated_DB_groups = np.array([v['num_DB_isolated_groups'] for k,v in image_out_dict_stats.items()])

# numbers of rejects due to single MS group overlapping multiple DB groups (# minimize this)
num_singleMS_multipleDB_reject = np.array([v['num_singleMS_multipleDB_reject'] for k,v in image_out_dict_stats.items()])
# numbers of rejects due to single DB group overlapping multiple MS groups (# minimize this)
num_oneDBbbox_multipleMSoverlap_ambiguity_reject = np.array([v['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] for k,v in image_out_dict_stats.items()])
# numbers of rejects due to no MS group overlapping DB group (# cannot do anything to this, examples are discarded)
num_noMS_but_DB_reject = np.array([v['num_noMS_but_DB_reject'] for k,v in image_out_dict_stats.items()])

num_optimizable_rejects = num_singleMS_multipleDB_reject + num_oneDBbbox_multipleMSoverlap_ambiguity_reject
num_rejects_all = num_optimizable_rejects + num_noMS_but_DB_reject

# rate of optimizable rejects
rate_optimizable_rejects = num_optimizable_rejects / num_rejects_all

#rate of 


# Numbers over whole dataset
num_DB_groups = np.sum(num_DB_groups_per_image)
num_MS_groups = np.sum(num_MS_groups_per_image)

num_isolated_DB_groups_total = np.sum(num_isolated_DB_groups) # cannot do anything to this, maximum value possible
rate_isolated_DB_groups_total = num_isolated_DB_groups_total / num_DB_groups
num_overlapping_DB_groups_total = num_DB_groups - num_isolated_DB_groups_total # cannot do anything to this, examples are discarded
rate_overlapping_DB_groups_total = num_overlapping_DB_groups_total / num_DB_groups

num_singleMS_multipleDB_reject_total = np.sum(num_singleMS_multipleDB_reject) # minimize this
num_oneDBbbox_multipleMSoverlap_ambiguity_reject_total = np.sum(num_oneDBbbox_multipleMSoverlap_ambiguity_reject) # minimize this




In [20]:
# Plot some stats
# General info
num_DB_groups_per_image = np.array([item['num_DB_groups'] for k,item in image_out_dict_stats.items()])
num_MS_groups_per_image = np.array([item['num_MS_groups'] for k,item in image_out_dict_stats.items()])

diff_num_groups = num_DB_groups_per_image - num_MS_groups_per_image # should be shown on histogram , closer to 0 is better

# number of isolated DB groups per image and mean number in dataset
num_isolated_DB_groups = np.array([v['num_DB_isolated_groups'] for k,v in image_out_dict_stats.items()])
mean_num_isolated_DB_groups = np.mean(num_isolated_DB_groups) # not very useful
rate_isolated_DB_groups = num_isolated_DB_groups / num_DB_groups_per_image # not very useful
mean_rate_isolated_DB_groups = np.mean(rate_isolated_DB_groups) # not very useful
std_rate_isolated_DB_groups = np.std(rate_isolated_DB_groups) # not very useful

# Numbers over whole dataset
num_DB_groups = np.sum(num_DB_groups_per_image)
num_MS_groups = np.sum(num_MS_groups_per_image)
num_isolated_DB_groups_total = np.sum(num_isolated_DB_groups)
rate_isolated_DB_groups_total = num_isolated_DB_groups_total / num_DB_groups

#


# number of MS groups without DB overlap per image
num_MS_groups_without_DB_overlap = np.array([v['num_noDB_but_MS'] for k,v in image_out_dict_stats.items()])
mean_num_MS_groups_without_DB_overlap = np.mean(num_MS_groups_without_DB_overlap) # not very useful
rate_MS_groups_without_DB_overlap = num_MS_groups_without_DB_overlap / num_MS_groups_per_image
mean_rate_MS_groups_without_DB_overlap = np.mean(rate_MS_groups_without_DB_overlap)
std_rate_MS_groups_without_DB_overlap = np.std(rate_MS_groups_without_DB_overlap)

#number of DB groups without MS overlap per image
num_DB_groups_without_MS_overlap = np.array([v['num_noMS_but_DB_reject'] for k,v in image_out_dict_stats.items()])
mean_num_DB_groups_without_MS_overlap = np.mean(num_DB_groups_without_MS_overlap) # not very useful
rate_DB_groups_without_MS_overlap = num_DB_groups_without_MS_overlap / num_DB_groups_per_image
mean_rate_DB_groups_without_MS_overlap = np.mean(rate_DB_groups_without_MS_overlap)
std_rate_DB_groups_without_MS_overlap = np.std(rate_DB_groups_without_MS_overlap)


# total number of rejects per image
total_rejects = np.sum([v['num_noMS_but_DB_reject'] + 
                        v['num_singleMS_multipleDB_reject'] +
                        v['num_oneDBbbox_multipleMSoverlap_ambiguity_reject'] 
                        for k,v in image_out_dict_stats.items()])

# # total number of DB groups 
# reject_rates = total_rejects / 

# # mean percentage of rejects per image
# mean_percentage_rejects = np.mean([ (( v['num_noMS_but_DB_reject'] +
#                                       v['num_singleMS_multipleDB_reject'] + 
#                                       v['num_oneDBbbox_multipleMSoverlap_ambiguity_reject']
#                                     )/ v['num_DB_groups'] * 100) if v['num_DB_groups'] > 0 else 0
#                                     for k,v in image_out_dict_stats.items()])


In [21]:
# print("len(huge_dict.keys()): ", len(huge_dict.keys()))
# print("len(image_out_dict.keys()): ", len(image_out_dict.keys()))

# # get the keys in huge_dict not present in image_out_dict

# missing_keys = []
# for k in huge_dict.keys():
#     # print(k)
#     if k not in image_out_dict.keys():
#         missing_keys.append(k)
# print("missing_keys: ", missing_keys)

#print the number of groups per image in huge_db_dict
for i,(k,v) in enumerate(huge_db_dict.items()):
    # print(v)
    # print(huge_dict[k].keys())
    # print(huge_dict[k]["meanshift"]['centroids'], len(huge_dict[k]["meanshift"]['centroids']))
    idx =  i
    img_name = k
    db_num_groups = len(v['group_list'])
    found_groups = len(huge_dict[k]["meanshift"]['centroids']) if k in huge_dict.keys() else 0
    matched_groups = len(image_out_dict[k]['groups']) if k in image_out_dict.keys() else 0

    
    print(f'idx {idx}: {img_name} DB num_groups: {db_num_groups} Found Groups: {found_groups} Matched Groups: {matched_groups}')
    # print( f'idx {i}: ', k, "DB num_groups:" ,len(v['group_list']), 'Matched Groups', len(image_out_dict[k]['groups']) )

# for k,v in image_out_dict.items():
#     print(k, len(v['groups']))


idx 0: UPH20020424160712 DB num_groups: 13 Found Groups: 16 Matched Groups: 2
idx 1: UPH20020427094522 DB num_groups: 11 Found Groups: 11 Matched Groups: 2
idx 2: UPH20020428170429 DB num_groups: 7 Found Groups: 11 Matched Groups: 1
idx 3: UPH20020502111048 DB num_groups: 12 Found Groups: 11 Matched Groups: 5
idx 4: UPH20020507112606 DB num_groups: 14 Found Groups: 18 Matched Groups: 2
idx 5: UPH20020508155035 DB num_groups: 15 Found Groups: 9 Matched Groups: 4
idx 6: UPH20020513152850 DB num_groups: 10 Found Groups: 12 Matched Groups: 1
idx 7: UPH20020515170013 DB num_groups: 7 Found Groups: 9 Matched Groups: 1
idx 8: UPH20020521184732 DB num_groups: 9 Found Groups: 8 Matched Groups: 2
idx 9: UPH20020527140851 DB num_groups: 12 Found Groups: 12 Matched Groups: 2
idx 10: UPH20020529170600 DB num_groups: 11 Found Groups: 5 Matched Groups: 1
idx 11: UPH20020531163352 DB num_groups: 11 Found Groups: 14 Matched Groups: 6
idx 12: UPH20020604170808 DB num_groups: 13 Found Groups: 16 Matched 

In [22]:
root_dir

'/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/'

In [23]:
# image_out_dict

In [24]:
out_json_dir = 'feb2023'
out_json_filename = 'dataset'

# out_json_filename = 'dataset'

In [25]:
# with open(f'{root_dir}/2002-2019/{out_json_dir}/{out_json_filename}.json', 'w') as f:
# with open(f'{root_dir}/ManualAnnotation/{out_json_filename}.json', 'w') as f:
# with open(f'{root_dir}/{out_json_filename}.json', 'w') as f:
with open(f'{root_dir}/test/{out_json_filename}.json', 'w') as f:
    json.dump(image_out_dict, f, cls=NpEncoder)

In [28]:
import json
# with open(f'{root_dir}/{out_json_filename}.json', 'r') as f:
with open(f'{root_dir}/test/{out_json_filename}.json', 'r') as f:
# with open(f'{root_dir}/2002-2019/{out_json_dir}/{out_json_filename}.json', 'r') as f:
# with open(f'{root_dir}/ManualAnnotation/{out_json_filename}.json', 'r') as f:
    image_out_dict = json.load(f)

# Dataset Statistics

In [26]:
# 1) analyse the distributions

classes = ['A','B','C','D','E','F','G','H','I','J','X']

distribs = {c: 0 for c in classes}
group_types = {}
group_types2 = {c:{} for c in classes}

for bn, img_dict in tqdm(image_out_dict.items()):
    groups = img_dict['groups']
    
    for i, g in enumerate(groups):
        cur_c = g["Zurich"]
        distribs[cur_c] +=1
        
        new_group_infos= {
            'angle': img_dict['angle'],
            'deltashapeX': img_dict['deltashapeX'],
            'deltashapeY':img_dict['deltashapeY'],
            'centroid_px': g['centroid_px'],
            'centroid_Lat': g['centroid_Lat'],
            'centroid_Lon': g['centroid_Lon'],
            'members': g['members'],
            'members_mean_px': g['members_mean_px'],
            'angular_excentricity_rad': g['angular_excentricity_rad'],
            'angular_excentricity_deg': g['angular_excentricity_deg'],
            'Zurich': g['Zurich'],
            'McIntosh': g['McIntosh'],   
        }
        
        new_goup_id = bn + '_' + str(i)
        group_types[new_goup_id] = new_group_infos
        group_types2[cur_c][new_goup_id] = new_group_infos
        
    
print(distribs)
print()
# print(group_types2)

  0%|          | 0/1856 [00:00<?, ?it/s]

{'A': 421, 'B': 520, 'C': 805, 'D': 798, 'E': 177, 'F': 16, 'G': 79, 'H': 315, 'I': 0, 'J': 743, 'X': 15}



In [27]:
group_based_dataset = deepcopy(group_types)
group_based_dataset2 = deepcopy(group_types2)

In [28]:
First2superFirst = {"A":"A",
                    "B":"B",
                    "C":"C",
                    "D":"SuperGroup",
                    "E":"SuperGroup",
                    "F":"SuperGroup",
                    "H":"H",
                    "X":"X"
                   }
Second2superSecond = {"x":"x",
                      "r":"r",
                      "s": "sym",
                      "h": "sym",
                      "a": "asym",
                      "k": "asym",
                     }
Third2superThird = {"x": "x",
                    "o": "o",
                    "i": "frag",
                    "c": "frag",
                   }

def add_superclasses(group_dict):
    cpy = deepcopy(group_dict)    
    # print(cpy)

    cpy["SuperClass"] = {
        "1": First2superFirst[group_dict["McIntosh"][0]],
        "2": Second2superSecond[group_dict["McIntosh"][1]],
        "3": Third2superThird[group_dict["McIntosh"][2]],
    }
    
    return cpy
    
    

grp_to_remove = []
group_based_dataset_superclasses = {}
for g in tqdm(group_based_dataset):
    try : 
        group = group_based_dataset[g]
        # print(group)
        group = add_superclasses(group)
#         print(group)
        group_based_dataset_superclasses[g] = group
        
    except KeyError:
        print(g)
        print(group_based_dataset[g])
        if group_based_dataset[g]["McIntosh"] == '   ':
            print( "error")
            grp_to_remove.append((g,group_based_dataset[g]['Zurich']))

for k,k_type in grp_to_remove:
    group_based_dataset.pop(k)
    group_based_dataset2[k_type].pop(k)

# group_based_dataset_superclasses

    

  0%|          | 0/3889 [00:00<?, ?it/s]

UPH20040722080136_1
{'angle': 6.950210373029591, 'deltashapeX': 226, 'deltashapeY': 226, 'centroid_px': array([1763.98468838,  789.79392226]), 'centroid_Lat': -0.19138280535545324, 'centroid_Lon': 0.8858795170464564, 'members': [[789.1715265866209, 1766.2315608919382], [805.4, 1705.56], [804.0, 1700.5]], 'members_mean_px': array([ 799.5238422 , 1724.09718696]), 'angular_excentricity_rad': 0.8332981619101842, 'angular_excentricity_deg': 47.74446775346269, 'Zurich': 'G', 'McIntosh': '   '}
error


## Split per type

In [29]:
distribs2 = {c:0 for c in classes}
group_types2 = {c:{} for c in classes}

for grp_id, grp_dict in tqdm(group_based_dataset_superclasses.items()):
    cur_c = grp_dict["Zurich"]
    group_types2[cur_c][grp_id] = grp_dict
        
# group_types2

  0%|          | 0/3888 [00:00<?, ?it/s]

## Step 2:  Split groups among train, val, test sets

In [30]:
import random


def splitPerc(l, perc):
    # Turn percentages into values between 0 and 1
    splits = np.cumsum(perc)/100.

    if splits[-1] != 1:
        raise ValueError("percents don't add up to 100")

    # Split doesn't need last percent, it will just take what is left
    splits = splits[:-1]

    # Turn values into indices
    splits *= len(l)

    # Turn double indices into integers.
    # CAUTION: numpy rounds to closest EVEN number when a number is halfway
    # between two integers. So 0.5 will become 0 and 1.5 will become 2!
    # If you want to round up in all those cases, do
    # splits += 0.5 instead of round() before casting to int
    splits = splits.round().astype(np.int)

    return np.split(l, splits)

splits = ['train', 'val', 'test']

splits_percentages = [70, 15, 15]

assert np.array(splits_percentages).sum() == 100

group_based_dataset_superclasses_splits = {sp:{} for sp in splits}

for t, type_dict in group_types2.items():
    list_type_groups = list(type_dict.keys()) 
    # shuffle
    random.shuffle(list_type_groups)
    
    indices = np.array(range(len(list_type_groups)))
    
    s = splitPerc(indices, splits_percentages)

    # take percentage and fill group_based_dataset
    for i, sp in enumerate(splits):
        split_indices = s[i]
        split_groups = [list_type_groups[j] for j in split_indices]
        
#         print(split_groups)
        for g in  split_groups:
            group_based_dataset_superclasses_splits[sp][g] = type_dict[g] 
            

In [31]:
# group_based_dataset_superclasses_splits

In [32]:
final_json = "dataset_nosplits"
# with open(f'{root_dir}/2002-2019/{out_json_dir}/{final_json}.json', 'w') as f:
# with open(f'{root_dir}/ManualAnnotation/{final_json}.json', 'w') as f:
# with open(f'{root_dir}/{final_json}.json', 'w') as f:
with open(f'{root_dir}/test/{final_json}.json', 'w') as f:
    json.dump(group_based_dataset_superclasses, f, cls=NpEncoder)
final_json_split = "dataset_final"
# with open(f'{root_dir}/2002-2019/{out_json_dir}/{final_json_split}.json', 'w') as f:
# with open(f'{root_dir}/ManualAnnotation/{final_json_split}.json', 'w') as f:
with open(f'{root_dir}/test/{final_json_split}.json', 'w') as f:
    json.dump(group_based_dataset_superclasses_splits, f, cls=NpEncoder)

## Compare model inputs

In [33]:
from omegaconf import OmegaConf
from hydra.utils import instantiate

In [34]:
cropper_transform_idx = 0

#load the config file
cfg1_orig = OmegaConf.load(f'test.yaml')

cfg2_focus_noRand = OmegaConf.load(f'test.yaml')
cfg2_focus_noRand.transforms[cropper_transform_idx].focus_on_group = True

cfg3_nofocus_rand = OmegaConf.load(f'test.yaml')
cfg3_nofocus_rand.transforms[cropper_transform_idx].focus_on_group = False
cfg3_nofocus_rand.transforms[cropper_transform_idx].random_move = True
cfg3_nofocus_rand.transforms[cropper_transform_idx].random_move_percent = .2

cfg4_focus_rand = OmegaConf.load(f'test.yaml')
cfg4_focus_rand.transforms[cropper_transform_idx].focus_on_group = True
cfg4_focus_rand.transforms[cropper_transform_idx].random_move = True
cfg4_focus_rand.transforms[cropper_transform_idx].random_move_percent = .2

In [35]:
tr_orig = [instantiate(t) for t in cfg1_orig.transforms]
tr_focus_noRand = [instantiate(t) for t in cfg2_focus_noRand.transforms]
tr_nofocus_rand = [instantiate(t) for t in cfg3_nofocus_rand.transforms]
tr_focus_rand = [instantiate(t) for t in cfg4_focus_rand.transforms]


In [36]:
root_dir

'/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/'

In [48]:
%load_ext autoreload
%autoreload 2
from transform_utilities import *

def create_circular_mask(h, w, center=None, radius=None):

    if center is None: # use the middle of the image
        center = (int(w/2), int(h/2))
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask

# get mask contaoining only components at given coordinates using regionprops
def get_mask_from_coords(mask, coords):
    mask = mask.copy()
    mask[~mask] = 0
    mask = mask.astype(np.uint8)
    props = regionprops(mask)
    for p in props:
        if p.centroid not in coords:
            mask[p.coords[:,0], p.coords[:,1]] = 0
    return mask


from datetime import datetime, timedelta

def create_sample( grp_info, img_path, mask_path, conf_path):
    sample = {}
    sample['image'] = (io.imread(img_path)).astype(float) 
    header = fits.open(img_path)[0].header
    center = np.array(sample['image'].shape)//2
    radius = header['SOLAR_R']
    
    
    basename = os.path.basename(img_path).split('.')[0]
    sample['name'] = basename
    
    flip_time = "2003-03-08T00:00:00"
    date = c_utils.whitelight_to_datetime(basename)
    datetime_str = c_utils.datetime_to_db_string(date).replace(' ', 'T')
    print(datetime_str)
    should_flip = (datetime.fromisoformat(datetime_str) - datetime.fromisoformat(flip_time)) < timedelta(0)
    sample['should_flip'] = should_flip
    

    sample['solar_disk'] = create_circular_mask( sample['image'].shape[1], sample['image'].shape[0] ,center,radius)
    sample['mask'] = io.imread(mask_path)#.astype(float)
    
    sample['members'] = np.array(grp_info['members'])
    sample['members_mean_px'] = np.array(grp_info['members_mean_px'])
    
    sample['confidence_map'] = np.load(conf_path)#.astype(float)


    sample['solar_angle'] = grp_info['angle']
    sample['deltashapeX'] = grp_info['deltashapeX']
    sample['deltashapeY'] = grp_info['deltashapeY']
    
  
    
    
    sample['angular_excentricity'] = np.array([grp_info["angular_excentricity_deg"]])
    sample['centroid_px'] = np.array(grp_info["centroid_px"])
    sample['excentricity_map'] = sample['solar_disk'].copy()

    sample['centroid_Lat'] = np.array([grp_info["centroid_Lat"]])
    print(should_flip)
#     should_flip=True
    if should_flip:
        sample['image'] = np.flip(sample['image'],axis=0)
        sample['solar_disk'] = np.flip(sample['solar_disk'],axis=0)
        sample['mask'] = np.flip(sample['mask'],axis=0)
        sample['confidence_map'] = np.flip(sample['confidence_map'],axis=0)
        sample['excentricity_map'] = np.flip(sample['excentricity_map'],axis=0)
    
    

    return sample

def grp_refresh(value):
    grp_name = list(group_based_dataset_superclasses.keys())[grp_slider.value]
    grp = group_based_dataset_superclasses[grp_name]
    grp_superclass = grp['SuperClass']

    # print(grp)

    img_bn = grp_name.split('_')[0]
#     img_path = f'{root_dir}/2002-2019/image/{img_bn}.FTS'
#     mask_path = f'{root_dir}/2002-2019/T400-T350-Alternating_pen_um/{img_bn}.png'
#     img_path = f'/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/image/{img_bn}.FTS'
    img_path = f'/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2/image/{img_bn}.FTS'
    mask_path = f'{root_dir}/T425-T375-T325_fgbg/{img_bn}.png'
    conf_path = f'{root_dir}/T425-T375-T325_fgbg/{img_bn}_proba_map.npy'
    # img_path = f'/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/image/{img_bn}.FTS'
    # mask_path = f'/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/GroundTruth/{img_bn}.png'

    sample = create_sample(grp, img_path, mask_path, conf_path)
    
    sampleV1 = sample.copy()
    for t in tr_orig:
        sampleV1 = t(**sampleV1)
    sampleV2 = sample.copy()
    for t in tr_focus_noRand:
        sampleV2 = t(**sampleV2)
    sampleV3 = sample.copy()
    for t in tr_nofocus_rand:
        sampleV3 = t(**sampleV3)
    sampleV4 = sample.copy()
    for t in tr_focus_rand:
        sampleV4 = t(**sampleV4)

    # print(sample)
    ax2[0].set_title(f"{grp_name}:\n {grp_superclass['1']}{grp_superclass['2']}{grp_superclass['3']}")
    ax2[0].imshow(sampleV1['image'], origin='lower', cmap='gray', vmin=0)
    ax2[1].set_title(f"orig")
    ax2[1].imshow(sampleV1['image'], origin='lower', cmap='gray', vmin=0)
    ax2[2].set_title(f"focus_noRand")
    ax2[2].imshow(sampleV2['image'], origin='lower', cmap='gray', vmin=0)
    ax2[3].set_title(f"nofocus_rand")
    ax2[3].imshow(sampleV3['image'], origin='lower', cmap='gray', vmin=0)
    ax2[4].set_title(f"focus_rand")
    ax2[4].imshow(sampleV4['image'], origin='lower', cmap='gray', vmin=0)
    
    if msk_cb.value:
        # pass
        
        
        ax2[1].imshow(sampleV1['confidence_map'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
        ax2[2].imshow(sampleV2['confidence_map'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
        ax2[3].imshow(sampleV3['confidence_map'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
        ax2[4].imshow(sampleV4['confidence_map'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')

#         ax2[1].imshow(sampleV1['mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
#         ax2[2].imshow(sampleV2['mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
#         ax2[3].imshow(sampleV3['mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
#         ax2[4].imshow(sampleV4['mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
    if grp_msk_cb.value:
        # pass

        ax2[1].imshow(sampleV1['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
        ax2[2].imshow(sampleV2['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
        ax2[3].imshow(sampleV3['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
        ax2[4].imshow(sampleV4['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
#         ax2[1].imshow(sampleV1['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
#         ax2[2].imshow(sampleV2['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
#         ax2[3].imshow(sampleV3['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')
#         ax2[4].imshow(sampleV4['group_mask'], origin='lower', alpha=0.5, cmap='jet', interpolation='none')


# 1773
grp_slider = widgets.IntSlider(min=0, max=len(group_based_dataset_superclasses)-1, step=1, value=1667, description='Group')
msk_cb = widgets.Checkbox(value=True, description='Mask', disabled=False, indent=False)
grp_msk_cb = widgets.Checkbox(value=True, description='GroupMask', disabled=False, indent=False)

grp_slider.observe(grp_refresh, 'value')
msk_cb.observe(grp_refresh, 'value')
grp_msk_cb.observe(grp_refresh, 'value')

plt.ioff()
fig2,ax2 = plt.subplots(1,5,figsize=(8,4))
grp_refresh(None)
plt.ion()

display(widgets.VBox([widgets.HBox([grp_slider, msk_cb, grp_msk_cb]), fig2.canvas]))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
2010-07-01T08:00:00
False


VBox(children=(HBox(children=(IntSlider(value=1667, description='Group', max=3887), Checkbox(value=True, descr…

2011-10-02T07:19:00
False
2011-10-02T07:19:00
False
2011-10-02T07:19:00
False
2011-10-03T07:50:00
False
2011-10-03T07:50:00
False
2011-10-03T07:50:00
False
2011-10-03T07:50:00
False
2011-10-07T08:33:00
False
2011-10-07T08:33:00
False
2011-10-07T08:33:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-13T08:49:00
False
2011-10-14T08:29:00
False
2011-10-14T08:29:00
False
2011-10-14T08:29:00
False
2011-10-14T08:29:00
False
2011-10-14T08:29:00
False
2011-10-14T08:29:00
False
2011-10-14T08:29:00
False
2011-10-15T08:43:00
False
2011-10-15T08:43:00
False
2011-10-15T08:43:00
False
2011-10-15T08:43:00
False
2011-10-15T08:43:00
False
2011-10-15T08:43:00
False
2011-10-15T08:43:00
False
2011-10-16T09:47:00
False
2011-10-16T09:47:00
False
2011-10-16T09:47:00
False
2011-10-16T09:47:00
False
2011-10-16T09:47:00
False
2011-10-16T0

In [37]:
#erreur pr index 1773

In [40]:
# for grp_slder in range(0, len(group_based_dataset_superclasses.keys())-1):
def grp_refresh(value):
    grp_name = list(group_based_dataset_superclasses.keys())[grp_slider.value]
    grp = group_based_dataset_superclasses[grp_name]
    grp_superclass = grp['SuperClass']

    # print(grp)

    img_bn = grp_name.split('_')[0]
#     img_path = f'{root_dir}/2002-2019/image/{img_bn}.FTS'
#     mask_path = f'{root_dir}/2002-2019/T400-T350-Alternating_pen_um/{img_bn}.png'
    img_path = f'/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019/image/{img_bn}.FTS'
    mask_path = f'{root_dir}/T425-T375-T325_fgbg/{img_bn}.png'
    # img_path = f'/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/image/{img_bn}.FTS'
    # mask_path = f'/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/GroundTruth/{img_bn}.png'

    sample = create_sample(grp, img_path, mask_path)
    
    sampleV1 = sample.copy()
    for t in tr_orig:
        sampleV1 = t(**sampleV1)

    # get the sunspots that are in sampleV1['mask'] but not in sampleV1['group_mask'] 
    diff_mask = sampleV1['mask'] - sampleV1['group_mask']
    if np.sum(diff_mask) > 0:
        print(f"Group {grp_name} has sunspots that are not in the group mask")

    ax2[0].set_title(f"{grp_name}:\n {grp_superclass['1']}{grp_superclass['2']}{grp_superclass['3']}")
    ax2[0].imshow(sampleV1['image'], origin='lower', cmap='gray', vmin=0)
    ax2[1].set_title(f"Whole Mask")
    ax2[1].imshow(sampleV1['mask'], origin='lower', cmap='gray', vmin=0)
    ax2[2].set_title(f"Group Mask")
    ax2[2].imshow(sampleV1['group_mask'], origin='lower', cmap='gray', vmin=0)
    ax2[3].set_title(f"Unwanted Sunspots")
    ax2[3].imshow(diff_mask, origin='lower', cmap='gray', vmin=0)

grp_slider = widgets.IntSlider(min=0, max=len(group_based_dataset_superclasses)-1, step=1, value=1773, description='Group')
msk_cb = widgets.Checkbox(value=True, description='Mask', disabled=False, indent=False)
grp_msk_cb = widgets.Checkbox(value=True, description='GroupMask', disabled=False, indent=False)

grp_slider.observe(grp_refresh, 'value')
msk_cb.observe(grp_refresh, 'value')
grp_msk_cb.observe(grp_refresh, 'value')

plt.ioff()
fig2,ax2 = plt.subplots(1,4,figsize=(8,4))
grp_refresh(None)
plt.ion()

display(widgets.VBox([widgets.HBox([grp_slider, msk_cb, grp_msk_cb]), fig2.canvas]))

    

Group UPH20111019081228_1 has sunspots that are not in the group mask


VBox(children=(HBox(children=(IntSlider(value=1773, description='Group', max=3337), Checkbox(value=True, descr…

In [34]:
import concurrent.futures

# for idx in tqdm(range(0, len(group_based_dataset_superclasses.keys())-1)):
def is_isolated(idx):
    grp_name = list(group_based_dataset_superclasses.keys())[idx]
    grp = group_based_dataset_superclasses[grp_name]
    grp_superclass = grp['SuperClass']

    # print(grp)

    img_bn = grp_name.split('_')[0]
    img_path = f'{root_dir}/2002-2019/image/{img_bn}.FTS'
    mask_path = f'{root_dir}/2002-2019/T400-T350-Alternating_pen_um/{img_bn}.png'
    # img_path = f'/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/image/{img_bn}.FTS'
    # mask_path = f'/globalscratch/users/n/s/nsayez/deepsun_bioblue/ManualAnnotation/GroundTruth/{img_bn}.png'

    sample = create_sample(grp, img_path, mask_path)
    
    sampleV1 = sample.copy()
    for t in tr_orig:
        sampleV1 = t(**sampleV1)

    # get the sunspots that are in sampleV1['mask'] but not in sampleV1['group_mask'] 
    diff_mask = sampleV1['mask'] - sampleV1['group_mask']
    if np.sum(diff_mask>0) > 0:
        return True
        # print(f"Group {grp_name} has sunspots that are not in the group mask")
    return False

num_NotIsolated_groups = 0
NotIsolated_groups = []

num_cpu = 16
with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:

    for i in tqdm(executor.map(is_isolated, range(0, len(group_based_dataset_superclasses.keys())-1))):
        if i:
            num_NotIsolated_groups += 1
            NotIsolated_groups.append(i)


proportion = num_NotIsolated_groups / len(group_based_dataset_superclasses.keys())
print(f"Proportion of cases with sunspots that are not in the group mask: {proportion}")

0it [00:00, ?it/s]

Proportion of cases with sunspots that are not in the group mask: 0.15133601324190116


#### To get this proportion the following must be taken into account:

#### We search for the presence of mask elements in the crop that do not belong to the mask of group elements 

#### This means that the counter increases in case of:

    1. There is an actual other group in the crop
    2. There is a segmentation FP in the crop

#### On the other hand, The counter does not increases in case of:

    1. Segmentation FN external to the group.

# Vérifier que pour les images de classes B, il n'y a pas de pénombre.

Parcourir les groupes de classe B, 

regarder les masques de groupe
`
verifier que lle masque de groupe ne contient pas de pénombre.

# Degrade masks

# Poubelle

In [31]:
import matplotlib

colors = ['tab:blue','tab:orange','tab:green','tab:red',
          'tab:purple','tab:brown','tab:pink',
          'tab:olive','tab:cyan']

def new_refresh(value):
    basename = os.path.basename(wl_list[img_slider.value]).split(".")[0]
    m, h = utils.open_and_add_celestial(wl_list[img_slider.value])
    mask = io.imread(os.path.join(masks_dir,basename+".png"))
    mask2 = mask.copy()
    mask2[mask2>0] = 1

    sunspots_sk, sunspots_areas = get_sunspots3(h,m, mask2, sky_coords=True)
    sunspots_pixel, _ = get_sunspots3(h, m, mask2, sky_coords=False)

    sk_Lon = sunspots_sk.lon.rad
    sk_Lat = sunspots_sk.lat.rad
    sk_LatLon = np.stack((sk_Lat,sk_Lon),axis=1)

    nan_indexes = np.unique(np.argwhere(np.isnan(sk_LatLon))[:,0])
    clean = (~np.isnan(sk_Lon) & ~np.isnan(sk_Lat))
    if len(nan_indexes) > 0:
        sunspots_sk = sunspots_sk[clean]
        sunspots_areas = (np.array(sunspots_areas)[clean]).tolist()
        sk_LatLon = sk_LatLon [clean]

    global ms_model
    ms_model = Mean_Shift(look_distance, kernel_bandwidthLon, kernel_bandwidthLat, sunspots_sk.radius.km[0], n_iterations)
    ms_model.fit(sk_LatLon, sunspots_areas)
    ms_centroids = ms_model.centroids

    sk_sequ_meanshift = SkyCoord(ms_centroids[:,1]*u.rad, ms_centroids[:,0]*u.rad , frame=frames.HeliographicCarrington,
                        obstime=m.date, observer="earth")
    wcs2 = WCS(h)
    pix_centers_meanshift = np.array(skycoord_to_pixel(sk_sequ_meanshift, wcs2, origin=0)).T.tolist()

    ms_classifications = ms_model.predict(sk_LatLon)
    # print(ms_classifications.shape, np.unique(ms_classifications))

    ms_group_sunspots = [(sk_LatLon[ms_classifications == c].tolist()) for c in np.unique(ms_classifications)] 
    ms_group_sunspots_px = [sunspots_pixel[ms_classifications == c].tolist() for c in np.unique(ms_classifications)]
    # print(len(ms_group_sunspots))

    #latitude and longitudes in radians
    extreme_values = -np.pi, np.pi, 0, 2*np.pi

    # print(sk_LatLon.shape)
    ax3[0].clear(), ax3[1].clear(), ax3[2].clear()
    ax3[0].set_title(basename)
    ax3[0].imshow(m.data,cmap='gray')
    ax3[0].imshow(mask,cmap='jet',alpha=0.5)
    ax3[0].invert_yaxis()

    ax3[1].imshow(m.data,cmap='gray')
    ax3[1].invert_yaxis()
    for i in range(len(ms_group_sunspots)):
        c = colors[np.unique(ms_classifications)[i]%len(colors)]
        # c = colors[ms_classifications[i]%len(colors)]
        cur = np.array(ms_group_sunspots_px[i])
        ms_centers = pix_centers_meanshift[i]
        ax3[1].scatter(ms_centers[0],ms_centers[1], s=10, c=c, marker='x')
        ax3[1].scatter(cur[:,1], cur[:,0], color=c, s=2)

    ax3[2].scatter(ms_centroids[:,1], ms_centroids[:,0], s=1)
    ax3[2].set_ylim(extreme_values[0], extreme_values[1])
    ax3[2].set_xlim(extreme_values[2], extreme_values[3])

    # ax5[1].scatter(sk_LatLon[:,1], sk_LatLon[:,0], s=1)
    # ax5[1].set_xlim(0, 2*np.pi)
    # ax5[1].set_ylim(-np.pi/2, np.pi/2)
    
    
    hist_refresh(None)

def hist_refresh(change):
    # if (xlims0, ylims0) != (0, 1):
    xlims0 = ax5[0].get_xlim()
    ylims0 = ax5[0].get_ylim()
    # print(xlims0, ylims0)

    global ms_model
    # print(ms_model.history)
    step = hist_slider.value
    ax5[0].clear()
    ax5[0].set_title('History step {}'.format(step))
    for i in range(len(ms_model.history[step])):
        cur_width = ms_model.get_area_weighted_ellipsis_width( ms_model.areas[i], ms_model.areas)
        cur_height = kernel_bandwidthLat
        # if (ms_model.data[i][0] <= 0.28 and ms_model.data[i][0] >= 0.25) and (ms_model.data[i][1] >= 1.45 and ms_model.data[i][1] <= 1.48):
        #     print(cur_width, cur_height)
        ellipsis = matplotlib.patches.Ellipse((ms_model.history[step][i,1], ms_model.history[step][i,0]), 2*cur_width, 2*cur_height, fill=False, color='red')
        # ellipsis = matplotlib.patches.Ellipse((ms_model.history[step][i,1], ms_model.history[step][i,0]), 2*kernel_bandwidthLon, 2*kernel_bandwidthLat, fill=False, color='red')
        ax5[0].add_patch(ellipsis)
    ax5[0].scatter(ms_model.history[step][:,1], ms_model.history[step][:,0], s=1)
    ax5[0].set_ylim(-np.pi/2, np.pi/2)
    ax5[0].set_xlim(0, 2*np.pi)
    # ax5.set_xlim(np.min(ms_model.data[:,1]), np.max(ms_model.data[:,1]))
    if (xlims0, ylims0) != ((0., 1.),(0., 1.)):
        ax5[0].set_xlim(xlims0)
        ax5[0].set_ylim(ylims0)
    

look_distance = .1 # How far to look for neighbours.
kernel_bandwidthLon = .35  # Longitude Kernel parameter.
kernel_bandwidthLat = .08  # Latitude Kernel parameter.
n_iterations = 20 # Number of iterations
    
ms_model = None

img_slider = widgets.IntSlider(min=0, max=len(wl_list)-1, step=1, value=3, description='Group')
msk_cb = widgets.Checkbox(value=False, description='Mask', disabled=False, indent=False)

img_slider.observe(new_refresh, 'value')
msk_cb.observe(new_refresh, 'value')

hist_slider = widgets.IntSlider(min=0, max=n_iterations-1, step=1, value=0, description='History')
hist_slider.observe(hist_refresh, 'value')

plt.ioff()
fig3,ax3 = plt.subplots(1,3,figsize=(12,3))
fig5,ax5 = plt.subplots(1,2,figsize=(10,3))
new_refresh(None)
hist_refresh(None)
plt.ion()


display(widgets.VBox([widgets.HBox([img_slider, msk_cb]), fig3.canvas]))
display(widgets.VBox([widgets.HBox([hist_slider]), fig5.canvas]))

NameError: name 'wl_list' is not defined