In [1]:
from email.mime import base
from tokenize import group
from bioblue.dataset.transform.pipelines import Compose
import collections
from functools import partial

from torch.utils.data import Dataset, DataLoader
import os
from hydra.utils import call, instantiate
from pathlib import Path
import skimage.io as io
import json

import numpy as np

from tqdm.notebook import tqdm

import traceback
import time

from albumentations.core.transforms_interface import DualTransform

from astropy.io import fits
from bioblue.dataset.utils import *

from omegaconf import DictConfig, OmegaConf
import matplotlib.pyplot as plt


from datetime import datetime, timedelta

from copy import deepcopy

import pickle

import concurrent.futures
from itertools import repeat
import multiprocessing

%matplotlib widget
# %matplotlib inline

In [2]:
def print_elapsed_time(st, msg):
        end_time = time.time()
        print(f'Elapsed time {msg}: {end_time - st}')

class ClassificationDatasetSuperclasses(Dataset):

    def __init__(
        self, root_dir, partition, dtypes, classes,
        first_classes, second_classes, third_classes, 
        json_file, classification='SuperClass' , transforms=None) -> None:
        super().__init__()
        if isinstance(transforms, collections.Mapping):
            transforms = partial(call, config=transforms)
        elif isinstance(transforms, collections.Sequence):
            transforms_init = []
            for transform in transforms:
                transforms_init.append(instantiate(transform))
            transforms = Compose(transforms_init)

        self.transforms = transforms
        self.root_dir = Path(root_dir)

        self.main_dtype = dtypes[0]
        self.mask_dtype = dtypes[1]

        self.json_file = self.root_dir / json_file 
        # print(self.json_file)
        self.partition_dict = None

        self.c1_mapper = {c: i for i,c in enumerate(first_classes)}
        self.c2_mapper = {c: i for i,c in enumerate(second_classes)}
        self.c3_mapper = {c: i for i,c in enumerate(third_classes)}


        with open(self.json_file, 'r') as f:
            self.partition_dict = json.load(f)[partition]

        
        assert (classification == 'Zurich') or (classification == 'McIntosh') or (classification == 'SuperClass')

        self.classification = classification

        self.FirstClass_mapper = {c: i for i,c in enumerate(first_classes)}
        self.SecondClass_mapper = {c: i for i,c in enumerate(second_classes)}
        self.ThirdClass_mapper = {c: i for i,c in enumerate(third_classes)}

        # print(classes)
        self.files = {}
        for i, bn in enumerate(sorted(list(self.partition_dict.keys()))):
            bn = bn.split('_')[0]
            # print(bn)
            cur = {}
            image_basename = bn + '.FTS'
            image_filename = self.root_dir / self.main_dtype / image_basename

            sun_mask_filename = self.root_dir / 'sun_mask' / (bn + '.png')


            mask_basename = bn + '.png'
            mask_filename = self.root_dir / self.mask_dtype / mask_basename

            conf_map_basename = bn + '_proba_map.npy'
            conf_map_filename = self.root_dir / self.mask_dtype / conf_map_basename

            cur["name"] = bn
            cur[self.main_dtype] = image_filename
            cur[self.mask_dtype] = mask_filename
            cur[self.mask_dtype+"_conf_map"] = conf_map_filename
            cur["sun_mask"] = sun_mask_filename

            self.files[bn] = cur

        # print(self.files)

        self.partition_dict

        # print(list(self.partition_dict.values())[0])

        self.groups = {}
        for k,v in self.partition_dict.items():

            if v[self.classification]["1"] in classes:
                    self.groups[k] = v
            else:
                # print(v[self.classification]["1"])
                pass

      

        self.dataset_length = len(list(self.groups.keys()))



    def __len__(self) -> int:
        # raise NotImplementedError
        # print(self.dataset_length)
        return self.dataset_length
        # return 10
    
    def __getitem__(self, index: int, do_transform=True):

        # st = time.time()

        sample = {} # dictionnary with 'image', 'class', 'angular_excentricity', 'centroid_lat'

        # basename = self.files[index]["name"]
        k = sorted(list(self.groups.keys()))[index]
        # print(k)
        basename = k.split('_')[0]


        # image_out_dict = self.partition_dict[basename]
        group_dict = self.groups[k]

        # print(group_dict)

        img_name = self.files[basename][self.main_dtype] # path of FITS file
        mask_name = self.files[basename][self.mask_dtype]
        conf_map_name = self.files[basename][self.mask_dtype+"_conf_map"]

        # print(img_name)
        # st =  time.time()
        hdulst:fits.HDUList = fits.open(img_name)
        image = hdulst[0]
        header = image.header
        center = np.array(image.shape)//2
        radius = header['SOLAR_R']
        
        # st = time.time()
        sample['solar_disk'] = io.imread(self.files[basename]["sun_mask"])
        # print_elapsed_time(st, 'load sun mask')



        # st = time.time()
        sample['excentricity_map'] = create_excentricity_map(sample['solar_disk'], radius, value_outside=-1)
        # print_elapsed_time(st, 'create_excentricity_map')

        # st = time.time()
        sample['mask'] = io.imread(mask_name)#.astype(float)
        sample['confidence_map'] = np.load(conf_map_name)
        # print_elapsed_time(st, 'load mask and conf map')

        # st = time.time()
        sample['image'] = (image.data).astype(float)

        sample['members'] = np.array(group_dict['members']) if 'members' in group_dict else np.array([0])
        sample['members_mean_px'] = np.array(group_dict['members_mean_px']) if 'members_mean_px' in group_dict else np.array([0])

        sample['name'] = basename
        sample['group_name'] = k

        sample['solar_angle'] = group_dict['angle']
        sample['deltashapeX'] = group_dict['deltashapeX']
        sample['deltashapeY'] = group_dict['deltashapeY']
        
        sample['angular_excentricity'] = np.array([group_dict["angular_excentricity_deg"]])
        sample['centroid_px'] = np.array(group_dict["centroid_px"])
        # print(sample['centroid_px'])
        sample['centroid_Lat'] = np.array([group_dict["centroid_Lat"]])

        # sample['class'] = np.array([self.classes_mapper[group_dict[self.classification]]])
        sample['class1'] = group_dict[self.classification]['1']
        sample['class2'] = group_dict[self.classification]['2']
        sample['class3'] = group_dict[self.classification]['3']
        # sample['class1'] = np.array([self.FirstClass_mapper[group_dict[self.classification]['1']]])
        # sample['class2'] = np.array([self.SecondClass_mapper[group_dict[self.classification]['2']]])
        # sample['class3'] = np.array([self.ThirdClass_mapper[group_dict[self.classification]['3']]])
        # print_elapsed_time(st, 'remaining operations')

        if sample["image"].shape == (1024,1024):
            fig,ax = plt.subplots(2, 5, figsize=(10, 4) )
            ax[0,0].imshow(sample["image"], cmap='gray', interpolation='none')
            ax[0,1].imshow(sample["mask"], cmap='gray', interpolation='none')
            ax[0,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
            ax[0,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
            ax[0,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
            # scatter the centroid
            ax[0,0].scatter(sample['centroid_px'][0], sample['centroid_px'][1], c='r', s=10)

            # print(sample["image"].shape)
            # double the resolution to 2048x2048 of all visual data
            sample["image"] = np.repeat(np.repeat(sample["image"], 2, axis=0), 2, axis=1)
            sample["mask"] = np.repeat(np.repeat(sample["mask"], 2, axis=0), 2, axis=1)
            sample["confidence_map"] = np.repeat(np.repeat(sample["confidence_map"], 2, axis=0), 2, axis=1)
            sample["solar_disk"] = np.repeat(np.repeat(sample["solar_disk"], 2, axis=0), 2, axis=1)
            sample["excentricity_map"] = np.repeat(np.repeat(sample["excentricity_map"], 2, axis=0), 2, axis=1)
            # print(sample["image"].shape)
            
            # # also double the delta shape values
            sample['deltashapeX'] = sample['deltashapeX']*2
            sample['deltashapeY'] = sample['deltashapeY']*2

            # also double the centroid values
            sample['centroid_px'] = sample['centroid_px']*2


            ax[1,0].imshow(sample["image"], cmap='gray', interpolation='none')
            ax[1,1].imshow(sample["mask"], cmap='gray', interpolation='none')
            ax[1,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
            ax[1,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
            ax[1,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
            # scatter the centroid
            ax[1,0].scatter(sample['centroid_px'][0], sample['centroid_px'][1], c='r', s=10)



            fig.tight_layout()

            # show differences
            
        flip_time = "2003-03-08T00:00:00"
        date = whitelight_to_datetime(basename)
        datetime_str = datetime_to_db_string(date).replace(' ', 'T')
        # print(datetime_str)
        should_flip = (datetime.fromisoformat(datetime_str) - datetime.fromisoformat(flip_time)) < timedelta(0)
        sample['should_flip'] = should_flip

        if should_flip:
            sample['image'] = np.flip(sample['image'],axis=0)
            sample['solar_disk'] = np.flip(sample['solar_disk'],axis=0)
            sample['mask'] = np.flip(sample['mask'],axis=0)
            sample['confidence_map'] = np.flip(sample['confidence_map'],axis=0)
            sample['excentricity_map'] = np.flip(sample['excentricity_map'],axis=0)


#         st = time.time()
        if self.transforms is not None and do_transform:
            sample = self.transforms(**sample)
#         print_elapsed_time(st, 'transform')

        # fig,ax = plt.subplots(1, 5, figsize=(10, 4) )
        # ax[0].imshow(sample["image"], cmap='gray', interpolation='none')
        # ax[1].imshow(sample["mask"], cmap='gray', interpolation='none')
        # ax[2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
        # ax[3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
        # ax[4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
        

        return sample
        

In [3]:
# root_dir, partition, dtypes, classes,
# first_classes, second_classes, third_classes, 
# json_file, classification='SuperClass' , transforms=None
# root_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019"
root_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2"
partition = 'train'
dtypes = ['image', 'T425-T375-T325_fgbg']
# dtypes = ['image', 'feb2023/T425-T375-T325_fgbg']
classes = ['A','B','C','SuperGroup','H']
first_classes = [ 'A','B','C','SuperGroup','H']
second_classes= [ 'x','r','sym','asym']
third_classes= [ "x","o","frag"]
json_file = 'test/dataset_final.json'
classification = 'SuperClass'
transforms = OmegaConf.load('/home/ucl/elen/nsayez/bio-blueprints/bioblue/conf/exp/Classification_Superclasses4.yaml').dataset.train_dataset.transforms

transforms[3].standard_height = 350
transforms[3].standard_width = 350

print(transforms)

dataset =  ClassificationDatasetSuperclasses(root_dir, partition, dtypes, classes, first_classes, second_classes, third_classes, json_file, classification, transforms)

[{'_target_': 'bioblue.transforms.DeepsunScaleWhitelight'}, {'_target_': 'bioblue.transforms.DeepsunScaleExcentricityMap'}, {'_target_': 'bioblue.transforms.DeepsunScaleConfidenceMap'}, {'_target_': 'bioblue.transforms.DeepsunRotateAndCropAroundGroup_Focus_Move', 'standard_height': 350, 'standard_width': 350, 'focus_on_group': '${focus_on_group}', 'random_move': '${random_move}', 'random_move_percent': '${random_move_percent}'}, {'_target_': 'bioblue.transforms.DeepsunMcIntoshScaleAdditionalInfo'}]


Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.10 it will stop working
Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.10 it will stop working


In [4]:
# dataset[259]
# pass

In [5]:
partition_samples = {}

In [6]:
def get_sample(idx, dataset):
    tmp = dataset[idx]
    center  = tmp['image'].shape[0]//2 , tmp['image'].shape[1]//2
    slice_x = center[0]-( 256 //2 ), center[0]+ (256 //2 )
    slice_y = center[1]-( 256 //2 ), center[1]+ (256 //2 )
    center_region = tmp['confidence_map'][slice_x[0]:slice_x[1], slice_y[0]:slice_y[1]]
    if np.sum(center_region) == 0:
        print('skipping sample', idx)
        return None
    return dataset[idx]

num_cpu = 15


with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
    for sample in tqdm(executor.map(get_sample, range(len(dataset)) , repeat(deepcopy(dataset)))):
    # for sample in tqdm(executor.map(get_sample, range(10) , repeat(deepcopy(dataset)))):
        if sample is not None:
            if sample['group_name'] not in partition_samples:
                partition_samples[sample['group_name']] = sample


0it [00:00, ?it/s]

skipping sample 288
skipping sample 473
skipping sample 477
skipping sample 707
skipping sample 777
skipping sample 946
skipping sample 1130
skipping sample 1166
skipping sample 2326


In [None]:
npy_file = os.path.join(root_dir, 'test', f'all_samples_{partition}.npy')

np.save(npy_file, partition_samples)

In [173]:
# loaded = np.load(npy_file, allow_pickle=True).item()

# for k, v in loaded.items():
#     print(k)
#     print(v.keys())
#     print(v['class1'])

# Join the train, val, test .npy files

Je sais pas pourqueoi mais le kernel meurt à chaque fois pendant le dump.... à la place, exécuter le script 'dataset_joiner.py' cans ce dossier.

In [4]:
all_samples ={'train':{},'val':{},'test':{}}
for p in all_samples.keys():
    print('loading', p)
    st = time.time()    
    filename = os.path.join(root_dir,'test',f'all_samples_{p}.npy' )
    tmp = np.load(filename, allow_pickle=True).item()
    print('Elapsed time', time.time()-st)
    
    all_samples[p] = tmp

loading train
Elapsed time 15.252368927001953
loading val
Elapsed time 3.4969048500061035
loading test
Elapsed time 3.5309314727783203


In [5]:
print("Dumping")
st = time.time()    
tot_npy_file = os.path.join(root_dir, 'test', f'all_samples.npy')
np.save(tot_npy_file, all_samples)
print('Elapsed time', time.time()-st)

Dumping


: 

: 

# Test

In [42]:
def print_elapsed_time(st, msg):
        end_time = time.time()
        print(f'Elapsed time {msg}: {end_time - st}')
        
class Deepsun_Focus_Move(DualTransform):
    def __init__(self, standard_height=256, standard_width=256,
                        focus_on_group=True,
                        random_move=False, random_move_percent=0.1,  
                        always_apply=False, p=1.0) -> None:

        super().__init__(always_apply, p)

        self.standard_height = standard_height
        self.standard_width = standard_width
        
        self.focus_on_group = focus_on_group
        self.random_move = random_move
        self.random_move_percent = random_move_percent

    def get_bounding_box_around_group_with_padding(self, mask, offset):
        # Get the bounding box around non-zero pixels in mask
        x, y = np.nonzero(mask)
        # print(x, y)
        x1, x2 = (np.min(x), np.max(x)) if len(x) > 0 else (None, None)
        y1, y2 = (np.min(y), np.max(y)) if len(y) > 0 else (None, None)

        if (x1 is None) or (y1 is None):
            return 0, mask.shape[0]-1, 0, mask.shape[1]-1

        # Add padding
        x1 -= offset
        x2 += offset
        y1 -= offset
        y2 += offset


        # Make sure the bounding box is not outside the image
        x1 = max(x1, 0)
        x2 = min(x2, mask.shape[0])
        y1 = max(y1, 0)
        y2 = min(y2, mask.shape[1])

        return x1, x2, y1, y2

    def adapt_bbox_to_image_size(self, bbox, image_size):
        bbox_center = ((bbox[0] + bbox[1]) // 2, (bbox[2] + bbox[3]) // 2)
        bbox_size = (bbox[1] - bbox[0], bbox[3] - bbox[2])

        # if bbox is too small, expand it
        minimal_percentage = .4

        bbox_size = (max(bbox_size[0], image_size[0] * minimal_percentage),
                     max(bbox_size[1], image_size[1] * minimal_percentage))
        
        return (int(bbox_center[0] - bbox_size[0] // 2), int(bbox_center[0] + bbox_size[0] // 2),
                int(bbox_center[1] - bbox_size[1] // 2), int(bbox_center[1] + bbox_size[1] // 2))

    def crop_img(self, img, bbox):
        # Crop image
        x1, x2, y1, y2 = bbox
        img = img[x1:x2, y1:y2]
        return img
        
    def padding(self, array, xx, yy):
        """
        :param array: numpy array
        :param xx: desired height
        :param yy: desirex width
        :return: padded array
        """

        h = array.shape[0]
        w = array.shape[1]

        a = (xx - h) // 2
        aa = xx - a - h

        b = (yy - w) // 2
        bb = yy - b - w

        # print(a,aa,b,bb)

        a = max(a,0)
        b = max(b,0)
        aa = max(aa,0)
        bb = max(bb,0)


        # print('->',a,aa,b,bb)

        return np.pad(array, pad_width=((a, aa), (b, bb)), mode='constant')

    def crop_and_pad(self, img, bbox, image_size):
        # print('crop_and_pad')
        # print(f'bbox: {bbox}, image_size: {image_size}, img.shape: {img.shape}')
        # Crop image
        img = self.crop_img(img, bbox)
        # Pad image
        img = self.padding(img, image_size[0], image_size[1])
        return img

    def data_aug_random_move(self, bbox, max_offset):
        '''
        Randomly move the bounding box
        param bbox: bounding box
        param max_offset: maximum offset in portion of the bbox size
        '''
        # Randomly move the bounding box
        x1, x2, y1, y2 = bbox
        horizontal_offset = (np.random.random(1) * 2*max_offset) - max_offset
        vertical_offset = (np.random.random(1) * 2*max_offset) - max_offset
        # print(f'horizontal_offset: {horizontal_offset}, vertical_offset: {vertical_offset}')

        x1 += int(horizontal_offset * (bbox[1] - bbox[0]))
        x2 += int(horizontal_offset * (bbox[1] - bbox[0]))
        y1 += int(vertical_offset * (bbox[3] - bbox[2]))
        y2 += int(vertical_offset * (bbox[3] - bbox[2]))
        
        return x1, x2, y1, y2
        
    def __call__(self, *args, force_apply=False, **kwargs):

        img_group_crop = kwargs['image'].copy()
        msk_group_crop = kwargs['mask'].copy()
        grp_msk_group_crop = kwargs['group_mask'].copy()
        disk_group_crop = kwargs['solar_disk'].copy()
        excentricity_group_crop = kwargs['excentricity_map'].copy()
        confidence_group_crop = kwargs['confidence_map'].copy()
        grp_confidence_group_crop = kwargs['group_confidence_map'].copy()
        
        shape  = img_group_crop.shape
        # minX, maxX, minY, maxY =  ((shape[0]//2)-self.standard_height//2, 
        #                             (shape[0]//2)+self.standard_height//2, 
        #                             (shape[1]//2)-self.standard_width//2, 
        #                             (shape[1]//2)+self.standard_width//2)

        # bbox format = x1, x2, y1, y2
        bbox = self.get_bounding_box_around_group_with_padding((grp_confidence_group_crop>0), 10)

        # print(bbox)
        # print(bbox[1]-bbox[0], bbox[3]-bbox[2])
        
        minX, maxX, minY, maxY =  (
                                    ((bbox[1]+bbox[0])//2)-(self.standard_height//2), 
                                    ((bbox[1]+bbox[0])//2)+(self.standard_height//2), 
                                    ((bbox[3]+bbox[2])//2)-(self.standard_width//2), 
                                    ((bbox[3]+bbox[2])//2)+(self.standard_width//2)
                                    
                                    )
        # print("new_shape minmax ",minX, maxX, minY, maxY)
        
            
        if self.focus_on_group:
            # focus on the group
            # print('focus on group')
            # Modify the bounding box if data augmentation is enabled
            if self.random_move:
                # print('random_move')
                bbox = self.data_aug_random_move(bbox, max_offset=self.random_move_percent)
                                    
                # Make sure the bounding box is not outside the image
                x1, x2, y1, y2 = bbox
                x1 = max(x1, 0)
                x2 = min(x2, self.standard_width)
                y1 = max(y1, 0)
                y2 = min(y2, self.standard_height)
                bbox = x1, x2, y1, y2
            else:
                # print('no random_move')
                pass
                
            bbox = self.adapt_bbox_to_image_size( bbox, (self.standard_height, self.standard_width))
            # print(bbox)
            img_group_crop = self.crop_and_pad(img_group_crop, bbox, (self.standard_height, self.standard_width))
            msk_group_crop = self.crop_and_pad(msk_group_crop, bbox, (self.standard_height, self.standard_width))
            grp_msk_group_crop = self.crop_and_pad(grp_msk_group_crop, bbox, (self.standard_height, self.standard_width))
            disk_group_crop = self.crop_and_pad(disk_group_crop, bbox, (self.standard_height, self.standard_width))
            excentricity_group_crop = self.crop_and_pad(excentricity_group_crop, bbox, (self.standard_height, self.standard_width))
            confidence_group_crop = self.crop_and_pad(confidence_group_crop, bbox, (self.standard_height, self.standard_width))
            grp_confidence_group_crop = self.crop_and_pad(grp_confidence_group_crop, bbox, (self.standard_height, self.standard_width))
        else:
            # print('NO focus on group')
            if self.random_move:
                # print('random_move')
                frac = np.max([(bbox[1]-bbox[0]) /self.standard_height, (bbox[3]-bbox[2]) /self.standard_width])
                frac = np.sqrt(frac)
                # print(minX, maxX, minY, maxY , frac, self.random_move_percent)
                bbox = self.data_aug_random_move([minX,maxX,minY,maxY], max_offset=self.random_move_percent*frac)
                # print('bbox',bbox)
                if not ((bbox[1] > shape[0]) or (bbox[3] > shape[1]) or (bbox[0] < 0) or (bbox[2] < 0)):
                    bbox = [bbox[0], bbox[0]+self.standard_height, bbox[2], bbox[3]]
                    # x1, x2, y1, y2 = bbox
                    # x1 = max(x1, 0)
                    # x2 = min(x2, self.standard_width)
                    # y1 = max(y1, 0)
                    # y2 = min(y2, self.standard_height)
                    minX,maxX,minY,maxY = bbox
                # print(minX, maxX, minY, maxY )
            else:
                # print('no random_move')
                pass
                
            img_group_crop = img_group_crop[minX:maxX,minY:maxY]
            msk_group_crop = msk_group_crop[minX:maxX,minY:maxY]
            grp_msk_group_crop = grp_msk_group_crop[minX:maxX,minY:maxY]
            disk_group_crop = disk_group_crop[minX:maxX,minY:maxY]
            excentricity_group_crop = excentricity_group_crop[minX:maxX,minY:maxY]
            confidence_group_crop = confidence_group_crop[minX:maxX,minY:maxY]
            grp_confidence_group_crop = grp_confidence_group_crop[minX:maxX,minY:maxY]

        if img_group_crop.shape != (self.standard_height, self.standard_width):
            img_group_crop = self.padding(img_group_crop, self.standard_height, self.standard_width)
            msk_group_crop = self.padding(msk_group_crop, self.standard_height, self.standard_width)
            grp_msk_group_crop = self.padding(grp_msk_group_crop, self.standard_height, self.standard_width)
            disk_group_crop = self.padding(disk_group_crop, self.standard_height, self.standard_width)
            excentricity_group_crop = self.padding(excentricity_group_crop, self.standard_height, self.standard_width)
            confidence_group_crop = self.padding(confidence_group_crop, self.standard_height, self.standard_width)
            grp_confidence_group_crop = self.padding(grp_confidence_group_crop, self.standard_height, self.standard_width)


        assert img_group_crop.shape == (self.standard_height, self.standard_width)
        assert msk_group_crop.shape == (self.standard_height, self.standard_width)
        assert grp_msk_group_crop.shape == (self.standard_height, self.standard_width)
        assert disk_group_crop.shape == (self.standard_height, self.standard_width)
        assert excentricity_group_crop.shape == (self.standard_height, self.standard_width)
        assert confidence_group_crop.shape == (self.standard_height, self.standard_width)
        assert grp_confidence_group_crop.shape == (self.standard_height, self.standard_width)

        # self.print_elapsed_time(st, 'focusMove Operations')
        # print('after focus_move call',img_group_crop.shape)


        kwargs['image'] = img_group_crop.copy()
        kwargs['mask'] = msk_group_crop.copy()
        kwargs['group_mask'] = grp_msk_group_crop.copy()
        kwargs['solar_disk'] = disk_group_crop.copy()
        kwargs['excentricity_map'] = excentricity_group_crop.copy()
        kwargs['confidence_map'] = confidence_group_crop.copy()
        kwargs['group_confidence_map'] = grp_confidence_group_crop.copy()
        
        return kwargs

 


class ClassificationDatasetSuperclasses_fast(Dataset):

    def __init__(
        self, root_dir, partition, dtypes, classes,
        first_classes, second_classes, third_classes, 
        dataset_file, classification='SuperClass' , transforms=None) -> None:
        super().__init__()
        if isinstance(transforms, collections.Mapping):
            transforms = partial(call, config=transforms)
        elif isinstance(transforms, collections.Sequence):
            transforms_init = []
            for transform in transforms:
                transforms_init.append(instantiate(transform))
            transforms = Compose(transforms_init)

        self.transforms = transforms
        self.root_dir = Path(root_dir)

        self.main_dtype = dtypes[0]
        self.mask_dtype = dtypes[1]

        # self.json_file = self.root_dir / json_file 
        self.disk_file = self.root_dir / dataset_file.replace('.', '_'+partition+'.')
        # print(self.json_file)
        self.partition_dict = None

        self.c1_mapper = {c: i for i,c in enumerate(first_classes)}
        self.c2_mapper = {c: i for i,c in enumerate(second_classes)}
        self.c3_mapper = {c: i for i,c in enumerate(third_classes)}
        
        st = time.time()
        # print('Loading npy dataset')
        dataset = np.load(self.disk_file, allow_pickle=True).item()
        # print_elapsed_time(st, 'Loading npy dataset')
        
        # print(dataset.keys())
    
        # self.partition_dict = dataset[partition]
        self.partition_dict = dataset

        self.classification = classification

        self.FirstClass_mapper = {c: i for i,c in enumerate(first_classes)}
        self.SecondClass_mapper = {c: i for i,c in enumerate(second_classes)}
        self.ThirdClass_mapper = {c: i for i,c in enumerate(third_classes)}

        # print(self.files)

        # print(list(self.partition_dict.values())[0])

        self.groups = {}
        for k,v in self.partition_dict.items():

            if v["class1"] in classes:
                    self.groups[k] = v
            else:
                # print(v[self.classification]["1"])
                pass

        self.dataset_length = len(list(self.groups.keys()))


    def __len__(self) -> int:
        # raise NotImplementedError
        # print(self.dataset_length)
        return self.dataset_length
        # return 10
    
    def __getitem__(self, index: int, do_transform=True):

        idx = list(self.groups.keys())[index]
        
        sample = deepcopy(self.groups[idx])

        # ['image']
        # ['mask']
        # ['group_mask']
        # ['solar_disk']
        # ['excentricity_map']
        # ['confidence_map']
        # ['group_confidence_map']

        # print(sample['group_name'])

        sample['class1'] = np.array([self.FirstClass_mapper[sample['class1']]])
        sample['class2'] = np.array([self.SecondClass_mapper[sample['class2']]])
        sample['class3'] = np.array([self.ThirdClass_mapper[sample['class3']]])

       
        fig,ax = plt.subplots(2, 5, figsize=(10, 4) )
        ax[0,0].set_title(sample['group_name'])
        ax[0,0].imshow(sample["image"], cmap='gray', interpolation='none')
        ax[0,1].imshow(sample["mask"], cmap='gray', interpolation='none')
        ax[0,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
        ax[0,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
        ax[0,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')


        tmp = np.argwhere(sample['confidence_map'] > 0)
        ax[0,2].scatter(tmp[:,1], tmp[:,0], s=1, c='r', alpha=0.5)

        sample['image'][sample['excentricity_map'] < 0] = 0
        sample['mask'][sample['excentricity_map'] < 0] = 0
        sample['group_mask'][sample['excentricity_map'] < 0] = 0
        sample['solar_disk'][sample['excentricity_map'] < 0] = 0
        sample['confidence_map'][sample['excentricity_map'] < 0] = 0
        sample['group_confidence_map'][sample['excentricity_map'] < 0] = 0
        sample['group_confidence_map'][sample['excentricity_map'] > 0.95] = 0
        
        sample['excentricity_map'][sample['excentricity_map'] < 0] = 0
        
        # st = time.time()
        if self.transforms is not None and do_transform:
            sample = self.transforms(**sample)
        # print_elapsed_time(st, 'transform')

        

        ax[1,0].imshow(sample["image"], cmap='gray', interpolation='none')
        ax[1,1].imshow(sample["mask"], cmap='gray', interpolation='none')
        ax[1,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
        ax[1,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
        ax[1,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
        
        fig.tight_layout()

        tmp = np.argwhere(sample['confidence_map'] > 0)
        ax[1,2].scatter(tmp[:,1], tmp[:,0], s=1, c='r', alpha=0.5)

        return sample
        

In [43]:

root_dir = "/globalscratch/users/n/s/nsayez/Classification_dataset/2002-2019_2"
partition='train'
dtypes = ['image', 'T425-T375-T325_fgbg']

classes = ['A','B','C','SuperGroup','H']
first_classes = [ 'A','B','C','SuperGroup','H']
second_classes= [ 'x','r','sym','asym']
third_classes= [ "x","o","frag"]
classification='SuperClasses'

transforms2 = [{'_target_': Deepsun_Focus_Move, 'standard_height': 256, 'standard_width': 256,
                 'focus_on_group': False, 'random_move': False, 'random_move_percent': .2}]
dataset_file = os.path.join(root_dir, 'test', f'all_samples.npy')
d2 = ClassificationDatasetSuperclasses_fast(root_dir, partition, dtypes, classes, first_classes, second_classes, third_classes, 
                                                dataset_file, classification, transforms2)
# for i in range(5):
#     d2[i]
# pass

In [44]:
d2[882]
# d2[2227]
# d2[2649]
pass

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
for i in tqdm(range(len(d2))):
    try:
        tmp = d2[i]
    except:
        print(tmp['image'].shape)
        print(i)
        pass

  0%|          | 0/2704 [00:00<?, ?it/s]

(256, 256)
882
(256, 256)
2227
(256, 256)
2649


In [6]:
# VAL avait des erruers aux indices 94 / 196
# d2[94]
# d2[96]
# TRAIN avait des erreurs aux indices 288 / 473 / 477 / 707 / 777 / 946 / 1130 / 1166 / 2326
# d2[288]
# d2[473]
# d2[477]
# d2[707]
# d2[777]
# d2[946]
# d2[1130]
# d2[1166]
# d2[2326]


pass

In [5]:
for i in range(10):
    d2[i]
pass

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# TEST

In [5]:

few_samples = [dataset[i] for i in range(2000,2010)]

Elapsed time transform: 0.4652729034423828
Elapsed time transform: 0.43789052963256836
Elapsed time transform: 0.4377424716949463
Elapsed time transform: 0.3622429370880127
Elapsed time transform: 0.35805320739746094
Elapsed time transform: 0.43010497093200684
Elapsed time transform: 0.3954179286956787
Elapsed time transform: 0.36830782890319824
Elapsed time transform: 0.3276553153991699
Elapsed time transform: 0.311431884765625


In [171]:
# # create a pickle file containing the samples in few_samples
# with open('few_samples.pkl', 'wb') as f:
#     pickle.dump(few_samples, f)

npy_file = 'few_samples.npy'

my_dict ={'train': {s["group_name"]: s for s in few_samples}}
np.save(npy_file, my_dict)


In [172]:
! mv few_samples.npy $root_dir
# np.load(npy_file, allow_pickle=True)