In [None]:
from email.mime import base
from tokenize import group
from sunscc.dataset.transform.pipelines import Compose
import collections
from functools import partial

from torch.utils.data import Dataset, DataLoader
import os
import random
from hydra.utils import call, instantiate
from pathlib import Path
import skimage.io as io
import skimage
from skimage.measure import label, regionprops
import skimage.exposure
from numpy.random import default_rng

import json

import numpy as np

from tqdm.notebook import tqdm

import traceback
import time

from albumentations.core.transforms_interface import DualTransform

from astropy.io import fits
from sunscc.dataset.utils import *

from omegaconf import DictConfig, OmegaConf
import matplotlib.pyplot as plt


from datetime import datetime, timedelta

from copy import deepcopy

import pickle

import concurrent.futures
from itertools import repeat
import multiprocessing

from numpy.random import default_rng

%matplotlib widget
# %matplotlib inline

In [None]:
def print_elapsed_time(st, msg):
        end_time = time.time()
        print(f'Elapsed time {msg}: {end_time - st}')

class ClassificationDatasetSuperclasses(Dataset):

    def __init__(
        self, root_dir, partition, dtypes, classes,
        first_classes, second_classes, third_classes, 
        json_file, classification='SuperClass' , transforms=None) -> None:
        super().__init__()
        if isinstance(transforms, collections.Mapping):
            transforms = partial(call, config=transforms)
        elif isinstance(transforms, collections.Sequence):
            transforms_init = []
            for transform in transforms:
                transforms_init.append(instantiate(transform))
            transforms = Compose(transforms_init)

        self.transforms = transforms
        self.root_dir = Path(root_dir)

        self.main_dtype = dtypes[0]
        self.mask_dtype = dtypes[1]

        self.json_file = self.root_dir / json_file 
        # print(self.json_file)
        self.partition_dict = None

        self.c1_mapper = {c: i for i,c in enumerate(first_classes)}
        self.c2_mapper = {c: i for i,c in enumerate(second_classes)}
        self.c3_mapper = {c: i for i,c in enumerate(third_classes)}


        with open(self.json_file, 'r') as f:
            self.partition_dict = json.load(f)[partition]

        
        assert (classification == 'Zurich') or (classification == 'McIntosh') or (classification == 'SuperClass')

        self.classification = classification

        self.FirstClass_mapper = {c: i for i,c in enumerate(first_classes)}
        self.SecondClass_mapper = {c: i for i,c in enumerate(second_classes)}
        self.ThirdClass_mapper = {c: i for i,c in enumerate(third_classes)}

        # print(classes)
        self.files = {}
        for i, bn in enumerate(sorted(list(self.partition_dict.keys()))):
            bn = bn.split('_')[0]
            # print(bn)
            cur = {}
            image_basename = bn + '.FTS'
            image_filename = self.root_dir / self.main_dtype / image_basename

            sun_mask_filename = self.root_dir / 'sun_mask' / (bn + '.png')


            mask_basename = bn + '.png'
            mask_filename = self.root_dir / self.mask_dtype / mask_basename

            conf_map_basename = bn + '_proba_map.npy'
            conf_map_filename = self.root_dir / self.mask_dtype / conf_map_basename

            cur["name"] = bn
            cur[self.main_dtype] = image_filename
            cur[self.mask_dtype] = mask_filename
            cur[self.mask_dtype+"_conf_map"] = conf_map_filename
            cur["sun_mask"] = sun_mask_filename

            self.files[bn] = cur

        # print(self.files)

        self.partition_dict

        # print(list(self.partition_dict.values())[0])

        self.groups = {}
        for k,v in self.partition_dict.items():

            if v[self.classification]["1"] in classes:
                    self.groups[k] = v
            else:
                # print(v[self.classification]["1"])
                pass

      

        self.dataset_length = len(list(self.groups.keys()))



    def __len__(self) -> int:
        # raise NotImplementedError
        # print(self.dataset_length)
        return self.dataset_length
        # return 10
    
    def __getitem__(self, index: int, do_transform=True):

        # st = time.time()

        sample = {} # dictionnary with 'image', 'class', 'angular_excentricity', 'centroid_lat'

        # basename = self.files[index]["name"]
        k = sorted(list(self.groups.keys()))[index]
        # print(k)
        basename = k.split('_')[0]


        # image_out_dict = self.partition_dict[basename]
        group_dict = self.groups[k]

        # print(group_dict)

        img_name = self.files[basename][self.main_dtype] # path of FITS file
        mask_name = self.files[basename][self.mask_dtype]
        conf_map_name = self.files[basename][self.mask_dtype+"_conf_map"]

        # print(img_name)
        # st =  time.time()
        hdulst:fits.HDUList = fits.open(img_name)
        image = hdulst[0]
        header = image.header
        center = np.array(image.shape)//2
        radius = header['SOLAR_R']
        
        # st = time.time()
        sample['solar_disk'] = io.imread(self.files[basename]["sun_mask"])
        # print_elapsed_time(st, 'load sun mask')



        # st = time.time()
        sample['excentricity_map'] = create_excentricity_map(sample['solar_disk'], radius, value_outside=-1)
        # print_elapsed_time(st, 'create_excentricity_map')

        # st = time.time()
        sample['mask'] = io.imread(mask_name)#.astype(float)
        sample['confidence_map'] = np.load(conf_map_name)
        # print_elapsed_time(st, 'load mask and conf map')

        # st = time.time()
        sample['image'] = (image.data).astype(float)

        sample['members'] = np.array(group_dict['members']) if 'members' in group_dict else np.array([0])
        sample['members_mean_px'] = np.array(group_dict['members_mean_px']) if 'members_mean_px' in group_dict else np.array([0])

        sample['name'] = basename
        sample['group_name'] = k

        sample['solar_angle'] = group_dict['angle']
        sample['deltashapeX'] = group_dict['deltashapeX']
        sample['deltashapeY'] = group_dict['deltashapeY']
        
        sample['angular_excentricity'] = np.array([group_dict["angular_excentricity_deg"]])
        sample['centroid_px'] = np.array(group_dict["centroid_px"])
        # print(sample['centroid_px'])
        sample['centroid_Lat'] = np.array([group_dict["centroid_Lat"]])

        # sample['class'] = np.array([self.classes_mapper[group_dict[self.classification]]])
        sample['class1'] = group_dict[self.classification]['1']
        sample['class2'] = group_dict[self.classification]['2']
        sample['class3'] = group_dict[self.classification]['3']
        # sample['class1'] = np.array([self.FirstClass_mapper[group_dict[self.classification]['1']]])
        # sample['class2'] = np.array([self.SecondClass_mapper[group_dict[self.classification]['2']]])
        # sample['class3'] = np.array([self.ThirdClass_mapper[group_dict[self.classification]['3']]])
        # print_elapsed_time(st, 'remaining operations')

        if sample["image"].shape == (1024,1024):
            fig,ax = plt.subplots(2, 5, figsize=(10, 4) )
            ax[0,0].imshow(sample["image"], cmap='gray', interpolation='none')
            ax[0,1].imshow(sample["mask"], cmap='gray', interpolation='none')
            ax[0,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
            ax[0,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
            ax[0,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
            # scatter the centroid
            ax[0,0].scatter(sample['centroid_px'][0], sample['centroid_px'][1], c='r', s=10)

            # print(sample["image"].shape)
            # double the resolution to 2048x2048 of all visual data
            sample["image"] = np.repeat(np.repeat(sample["image"], 2, axis=0), 2, axis=1)
            sample["mask"] = np.repeat(np.repeat(sample["mask"], 2, axis=0), 2, axis=1)
            sample["confidence_map"] = np.repeat(np.repeat(sample["confidence_map"], 2, axis=0), 2, axis=1)
            sample["solar_disk"] = np.repeat(np.repeat(sample["solar_disk"], 2, axis=0), 2, axis=1)
            sample["excentricity_map"] = np.repeat(np.repeat(sample["excentricity_map"], 2, axis=0), 2, axis=1)
            # print(sample["image"].shape)
            
            # # also double the delta shape values
            sample['deltashapeX'] = sample['deltashapeX']*2
            sample['deltashapeY'] = sample['deltashapeY']*2

            # also double the centroid values
            sample['centroid_px'] = sample['centroid_px']*2


            ax[1,0].imshow(sample["image"], cmap='gray', interpolation='none')
            ax[1,1].imshow(sample["mask"], cmap='gray', interpolation='none')
            ax[1,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
            ax[1,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
            ax[1,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
            # scatter the centroid
            ax[1,0].scatter(sample['centroid_px'][0], sample['centroid_px'][1], c='r', s=10)



            fig.tight_layout()

            # show differences
            
        flip_time = "2003-03-08T00:00:00"
        date = whitelight_to_datetime(basename)
        datetime_str = datetime_to_db_string(date).replace(' ', 'T')
        # print(datetime_str)
        should_flip = (datetime.fromisoformat(datetime_str) - datetime.fromisoformat(flip_time)) < timedelta(0)
        sample['should_flip'] = should_flip

        if should_flip:
            sample['image'] = np.flip(sample['image'],axis=0)
            sample['solar_disk'] = np.flip(sample['solar_disk'],axis=0)
            sample['mask'] = np.flip(sample['mask'],axis=0)
            sample['confidence_map'] = np.flip(sample['confidence_map'],axis=0)
            sample['excentricity_map'] = np.flip(sample['excentricity_map'],axis=0)


#         st = time.time()
        if self.transforms is not None and do_transform:
            sample = self.transforms(**sample)
#         print_elapsed_time(st, 'transform')

        # fig,ax = plt.subplots(1, 5, figsize=(10, 4) )
        # ax[0].imshow(sample["image"], cmap='gray', interpolation='none')
        # ax[1].imshow(sample["mask"], cmap='gray', interpolation='none')
        # ax[2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
        # ax[3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
        # ax[4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
        

        return sample
        

In [None]:
root_dir = "../../datasets/classification/"
sub_dir = '2002-2019'
partition = 'val'
dtypes = ['image', 'T425-T375-T325_fgbg']

classes = ['A','B','C','SuperGroup','H']
first_classes = [ 'A','B','C','SuperGroup','H']
second_classes= [ 'x','r','sym','asym']
third_classes= [ "x","o","frag"]
json_file = f'{sub_dir}/dataset_revised.json'

classification = 'SuperClass'
transforms = OmegaConf.load('../../sunscc/conf/exp/Classification_Superclasses4.yaml').dataset.train_dataset.transforms

transforms[3].standard_height = 350
transforms[3].standard_width = 350

print(transforms)

dataset =  ClassificationDatasetSuperclasses(root_dir, partition, dtypes, classes, first_classes, second_classes, third_classes, json_file, classification, transforms)

In [None]:
partition_samples = {}

In [None]:
def get_sample(idx, dataset):
    tmp = dataset[idx]
    center  = tmp['image'].shape[0]//2 , tmp['image'].shape[1]//2
    slice_x = center[0]-( 256 //2 ), center[0]+ (256 //2 )
    slice_y = center[1]-( 256 //2 ), center[1]+ (256 //2 )
    center_region = tmp['confidence_map'][slice_x[0]:slice_x[1], slice_y[0]:slice_y[1]]
    if np.sum(center_region) == 0:
        print('skipping sample', idx)
        return None
    return dataset[idx]

num_cpu = 15


with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
    for sample in tqdm(executor.map(get_sample, range(len(dataset)) , repeat(deepcopy(dataset)))):
    # for sample in tqdm(executor.map(get_sample, range(10) , repeat(deepcopy(dataset)))):
        if sample is not None:
            if sample['group_name'] not in partition_samples:
                partition_samples[sample['group_name']] = sample


In [None]:
npy_file = os.path.join(root_dir, sub_dir, f'all_samples_{partition}.npy')

np.save(npy_file, partition_samples)

# Test

In [None]:
def print_elapsed_time(st, msg):
        end_time = time.time()
        print(f'Elapsed time {msg}: {end_time - st}')
        
class Deepsun_Focus_Move(DualTransform):
    def __init__(self, standard_height=256, standard_width=256,
                        focus_on_group=True,
                        random_move=False, random_move_percent=0.1,  
                        always_apply=False, p=1.0) -> None:

        super().__init__(always_apply, p)

        self.standard_height = standard_height
        self.standard_width = standard_width
        
        self.focus_on_group = focus_on_group
        self.random_move = random_move
        self.random_move_percent = random_move_percent

    def get_bounding_box_around_group_with_padding(self, mask, offset):
        # Get the bounding box around non-zero pixels in mask
        x, y = np.nonzero(mask)
        # print(x, y)
        x1, x2 = (np.min(x), np.max(x)) if len(x) > 0 else (None, None)
        y1, y2 = (np.min(y), np.max(y)) if len(y) > 0 else (None, None)

        if (x1 is None) or (y1 is None):
            return 0, mask.shape[0]-1, 0, mask.shape[1]-1

        # Add padding
        x1 -= offset
        x2 += offset
        y1 -= offset
        y2 += offset


        # Make sure the bounding box is not outside the image
        x1 = max(x1, 0)
        x2 = min(x2, mask.shape[0])
        y1 = max(y1, 0)
        y2 = min(y2, mask.shape[1])

        return x1, x2, y1, y2

    def adapt_bbox_to_image_size(self, bbox, image_size):
        bbox_center = ((bbox[0] + bbox[1]) // 2, (bbox[2] + bbox[3]) // 2)
        bbox_size = (bbox[1] - bbox[0], bbox[3] - bbox[2])

        # if bbox is too small, expand it
        minimal_percentage = .4

        bbox_size = (max(bbox_size[0], image_size[0] * minimal_percentage),
                     max(bbox_size[1], image_size[1] * minimal_percentage))
        
        return (int(bbox_center[0] - bbox_size[0] // 2), int(bbox_center[0] + bbox_size[0] // 2),
                int(bbox_center[1] - bbox_size[1] // 2), int(bbox_center[1] + bbox_size[1] // 2))

    def crop_img(self, img, bbox):
        # Crop image
        x1, x2, y1, y2 = bbox
        img = img[x1:x2, y1:y2]
        return img
        
    def padding(self, array, xx, yy):
        """
        :param array: numpy array
        :param xx: desired height
        :param yy: desirex width
        :return: padded array
        """

        h = array.shape[0]
        w = array.shape[1]

        a = (xx - h) // 2
        aa = xx - a - h

        b = (yy - w) // 2
        bb = yy - b - w

        # print(a,aa,b,bb)

        a = max(a,0)
        b = max(b,0)
        aa = max(aa,0)
        bb = max(bb,0)


        # print('->',a,aa,b,bb)

        return np.pad(array, pad_width=((a, aa), (b, bb)), mode='constant')

    def crop_and_pad(self, img, bbox, image_size):
        # print('crop_and_pad')
        # print(f'bbox: {bbox}, image_size: {image_size}, img.shape: {img.shape}')
        # Crop image
        img = self.crop_img(img, bbox)
        # Pad image
        img = self.padding(img, image_size[0], image_size[1])
        return img

    def data_aug_random_move(self, bbox, max_offset):
        '''
        Randomly move the bounding box
        param bbox: bounding box
        param max_offset: maximum offset in portion of the bbox size
        '''
        # Randomly move the bounding box
        x1, x2, y1, y2 = bbox
        horizontal_offset = (np.random.random(1) * 2*max_offset) - max_offset
        vertical_offset = (np.random.random(1) * 2*max_offset) - max_offset
        # print(f'horizontal_offset: {horizontal_offset}, vertical_offset: {vertical_offset}')

        x1 += int(horizontal_offset * (bbox[1] - bbox[0]))
        x2 += int(horizontal_offset * (bbox[1] - bbox[0]))
        y1 += int(vertical_offset * (bbox[3] - bbox[2]))
        y2 += int(vertical_offset * (bbox[3] - bbox[2]))
        
        return x1, x2, y1, y2
        
    def __call__(self, *args, force_apply=False, **kwargs):

        img_group_crop = kwargs['image'].copy()
        msk_group_crop = kwargs['mask'].copy()
        grp_msk_group_crop = kwargs['group_mask'].copy()
        disk_group_crop = kwargs['solar_disk'].copy()
        excentricity_group_crop = kwargs['excentricity_map'].copy()
        confidence_group_crop = kwargs['confidence_map'].copy()
        grp_confidence_group_crop = kwargs['group_confidence_map'].copy()
        
        shape  = img_group_crop.shape
        # minX, maxX, minY, maxY =  ((shape[0]//2)-self.standard_height//2, 
        #                             (shape[0]//2)+self.standard_height//2, 
        #                             (shape[1]//2)-self.standard_width//2, 
        #                             (shape[1]//2)+self.standard_width//2)

        # bbox format = x1, x2, y1, y2
        bbox = self.get_bounding_box_around_group_with_padding((grp_confidence_group_crop>0), 10)

        # print(bbox)
        # print(bbox[1]-bbox[0], bbox[3]-bbox[2])
        
        minX, maxX, minY, maxY =  (
                                    ((bbox[1]+bbox[0])//2)-(self.standard_height//2), 
                                    ((bbox[1]+bbox[0])//2)+(self.standard_height//2), 
                                    ((bbox[3]+bbox[2])//2)-(self.standard_width//2), 
                                    ((bbox[3]+bbox[2])//2)+(self.standard_width//2)
                                    
                                    )
        # print("new_shape minmax ",minX, maxX, minY, maxY)
        
            
        if self.focus_on_group:
            # focus on the group
            # print('focus on group')
            # Modify the bounding box if data augmentation is enabled
            if self.random_move:
                # print('random_move')
                bbox = self.data_aug_random_move(bbox, max_offset=self.random_move_percent)
                                    
                # Make sure the bounding box is not outside the image
                x1, x2, y1, y2 = bbox
                x1 = max(x1, 0)
                x2 = min(x2, self.standard_width)
                y1 = max(y1, 0)
                y2 = min(y2, self.standard_height)
                bbox = x1, x2, y1, y2
            else:
                # print('no random_move')
                pass
                
            bbox = self.adapt_bbox_to_image_size( bbox, (self.standard_height, self.standard_width))
            # print(bbox)
            img_group_crop = self.crop_and_pad(img_group_crop, bbox, (self.standard_height, self.standard_width))
            msk_group_crop = self.crop_and_pad(msk_group_crop, bbox, (self.standard_height, self.standard_width))
            grp_msk_group_crop = self.crop_and_pad(grp_msk_group_crop, bbox, (self.standard_height, self.standard_width))
            disk_group_crop = self.crop_and_pad(disk_group_crop, bbox, (self.standard_height, self.standard_width))
            excentricity_group_crop = self.crop_and_pad(excentricity_group_crop, bbox, (self.standard_height, self.standard_width))
            confidence_group_crop = self.crop_and_pad(confidence_group_crop, bbox, (self.standard_height, self.standard_width))
            grp_confidence_group_crop = self.crop_and_pad(grp_confidence_group_crop, bbox, (self.standard_height, self.standard_width))
        else:
            # print('NO focus on group')
            if self.random_move:
                # print('random_move')
                frac = np.max([(bbox[1]-bbox[0]) /self.standard_height, (bbox[3]-bbox[2]) /self.standard_width])
                frac = np.sqrt(frac)
                # print(minX, maxX, minY, maxY , frac, self.random_move_percent)
                bbox = self.data_aug_random_move([minX,maxX,minY,maxY], max_offset=self.random_move_percent*frac)
                # print('bbox',bbox)
                if not ((bbox[1] > shape[0]) or (bbox[3] > shape[1]) or (bbox[0] < 0) or (bbox[2] < 0)):
                    bbox = [bbox[0], bbox[0]+self.standard_height, bbox[2], bbox[3]]
                    # x1, x2, y1, y2 = bbox
                    # x1 = max(x1, 0)
                    # x2 = min(x2, self.standard_width)
                    # y1 = max(y1, 0)
                    # y2 = min(y2, self.standard_height)
                    minX,maxX,minY,maxY = bbox
                # print(minX, maxX, minY, maxY )
            else:
                # print('no random_move')
                pass
                
            img_group_crop = img_group_crop[minX:maxX,minY:maxY]
            msk_group_crop = msk_group_crop[minX:maxX,minY:maxY]
            grp_msk_group_crop = grp_msk_group_crop[minX:maxX,minY:maxY]
            disk_group_crop = disk_group_crop[minX:maxX,minY:maxY]
            excentricity_group_crop = excentricity_group_crop[minX:maxX,minY:maxY]
            confidence_group_crop = confidence_group_crop[minX:maxX,minY:maxY]
            grp_confidence_group_crop = grp_confidence_group_crop[minX:maxX,minY:maxY]

        if img_group_crop.shape != (self.standard_height, self.standard_width):
            img_group_crop = self.padding(img_group_crop, self.standard_height, self.standard_width)
            msk_group_crop = self.padding(msk_group_crop, self.standard_height, self.standard_width)
            grp_msk_group_crop = self.padding(grp_msk_group_crop, self.standard_height, self.standard_width)
            disk_group_crop = self.padding(disk_group_crop, self.standard_height, self.standard_width)
            excentricity_group_crop = self.padding(excentricity_group_crop, self.standard_height, self.standard_width)
            confidence_group_crop = self.padding(confidence_group_crop, self.standard_height, self.standard_width)
            grp_confidence_group_crop = self.padding(grp_confidence_group_crop, self.standard_height, self.standard_width)


        assert img_group_crop.shape == (self.standard_height, self.standard_width)
        assert msk_group_crop.shape == (self.standard_height, self.standard_width)
        assert grp_msk_group_crop.shape == (self.standard_height, self.standard_width)
        assert disk_group_crop.shape == (self.standard_height, self.standard_width)
        assert excentricity_group_crop.shape == (self.standard_height, self.standard_width)
        assert confidence_group_crop.shape == (self.standard_height, self.standard_width)
        assert grp_confidence_group_crop.shape == (self.standard_height, self.standard_width)

        # self.print_elapsed_time(st, 'focusMove Operations')
        # print('after focus_move call',img_group_crop.shape)


        kwargs['image'] = img_group_crop.copy()
        kwargs['mask'] = msk_group_crop.copy()
        kwargs['group_mask'] = grp_msk_group_crop.copy()
        kwargs['solar_disk'] = disk_group_crop.copy()
        kwargs['excentricity_map'] = excentricity_group_crop.copy()
        kwargs['confidence_map'] = confidence_group_crop.copy()
        kwargs['group_confidence_map'] = grp_confidence_group_crop.copy()
        
        return kwargs

 


class ClassificationDatasetSuperclasses_fast(Dataset):

    def __init__(
        self, root_dir, partition, dtypes, classes,
        first_classes, second_classes, third_classes, 
        dataset_file, classification='SuperClass' , transforms=None) -> None:
        super().__init__()
        if isinstance(transforms, collections.Mapping):
            transforms = partial(call, config=transforms)
        elif isinstance(transforms, collections.Sequence):
            transforms_init = []
            for transform in transforms:
                transforms_init.append(instantiate(transform))
            transforms = Compose(transforms_init)

        self.transforms = transforms
        self.root_dir = Path(root_dir)

        self.main_dtype = dtypes[0]
        self.mask_dtype = dtypes[1]

        # self.json_file = self.root_dir / json_file 
        self.disk_file = self.root_dir / dataset_file.replace('.', '_'+partition+'.')
        # print(self.json_file)
        self.partition_dict = None

        self.c1_mapper = {c: i for i,c in enumerate(first_classes)}
        self.c2_mapper = {c: i for i,c in enumerate(second_classes)}
        self.c3_mapper = {c: i for i,c in enumerate(third_classes)}
        
        st = time.time()
        # print('Loading npy dataset')
        dataset = np.load(self.disk_file, allow_pickle=True).item()
        # print_elapsed_time(st, 'Loading npy dataset')
        
        # print(dataset.keys())
    
        # self.partition_dict = dataset[partition]
        self.partition_dict = dataset

        self.classification = classification

        self.FirstClass_mapper = {c: i for i,c in enumerate(first_classes)}
        self.SecondClass_mapper = {c: i for i,c in enumerate(second_classes)}
        self.ThirdClass_mapper = {c: i for i,c in enumerate(third_classes)}

        # print(self.files)

        # print(list(self.partition_dict.values())[0])

        self.groups = {}
        for k,v in self.partition_dict.items():

            if v["class1"] in classes:
                    self.groups[k] = v
            else:
                # print(v[self.classification]["1"])
                pass

        self.dataset_length = len(list(self.groups.keys()))


    def __len__(self) -> int:
        # raise NotImplementedError
        # print(self.dataset_length)
        return self.dataset_length
        # return 10
    
    def __getitem__(self, index: int, do_transform=True):

        idx = list(self.groups.keys())[index]
        
        sample = deepcopy(self.groups[idx])

        # ['image']
        # ['mask']
        # ['group_mask']
        # ['solar_disk']
        # ['excentricity_map']
        # ['confidence_map']
        # ['group_confidence_map']

        # print(sample['group_name'])

        sample['class1'] = np.array([self.FirstClass_mapper[sample['class1']]])
        sample['class2'] = np.array([self.SecondClass_mapper[sample['class2']]])
        sample['class3'] = np.array([self.ThirdClass_mapper[sample['class3']]])

       
        fig,ax = plt.subplots(2, 5, figsize=(10, 4) )
        ax[0,0].set_title(sample['group_name'])
        ax[0,0].imshow(sample["image"], cmap='gray', interpolation='none')
        ax[0,1].imshow(sample["mask"], cmap='gray', interpolation='none')
        ax[0,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
        ax[0,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
        ax[0,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')


        tmp = np.argwhere(sample['confidence_map'] > 0)
        ax[0,2].scatter(tmp[:,1], tmp[:,0], s=1, c='r', alpha=0.5)

        sample['image'][sample['excentricity_map'] < 0] = 0
        sample['mask'][sample['excentricity_map'] < 0] = 0
        sample['group_mask'][sample['excentricity_map'] < 0] = 0
        sample['solar_disk'][sample['excentricity_map'] < 0] = 0
        sample['confidence_map'][sample['excentricity_map'] < 0] = 0
        sample['group_confidence_map'][sample['excentricity_map'] < 0] = 0
        sample['group_confidence_map'][sample['excentricity_map'] > 0.95] = 0
        
        sample['excentricity_map'][sample['excentricity_map'] < 0] = 0
        
        # st = time.time()
        if self.transforms is not None and do_transform:
            sample = self.transforms(**sample)
        # print_elapsed_time(st, 'transform')

        

        ax[1,0].imshow(sample["image"], cmap='gray', interpolation='none')
        ax[1,1].imshow(sample["mask"], cmap='gray', interpolation='none')
        ax[1,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
        ax[1,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
        ax[1,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
        
        fig.tight_layout()

        tmp = np.argwhere(sample['confidence_map'] > 0)
#         ax[1,2].scatter(tmp[:,1], tmp[:,0], s=1, c='r', alpha=0.5)

        return sample


class DeepsunClassifHideOtherGroups(DualTransform):
    def __init__(self, always_apply=False, p=1.0, method='avg'):
        super().__init__(self, p=p)
        self.index = 0
        self.method = method

    def __call__(self, *args, force_apply=False, **kwargs):
        # st = time.time()
        img = kwargs['image'].copy()
        msk = kwargs['mask'].copy()
        grp_msk = kwargs['group_mask'].copy()
        disk = kwargs['solar_disk'].copy()
        excentricity = kwargs['excentricity_map'].copy()
        confidence = kwargs['confidence_map'].copy()
        grp_confidence = kwargs['group_confidence_map'].copy()

        # get the pixels that are in the solar disk and zero in the msk
        bg = (disk>0) & (msk == 0)


        # get the average value of the pixels in bg
        bg_avg = np.mean(img[bg])
        
        # get the non-zero pixels in confidence that are zero in group confidence
        # these are the pixels that belong to the group we are hiding
        # and are not part of the group we are classifying
        hide_mask = (confidence > 0) & (grp_confidence == 0)
        h_m = hide_mask.copy().astype(np.uint8)
        
        
        # make a dilation of the hide_mask to make sure the border is covered
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
        h_m2 = cv2.morphologyEx(h_m, cv2.MORPH_DILATE, kernel)
        hide_mask = h_m2 > 0


        if self.method == 'avg':
            # set the pixels in hide_mask to the average value of the background
            new_img = img.copy()
            new_img[hide_mask] = bg_avg
            
        elif self.method == 'local_avg':
            # for each shape in the hide_mask, get the average value of pixels around it
            # and set the pixels in the shape to that value
            new_img = img.copy()

            # get the connected components in the hide_mask using skimage.measure.label
            # this will give us a label for each shape in the hide_mask
            # and a background label of 0
            labels = skimage.measure.label(hide_mask, background=0)
            for i in range(1,labels.max()+1):
                # get the pixels in the shape
                shape_mask = labels == i

                # do a dilation of the shape mask to get the pixels around the shape
                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
                dilated_shape_mask = cv2.morphologyEx(shape_mask.astype(np.uint8), cv2.MORPH_DILATE, kernel)

                # get the pixels that are in the dilated shape mask but not in the shape mask
                # these are the pixels around the shape
                around_shape_mask = (dilated_shape_mask & (~shape_mask)).astype(np.uint8)

                # get the average value of the pixels around the shape
                around_shape_avg = np.mean(new_img[around_shape_mask != 0])

                # set the pixels in the shape to the average value of the pixels around the shape
                new_img[shape_mask] = around_shape_avg
                
#                 # # plot the shape and the pixels around it
#                 fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10,3))
#                 ax[0].imshow(img, interpolation=None, cmap='gray')
#                 ax[1].imshow(new_img, interpolation=None, cmap='gray')
#                 ax[1].imshow(shape_mask, interpolation=None, alpha=0.25)
#                 ax[2].imshow(img, interpolation=None, cmap='gray')
#                 ax[2].imshow(around_shape_mask, interpolation=None, alpha=0.25)
#                 fig.tight_layout()
#                 plt.show()

  

# #         # create a matplotlib figure
#         fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(10,3))
#         # plot the original image and mask
#         ax[0].imshow(img, interpolation=None, cmap='gray')
#         ax[1].imshow(confidence, interpolation=None, alpha=0.5)
#         ax[2].imshow(grp_confidence, interpolation=None, alpha=0.5)
#         ax[3].imshow(hide_mask, interpolation=None, alpha=0.5)
#         ax[4].imshow(new_img, interpolation=None, cmap='gray')
#         fig.tight_layout()
#         plt.show()
        
        kwargs['image'] = new_img
        
        return kwargs
    


from skimage.segmentation import clear_border
from matplotlib import patches
from shapely.geometry import Polygon


class DeepsunAddRandomPatch(DualTransform):
    def __init__(self, always_apply=False, p=1.0, method='avg'):
        super().__init__(self, p=p)
        self.index = 0
        self.method = method

    def gen_mask(self, shape):
        # read input image
        img = np.zeros(shape, np.uint8)
        height, width = img.shape[:2]

        # define random seed to change the pattern
        rng = default_rng()

        # create random noise image
        noise = rng.integers(0, 255, (height,width), np.uint8, True)

        # blur the noise image to control the size
        blur = cv2.GaussianBlur(noise, (0,0), sigmaX=5, sigmaY=5, borderType = cv2.BORDER_DEFAULT)

        # stretch the blurred image to full dynamic range
        stretch = skimage.exposure.rescale_intensity(blur, in_range='image', out_range=(0,255)).astype(np.uint8)

        # threshold stretched image to control the size
        thresh = cv2.threshold(stretch, 175, 255, cv2.THRESH_BINARY)[1]

        # # apply morphology open and close to smooth out and make 3 channels
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
        mask = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        # mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)

        # mask = thresh

        mask = cv2.merge([mask,mask,mask])

        return mask.copy()

    def remove_intersecting_regions(self, mask1, mask2):
        # Label connected components in both masks
        labeled_mask1 = label(mask1)
        labeled_mask2 = label(mask2)
        
#         fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,3))
#         # plot the original image and mask
#         ax[0].imshow(labeled_mask1, interpolation=None, cmap='gray')
#         ax[1].imshow(labeled_mask2, interpolation=None, cmap='gray')
#         plt.show()

        # Iterate over each label in mask1
        for label1 in np.unique(labeled_mask1):
            # Check if the label1 intersects with any label in mask2
            if np.any(np.isin(labeled_mask1[labeled_mask2 != 0], label1)):
                # Clear the region in mask1
                mask1[labeled_mask1 == label1] = 0

#         # Clear border regions in the resulting mask
#         cleared_mask = clear_border(mask1)
        cleared_mask = mask1.copy()

        return cleared_mask
    
    # from a mask with several connected components, get the bbox of the non-zero pixels
    # get the xmin, ymin, xmax, ymax of the non-zero pixels
    def mask_to_bbox(self, mask):
        # get the connected components
        labels = label(mask)
        # get the properties of the connected components
        regions = regionprops(labels)

        xmin, ymin = mask.shape[0], mask.shape[1]
        xmax, ymax = 0, 0       
        for r in regions:
            if r.bbox[0] < xmin:
                xmin = r.bbox[0]
            if r.bbox[1] < ymin:
                ymin = r.bbox[1]
            if r.bbox[2] > xmax:
                xmax = r.bbox[2]
            if r.bbox[3] > ymax:
                ymax = r.bbox[3]

        return xmin, ymin, xmax, ymax
    
    def generate_bbox(self, mask, disk_mask , Lon, Lat, radius):
        # given a binary mask, generate a bounding box that does not intersect the bbox of the non-zero pixels,
        # the bbox is not bound to the image size
        # Lon, Lat: the center of the bbox in radians,
        # mask: binary mask
        # return: bbox

        # get the bbox of the non-zero pixels (several connected components), this is the bbox that the generated bbox should not intersect
        # get the xmin, ymin, xmax, ymax of the non-zero pixels
        xmin, ymin, xmax, ymax = self.mask_to_bbox(mask)
        # create a polygon of the bbox of the non-zero pixels
        bbox_mask = (xmin, ymin, xmax, ymax)
        bbox_mask_poly = Polygon([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)])

        # bbox size are the denominator of a circle radius
        bbox_sizes  = [4, 6, 9, 30]
        sample_bbox_size = radius / np.random.choice(bbox_sizes)
        bbox_w = sample_bbox_size * np.abs(np.cos(Lon))
        bbox_h = sample_bbox_size * np.abs(np.cos(Lat))
        
        # get list of all non-zero coordinates in disk_mask
        # get the coordinates of the non-zero pixels
        disk_mask_coords = np.argwhere(disk_mask > 0 )


        # check if the bbox intersects the bbox of the non-zero pixels
        # if it does, generate a new bbox location
        # if it does not, return the bbox
        bbox = [0, 0, 0, 0] 
        trials = 0
        while True:
            if trials > 20:
                break
            else:
                trials +=1
                
            # select a random pixel from the disk mask
            index = np.random.randint(0, disk_mask_coords.shape[0])
            bbox_x = disk_mask_coords[index][1]
            bbox_y = disk_mask_coords[index][0]

#             print(bbox_x, bbox_y)

            # generate a bbox
            bbox = [ bbox_y -( bbox_h / 2),bbox_x - (bbox_w / 2),  bbox_y + (bbox_h / 2),bbox_x + (bbox_w / 2),]

            #check if the bbox intersects the bbox of the non-zero pixels without using shapely
            if bbox[0] > bbox_mask[2] or bbox[1] > bbox_mask[3] or bbox[2] < bbox_mask[0] or bbox[3] < bbox_mask[1]:
                break
                
            
        if trials > 20:
            return None, None

        return bbox, bbox_mask
    
    def modify(self, mask, gen_mask, bbox):
        
#         fig, ax = plt.subplots(1,3, figsize=(10,5))
#         ax[0].imshow(mask, interpolation='nearest', cmap=plt.cm.jet, alpha=0.7)
#         p = patches.Polygon([(bbox[1], bbox[0]), (bbox[3], bbox[0]), (bbox[3], bbox[2]), (bbox[1], bbox[2])], fill=False, color='r')
#         ax[0].add_patch(p)
            
#         ax[1].imshow(gen_mask, interpolation='nearest', cmap=plt.cm.jet, alpha=0.7)
        
        tmp = mask.copy()
        tmp[tmp>0] = 0

        source_height, source_width = gen_mask.shape[:2]
        height , width = mask.shape[:2]

        # bbox has format (ymin, xmin, ymax, xmax)
        # Extract the coordinates from the bounding box
        y1, x1, y2, x2 = bbox

        # Calculate the target index ranges
        target_y_start = y1
        target_y_end = y1 + source_height
        target_x_start = x1
        target_x_end = x1 + source_width

        # Ensure the target index ranges are within the bounds of the target image
        target_y_start = max(0, target_y_start)
        target_y_end = min(height, target_y_end)
        target_x_start = max(0, target_x_start)
        target_x_end = min(width, target_x_end)

        # Calculate the source index ranges
        source_y_start = target_y_start - y1
        source_y_end = source_y_start + (target_y_end - target_y_start)
        source_x_start = target_x_start - x1
        source_x_end = source_x_start + (target_x_end - target_x_start)

        # Assign the source image pixels to the target image using index ranges
        tmp[target_y_start:target_y_end, target_x_start:target_x_end] = \
            gen_mask[source_y_start:source_y_end, source_x_start:source_x_end]
        
        
#         ax[2].imshow(tmp, interpolation='nearest', cmap=plt.cm.jet, alpha=0.7)
        
        return tmp
        


    def __call__(self, *args, force_apply=False, **kwargs):
            #take a random number between 0 and 1, 
            # if it is less than the probability, then apply the transform
            if random.random() > self.p:
                return kwargs
            
#             print(kwargs.keys())
            
            # generate an image with same shape as the input image
            img = kwargs['image'].copy()
            msk = kwargs['confidence_map'].copy()
            disk = kwargs['solar_disk'].copy() 
            
            ######################
            # do a dilation of the mask to be sure the added group is not too close
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20,20))
            msk_dilate = cv2.morphologyEx((msk>0).astype(np.uint8), cv2.MORPH_DILATE, kernel)
            
            xmin, ymin, xmax, ymax = self.mask_to_bbox(msk_dilate)
            bbox_mask = (xmin, ymin, xmax, ymax)
            bbox_mask = [int(i) for i in bbox_mask]
            
            bbox, bbox_mask = self.generate_bbox(msk_dilate, disk, 3.4, .79, radius=450)
            if bbox is None:
                # No group location could be found, abort transform
                return kwargs
            
            bbox = [int(i) for i in bbox]
            bbox_shape = [bbox[2] - bbox[0], bbox[3] - bbox[1]]
            
            gen_m = self.gen_mask(bbox_shape)
#             print(f'uniques = {np.unique(gen_m)}')
            # remove connected components touching the border
            gen_m = gen_m[:,:,0]//255                 
#             print(f'uniques 2 = {np.unique(gen_m)}')
            
            gen_m_1 = self.modify(msk, gen_m, bbox)
            gen_m_1 = gen_m_1 * disk
            
#             fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(12,3))
#             p_mask = patches.Polygon([(bbox_mask[1], bbox_mask[0]), (bbox_mask[3], bbox_mask[0]), (bbox_mask[3], bbox_mask[2]), (bbox_mask[1], bbox_mask[2])], fill=False, color='g')
#             p = patches.Polygon([(bbox[1], bbox[0]), (bbox[3], bbox[0]), (bbox[3], bbox[2]), (bbox[1], bbox[2])], fill=False, color='r')
#             ax[0].imshow(img, interpolation=None, cmap='gray')
#             ax[0].add_patch(p)
#             ax[0].add_patch(p_mask)
#             ax[1].imshow(gen_m, interpolation='None', cmap='gray')
#             ax[2].imshow(gen_m_1, interpolation='None', cmap='gray')
#             plt.show()
            
            
            ######################
            
            
# #             print(np.unique(gen_m))
#             fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(12,3))
#             # plot the original image and mask
#             ax[0].imshow(img, interpolation=None, cmap='gray')
#             ax[0].set_title('Orig')
#             ax[1].imshow(msk>0, interpolation=None, alpha=0.5)
#             ax[1].set_title('mask')
#             ax[2].imshow(gen_m_1, interpolation=None, alpha=0.5)
#             ax[2].set_title('genereted mask')
#             plt.show()

            bg = (disk>0) & (msk == 0)
            # get the average value of the pixels in bg
            bg_avg = np.mean(img[bg])

            if self.method == 'avg':
                # set the pixels in hide_mask to the average value of the background
                new_img = img.copy()
                new_img[gen_m] = bg_avg
                
            elif self.method == 'local_avg':
                # for each shape in the hide_mask, get the average value of pixels around it
                # and set the pixels in the shape to that value
                new_img = img.copy()

                # get the connected components in the hide_mask using skimage.measure.label
                # this will give us a label for each shape in the hide_mask
                # and a background label of 0
                labels = skimage.measure.label(gen_m_1, background=0)
                for i in range(1,labels.max()+1):
                    # get the pixels in the shape
                    shape_mask = labels == i

                    # do a dilation of the shape mask to get the pixels around the shape
                    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
                    dilated_shape_mask = cv2.morphologyEx(shape_mask.astype(np.uint8), cv2.MORPH_DILATE, kernel)

                    # get the pixels that are in the dilated shape mask but not in the shape mask
                    # these are the pixels around the shape
                    around_shape_mask = (dilated_shape_mask & (~shape_mask)).astype(np.uint8)

                    # get the average value of the pixels around the shape
                    around_shape_avg = np.mean(new_img[around_shape_mask != 0])

                    # set the pixels in the shape to the average value of the pixels around the shape
                    new_img[shape_mask] = around_shape_avg   
    
#             # create a matplotlib figure
#             fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10,3))
#             # plot the original image and mask
#             ax[0].imshow(img, interpolation=None, cmap='gray')
#             p = patches.Polygon([(bbox[1], bbox[0]), (bbox[3], bbox[0]), (bbox[3], bbox[2]), (bbox[1], bbox[2])], fill=False, color='r')
#             ax[0].add_patch(p)

#             ax[1].imshow(gen_m, interpolation=None, alpha=0.5)
#             ax[2].imshow(new_img, interpolation=None, cmap='gray')
#             p2 = patches.Polygon([(bbox[1], bbox[0]), (bbox[3], bbox[0]), (bbox[3], bbox[2]), (bbox[1], bbox[2])], fill=False, color='r')           
#             ax[2].add_patch(p2)

#             fig.tight_layout()
#             plt.show()
            
            kwargs['image'] = new_img
            
            return kwargs
              
        

In [None]:

root_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2"
sub_dir = 'rebuttal_overlap_only'
partition='test'
dtypes = ['image', 'T425-T375-T325_fgbg']

classes = ['A','B','C','SuperGroup','H']
first_classes = [ 'A','B','C','SuperGroup','H']
second_classes= [ 'x','r','sym','asym']
third_classes= [ "x","o","frag"]
classification='SuperClasses'

transforms2 = [
                {'_target_': Deepsun_Focus_Move, 'standard_height': 256, 'standard_width': 256,
                 'focus_on_group': False, 'random_move': False, 'random_move_percent': .2},
                {'_target_': DeepsunClassifHideOtherGroups, 'method': 'local_avg'},
                {'_target_': DeepsunAddRandomPatch, "p":1, 'method': 'local_avg'},
#                 {'_target_': DeepsunAddRandomPatch, "p":.5, 'method': 'avg'},
#                 {'_target_': DeepsunClassifHideOtherGroups, 'method': 'avg'},
              ]
dataset_file = os.path.join(root_dir, sub_dir, f'all_samples.npy')
d2 = ClassificationDatasetSuperclasses_fast(root_dir, partition, dtypes, classes, first_classes, second_classes, third_classes, 
                                                dataset_file, classification, transforms2)
# for i in range(5):
#     d2[i]
# pass

In [None]:
# d2[150]
# d2[155]
# d2[30]
# d2[31]
# d2[32]
# d2[41]
# d2[55]
d2[119]
pass

In [None]:
for item in d2:
    pass

In [None]:
# index 430 contient plusieurs groupes sur le crop 
d2[430]
# d2[2227]
# d2[2649]
pass

In [None]:

class DeepsunAddRandomPatchOLD(DualTransform):
    def __init__(self, always_apply=False, p=1.0, method='avg'):
        super().__init__(self, p=p)
        self.index = 0
        self.method = method

    def gen_mask(self, shape):
        # read input image
        img = np.zeros(shape, np.uint8)
        height, width = img.shape[:2]

        # define random seed to change the pattern
        rng = default_rng()

        # create random noise image
        noise = rng.integers(0, 255, (height,width), np.uint8, True)

        # blur the noise image to control the size
        blur = cv2.GaussianBlur(noise, (0,0), sigmaX=5, sigmaY=5, borderType = cv2.BORDER_DEFAULT)

        # stretch the blurred image to full dynamic range
        stretch = skimage.exposure.rescale_intensity(blur, in_range='image', out_range=(0,255)).astype(np.uint8)

        # threshold stretched image to control the size
        thresh = cv2.threshold(stretch, 170, 255, cv2.THRESH_BINARY)[1]

        # apply morphology open and close to smooth out and make 3 channels
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
        mask = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        mask = cv2.merge([mask,mask,mask])
        
        return mask.copy()

    def remove_intersecting_regions(self, mask1, mask2):
        # Label connected components in both masks
        labeled_mask1 = label(mask1)
        labeled_mask2 = label(mask2)
        
#         fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,3))
#         # plot the original image and mask
#         ax[0].imshow(labeled_mask1, interpolation=None, cmap='gray')
#         ax[1].imshow(labeled_mask2, interpolation=None, cmap='gray')
#         plt.show()

        # Iterate over each label in mask1
        for label1 in np.unique(labeled_mask1):
            # Check if the label1 intersects with any label in mask2
            if np.any(np.isin(labeled_mask1[labeled_mask2 != 0], label1)):
                # Clear the region in mask1
                mask1[labeled_mask1 == label1] = 0

#         # Clear border regions in the resulting mask
#         cleared_mask = clear_border(mask1)
        cleared_mask = mask1.copy()

        return cleared_mask
    

    def divide_mask(self, mask, N, selected_cell):
        print(mask.shape)
        # Divide the mask into an NxN grid
        height, width = mask.shape
        cell_height = height // N
        cell_width = width // N
        
        
        # get all the grid cells that contain at least one connected component
        grid_cells = []
        for i in range(N):
            for j in range(N):
                y_start = i * cell_height
                y_end = (i + 1) * cell_height
                x_start = j * cell_width
                x_end = (j + 1) * cell_width
                if np.any(mask[y_start:y_end, x_start:x_end]):
                    grid_cells.append(i * N + j)
                    
        print(f' possible cells: {grid_cells}')
        
        if len(grid_cells) == 0:
            return np.zeros_like(mask,np.uint8)
        
        # randomly select a subset of the grid cells
        selected_grid_cell_idx = random.choice(grid_cells)
        print(f' selected_cell: {selected_grid_cell_idx}')
        
        # Determine the row and column indices of the selected grid cell
        row = selected_grid_cell_idx // N
        col = selected_grid_cell_idx % N
        
        # Define the bounding box of the selected grid cell
        y_start = row * cell_height
        y_end = (row + 1) * cell_height
        x_start = col * cell_width
        x_end = (col + 1) * cell_width
        
        # go through the connected component in mask and select the ones that are in the selected grid cell
        connected_mask = np.zeros_like(mask,np.uint8)

        # Label connected components in both masks
        labeled_mask = label(mask)
        #use regionprops to get the bounding box of each connected component
        regions = regionprops(labeled_mask)
        for region in regions:
            # get the bounding box of the connected component
            minr, minc, maxr, maxc = region.bbox
            
            # check if the bounding box intersects the selected grid cell by at least 1 pixel
            if (minr <= y_end and maxr >= y_start) and (minc <= x_end and maxc >= x_start):
                #add the connected component to the connected_mask
                connected_mask[labeled_mask == region.label] = 1
        
#         # Extract the selected grid cell from the mask
#         selected_grid_cell = np.zeros_like(mask, np.uint8)
#         selected_grid_cell[y_start:y_end, x_start:x_end] = mask[y_start:y_end, x_start:x_end]
        
#         # Label the connected components in the selected grid cell
#         labeled_array = label(selected_grid_cell)
#         num_features = len(np.unique(labeled_array))
        
#         # Create a mask to store the connected components
#         connected_mask = np.zeros_like(mask, np.uint8)
#         print(f'connected_mask.shape: {connected_mask.shape}')
        
#         # Find the connected components that contain at least one pixel belonging to the selected grid cell
#         for i in range(1, num_features + 1):
#             component = np.where(labeled_array == i, 1, 0).astype(np.uint8)
#             print(f'component.shape: {component.shape}')
#             if np.any(np.logical_and(component, mask)):
#                 print(component.dtype,connected_mask.dtype)
#                 print(f'component.shape:{component.shape}, connected_mask.shape: {connected_mask.shape}')
#                 connected_mask += (labeled_array == i).astype(np.uint8)
                
                # show 
        fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,3))
        # plot the original image and mask
        ax[0].imshow(mask, interpolation=None, cmap='gray')
        ax[0].imshow(mask, interpolation=None, cmap='jet', alpha=0.5)
        #add a rectangle to show the selected grid cell
        rect = patches.Rectangle((x_start, y_start), cell_width, cell_height, linewidth=1, edgecolor='r', facecolor='none')
        ax[0].add_patch(rect)
        # plot the connected mask
        ax[1].imshow(connected_mask, interpolation=None, cmap='gray')
        plt.show()
        
        return connected_mask
    
    def generate_box(self, angle):
        box_sides = []
        
        pass
    
    def box_to_polygon(self, box):
        # Convert box coordinates to polygon
        x1, y1, x2, y2 = box
        return Polygon([(x1, y1), (x1, y2), (x2, y2), (x2, y1)])
    
    def add_group(self, mask, angle):
        # box is in the format [x1,y1, x2,y2, x3,y3, x4,y4]
        connected_mask = np.zeros_like(mask)
        box = generate_box(angle)
        
        # Step 1: Find the bounding box coordinates of the binary mask
        label_mask = label(mask)
        regions = regionprops(label_mask)
        box_poly = box_to_polygon(box)

        # Step 2: Iterate over each connected component
        connected_components = []
        for region in regions:
            # Step 3: Check if the connected component intersects with the bounding box
            component_box = region.bbox
            component_polygon = box_to_polygon(component_box)
            if component_polygon.intersects(box):
                # Step 4: Add the connected component to the list
                connected_components.append(region)
                connected_mask[labeled_mask == region.label] = 1
        
        
                        
        # show 
        fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,3))
        # plot the original image and mask
        ax[0].imshow(mask, interpolation=None, cmap='gray')
        #add a rectangle to show the selected grid cell
#         rect = patches.Rectangle((x_start, y_start), cell_width, cell_height, linewidth=1, edgecolor='r', facecolor='none')
        box_arr = np.array([[box[0],box[1]],[box[2],box[3]],[box[4],box[5]],[box[6],box[7]]])
        rect = patches.Polygon(box_arr, linewidth=1, edgecolor='r', facecolor='none')
        ax[0].add_patch(rect)
        # plot the connected mask
        ax[1].imshow(connected_mask, interpolation=None, cmap='gray')
        plt.show()
        
        return connected_mask

    
    
    def randomly_select_elements(self, input_list, percentage):
        num_elements = int(len(input_list) * percentage)
        random_elements = random.sample(input_list, num_elements)
        return random_elements

    def randomly_remove_regions(self, mask, removal_prob, max_blobs=4):
        cleared_mask = mask.copy()
        
        # Label connected components in the mask
        labeled_mask = label(cleared_mask)

        # Get unique labels in the mask
        unique_labels = np.unique(labeled_mask)
#         print(unique_labels)
        to_keep = self.randomly_select_elements(unique_labels.tolist()[1:], removal_prob)
#         print(to_keep)
        random.shuffle(to_keep)
        
#         print(f'{max_blobs}->{len(to_keep)}  {to_keep}')
        
        if len(to_keep) > max_blobs:
            to_keep = to_keep[0:max_blobs] # +1 as index 0 is the background
        
        
        # print(unique_labels)

        # Iterate over each label in the mask
        for label_value in unique_labels[1:]:
            if label_value not in to_keep:
                # print('remove')
                # Clear the region in the mask
                cleared_mask[labeled_mask == label_value] = 0

        return cleared_mask

    def __call__(self, *args, force_apply=False, **kwargs):
            #take a random number between 0 and 1, 
            # if it is less than the probability, then apply the transform
            if random.random() > self.p:
                return kwargs
            
#             print('Done')
            
            # generate an image with same shape as the input image
            img = kwargs['image'].copy()
            msk = kwargs['confidence_map'].copy()
            disk = kwargs['solar_disk'].copy()
            
            
            # do an closing of the mask
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20,20))
            msk_closing = cv2.morphologyEx((msk>0).astype(np.uint8), cv2.MORPH_DILATE, kernel)
            
            
#             fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12,3))
#             ax.imshow((msk>0).astype(np.uint8), interpolation=None, cmap='gray', alpha=.5)
#             ax.imshow(msk_closing, interpolation=None, cmap='gray', alpha=.5)
#             plt.show()

            gen_m_1 = self.gen_mask(img.shape)            
            gen_m_2 = self.remove_intersecting_regions((gen_m_1>0).astype(np.uint8)[:,:,0],
                                                     msk_closing)
            
            # remove pixels outside the solar disk
            gen_m_2[disk==0] = 0

#             # randomly remove regions from the generated mask
#             gen_m = self.randomly_remove_regions(gen_m_2, 0.8)
            # randomly remove regions from the generated mask
            side = 3
            sel_cel = random.randint(0,side*side)
            gen_m = self.divide_mask(gen_m_2, side, sel_cel)
#             print(np.unique(gen_m), gen_m.shape)
            gen_m = (gen_m > 0)
            
#             print(np.unique(gen_m))
            fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(12,3))
            # plot the original image and mask
            ax[0].imshow(img, interpolation=None, cmap='gray')
            ax[0].set_title('Orig')
            ax[1].imshow(msk>0, interpolation=None, alpha=0.5)
            ax[1].set_title('mask')
            ax[2].imshow(gen_m_1, interpolation=None, alpha=0.5)
            ax[2].set_title('genereted mask')
            ax[3].imshow(gen_m_2, interpolation=None, alpha=0.5, cmap='gray')
            ax[3].set_title('removed intersection')
            ax[4].imshow(gen_m, interpolation=None, alpha=0.5, cmap='gray')
            ax[4].set_title('removed regions')
            plt.show()



            bg = (disk>0) & (msk == 0)
            # get the average value of the pixels in bg
            bg_avg = np.mean(img[bg])



            if self.method == 'avg':
                # set the pixels in hide_mask to the average value of the background
                new_img = img.copy()
                new_img[gen_m] = bg_avg
                
            elif self.method == 'local_avg':
                # for each shape in the hide_mask, get the average value of pixels around it
                # and set the pixels in the shape to that value
                new_img = img.copy()

                # get the connected components in the hide_mask using skimage.measure.label
                # this will give us a label for each shape in the hide_mask
                # and a background label of 0
                labels = skimage.measure.label(gen_m, background=0)
                for i in range(1,labels.max()+1):
                    # get the pixels in the shape
                    shape_mask = labels == i

                    # do a dilation of the shape mask to get the pixels around the shape
                    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
                    dilated_shape_mask = cv2.morphologyEx(shape_mask.astype(np.uint8), cv2.MORPH_DILATE, kernel)

                    # get the pixels that are in the dilated shape mask but not in the shape mask
                    # these are the pixels around the shape
                    around_shape_mask = (dilated_shape_mask & (~shape_mask)).astype(np.uint8)

                    # get the average value of the pixels around the shape
                    around_shape_avg = np.mean(new_img[around_shape_mask != 0])

                    # set the pixels in the shape to the average value of the pixels around the shape
                    new_img[shape_mask] = around_shape_avg
                    
                    # # plot the shape and the pixels around it
#                     fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10,3))
#                     ax[0].imshow(img, interpolation=None, cmap='gray')
#                     ax[1].imshow(gen_m, interpolation=None, cmap='gray')
#                     ax[1].imshow(shape_mask, interpolation=None, alpha=0.25)
#                     ax[2].imshow(img, interpolation=None, cmap='gray')
#                     ax[2].imshow(around_shape_mask, interpolation=None, alpha=0.25)
#                     fig.tight_layout()
#                     plt.show()

            

            
            
    
            # create a matplotlib figure
            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10,3))
            # plot the original image and mask
            ax[0].imshow(img, interpolation=None, cmap='gray')
            ax[1].imshow(gen_m, interpolation=None, alpha=0.5)
            ax[2].imshow(new_img, interpolation=None, cmap='gray')
            fig.tight_layout()
            plt.show()
            
            kwargs['image'] = new_img
            
            return kwargs
       

In [None]:
for i in tqdm(range(len(d2))):
    try:
        tmp = d2[i]
    except:
        print(tmp['image'].shape)
        print(i)
        pass

In [None]:
# VAL avait des erruers aux indices 94 / 196
# d2[94]
# d2[96]
# TRAIN avait des erreurs aux indices 288 / 473 / 477 / 707 / 777 / 946 / 1130 / 1166 / 2326
# d2[288]
# d2[473]
# d2[477]
# d2[707]
# d2[777]
# d2[946]
# d2[1130]
# d2[1166]
# d2[2326]


pass

In [None]:
for i in range(10):
    d2[i]
pass

In [None]:

import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import generate_binary_structure, label, random_uniform

def generate_random_blobs_mask(size=(100, 100), num_blobs=10, min_blob_size=10, max_blob_size=20):
    mask = np.zeros(size, dtype=np.uint8)
    
    for _ in range(num_blobs):
        blob_size = np.random.randint(min_blob_size, max_blob_size)
        
        struct = generate_binary_structure(2, 1)
        labeled, num_features = label(random_uniform((blob_size, blob_size)) > 0.5, structure=struct)
        
        x = np.random.randint(0, size[0] - blob_size)
        y = np.random.randint(0, size[1] - blob_size)
        mask[x:x+blob_size, y:y+blob_size] = labeled
    
    return mask

# Generate random blobs mask
mask = generate_random_blobs_mask()

# Display the generated mask
plt.imshow(mask, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
import cv2
import skimage.exposure
import numpy as np
from numpy.random import default_rng

# read input image
img = np.zeros((250,250))
height, width = img.shape[:2]

# define random seed to change the pattern
rng = default_rng()

# create random noise image
noise = rng.integers(0, 255, (height,width), np.uint8, True)

# blur the noise image to control the size
blur = cv2.GaussianBlur(noise, (0,0), sigmaX=4, sigmaY=4, borderType = cv2.BORDER_DEFAULT)

# stretch the blurred image to full dynamic range
stretch = skimage.exposure.rescale_intensity(blur, in_range='image', out_range=(0,255)).astype(np.uint8)

# threshold stretched image to control the size
thresh = cv2.threshold(stretch, 175, 255, cv2.THRESH_BINARY)[1]

# apply morphology open and close to smooth out and make 3 channels
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
mask = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.merge([mask,mask,mask])

print(type(mask), mask.shape)

# Display the generated mask
plt.figure()
plt.imshow(mask, cmap='gray')
plt.axis('off')
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.measure import label
from skimage.segmentation import clear_border
import random

def randomly_select_elements(input_list, percentage):
    num_elements = int(len(input_list) * percentage)
    random_elements = random.sample(input_list, num_elements)
    return random_elements

def randomly_remove_regions(mask, removal_prob):
    cleared_mask = mask.copy()
    
    # Label connected components in the mask
    labeled_mask = label(cleared_mask)

    # Get unique labels in the mask
    unique_labels = np.unique(labeled_mask)
    to_keep = randomly_select_elements(unique_labels.tolist(), removal_prob)
    print(unique_labels)

    # Iterate over each label in the mask
    for label_value in unique_labels[1:]:
        if label_value in to_keep:
            print('remove')
            # Clear the region in the mask
            cleared_mask[labeled_mask == label_value] = 0

    return cleared_mask


removal_probability = 0.9

# Make a copy of the original mask
masked_copy = np.copy(mask)


# Randomly remove regions
result_mask = randomly_remove_regions(masked_copy, removal_probability)

# Display the masks
fig, axes = plt.subplots(1, 3, figsize=(8, 4))
axes[0].imshow(mask, cmap='gray')
axes[0].set_title('Original Mask')
axes[0].axis('off')
axes[1].imshow(masked_copy, cmap='gray')
axes[1].set_title('Original Mask')
axes[1].axis('off')
axes[2].imshow(result_mask, cmap='gray')
axes[2].set_title('Mask with Randomly Removed Regions')
axes[2].axis('off')
plt.show()
