In [1]:
from email.mime import base
from tokenize import group
from sunscc.dataset.transform.pipelines import Compose
import collections
from functools import partial

from torch.utils.data import Dataset, DataLoader
import os
import random
from hydra.utils import call, instantiate
from pathlib import Path
import skimage.io as io
import skimage
from skimage.measure import label, regionprops
import skimage.exposure
from numpy.random import default_rng

import json

import numpy as np

from tqdm.notebook import tqdm

import traceback
import time

from albumentations.core.transforms_interface import DualTransform

from astropy.io import fits
from sunscc.dataset.utils import *

from omegaconf import DictConfig, OmegaConf
import matplotlib.pyplot as plt


from datetime import datetime, timedelta

from copy import deepcopy

import pickle

import concurrent.futures
from itertools import repeat
import multiprocessing

from numpy.random import default_rng

%matplotlib widget
# %matplotlib inline

In [2]:
def print_elapsed_time(st, msg):
        end_time = time.time()
        print(f'Elapsed time {msg}: {end_time - st}')

class ClassificationDatasetSuperclasses(Dataset):

    def __init__(
        self, root_dir, partition, dtypes, classes,
        first_classes, second_classes, third_classes, 
        json_file, classification='SuperClass' , transforms=None) -> None:
        super().__init__()
        if isinstance(transforms, collections.abc.Mapping):
            transforms = partial(call, config=transforms)
        elif isinstance(transforms, collections.abc.Sequence):
            transforms_init = []
            for transform in transforms:
                transforms_init.append(instantiate(transform))
            transforms = Compose(transforms_init)

        self.transforms = transforms
        self.root_dir = Path(root_dir)

        self.main_dtype = dtypes[0]
        self.mask_dtype = dtypes[1]

        self.json_file = self.root_dir / json_file 
        # print(self.json_file)
        self.partition_dict = None

        self.c1_mapper = {c: i for i,c in enumerate(first_classes)}
        self.c2_mapper = {c: i for i,c in enumerate(second_classes)}
        self.c3_mapper = {c: i for i,c in enumerate(third_classes)}


        with open(self.json_file, 'r') as f:
            self.partition_dict = json.load(f)[partition]

        
        assert (classification == 'Zurich') or (classification == 'McIntosh') or (classification == 'SuperClass')

        self.classification = classification

        self.FirstClass_mapper = {c: i for i,c in enumerate(first_classes)}
        self.SecondClass_mapper = {c: i for i,c in enumerate(second_classes)}
        self.ThirdClass_mapper = {c: i for i,c in enumerate(third_classes)}

        # print(classes)
        self.files = {}
        for i, bn in enumerate(sorted(list(self.partition_dict.keys()))):
            bn = bn.split('_')[0]
            # print(bn)
            cur = {}
            image_basename = bn + '.FTS'
            image_filename = self.root_dir / self.main_dtype / image_basename

            sun_mask_filename = self.root_dir / 'sun_mask' / (bn + '.png')


            mask_basename = bn + '.png'
            mask_filename = self.root_dir / self.mask_dtype / mask_basename

            conf_map_basename = bn + '_proba_map.npy'
            conf_map_filename = self.root_dir / self.mask_dtype / conf_map_basename

            cur["name"] = bn
            cur[self.main_dtype] = image_filename
            cur[self.mask_dtype] = mask_filename
            cur[self.mask_dtype+"_conf_map"] = conf_map_filename
            cur["sun_mask"] = sun_mask_filename

            self.files[bn] = cur

        # print(self.files)

        self.partition_dict

        # print(list(self.partition_dict.values())[0])

        self.groups = {}
        for k,v in self.partition_dict.items():

            if v[self.classification]["1"] in classes:
                    self.groups[k] = v
            else:
                # print(v[self.classification]["1"])
                pass

      

        self.dataset_length = len(list(self.groups.keys()))



    def __len__(self) -> int:
        # raise NotImplementedError
        # print(self.dataset_length)
        return self.dataset_length
        # return 10
    
    def __getitem__(self, index: int, do_transform=True):

        # st = time.time()

        sample = {} # dictionnary with 'image', 'class', 'angular_excentricity', 'centroid_lat'

        # basename = self.files[index]["name"]
        k = sorted(list(self.groups.keys()))[index]
        # print(k)
        basename = k.split('_')[0]


        # image_out_dict = self.partition_dict[basename]
        group_dict = self.groups[k]

        # print(group_dict)

        img_name = self.files[basename][self.main_dtype] # path of FITS file
        mask_name = self.files[basename][self.mask_dtype]
        conf_map_name = self.files[basename][self.mask_dtype+"_conf_map"]

        # print(img_name)
        # st =  time.time()
        hdulst:fits.HDUList = fits.open(img_name)
        image = hdulst[0]
        header = image.header
        center = np.array(image.shape)//2
        radius = header['SOLAR_R']
        
        # st = time.time()
        sample['solar_disk'] = io.imread(self.files[basename]["sun_mask"])
        # print_elapsed_time(st, 'load sun mask')



        # st = time.time()
        sample['excentricity_map'] = create_excentricity_map(sample['solar_disk'], radius, value_outside=-1)
        # print_elapsed_time(st, 'create_excentricity_map')

        # st = time.time()
        sample['mask'] = io.imread(mask_name)#.astype(float)
        sample['confidence_map'] = np.load(conf_map_name)
        # print_elapsed_time(st, 'load mask and conf map')

        # st = time.time()
        sample['image'] = (image.data).astype(float)

        sample['members'] = np.array(group_dict['members']) if 'members' in group_dict else np.array([0])
        sample['members_mean_px'] = np.array(group_dict['members_mean_px']) if 'members_mean_px' in group_dict else np.array([0])

        sample['name'] = basename
        sample['group_name'] = k

        sample['solar_angle'] = group_dict['angle']
        sample['deltashapeX'] = group_dict['deltashapeX']
        sample['deltashapeY'] = group_dict['deltashapeY']
        
        sample['angular_excentricity'] = np.array([group_dict["angular_excentricity_deg"]])
        sample['centroid_px'] = np.array(group_dict["centroid_px"])
        # print(sample['centroid_px'])
        sample['centroid_Lat'] = np.array([group_dict["centroid_Lat"]])

        # sample['class'] = np.array([self.classes_mapper[group_dict[self.classification]]])
        sample['class1'] = group_dict[self.classification]['1']
        sample['class2'] = group_dict[self.classification]['2']
        sample['class3'] = group_dict[self.classification]['3']
        # sample['class1'] = np.array([self.FirstClass_mapper[group_dict[self.classification]['1']]])
        # sample['class2'] = np.array([self.SecondClass_mapper[group_dict[self.classification]['2']]])
        # sample['class3'] = np.array([self.ThirdClass_mapper[group_dict[self.classification]['3']]])
        # print_elapsed_time(st, 'remaining operations')

        if sample["image"].shape == (1024,1024):
            fig,ax = plt.subplots(2, 5, figsize=(10, 4) )
            ax[0,0].imshow(sample["image"], cmap='gray', interpolation='none')
            ax[0,1].imshow(sample["mask"], cmap='gray', interpolation='none')
            ax[0,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
            ax[0,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
            ax[0,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
            # scatter the centroid
            ax[0,0].scatter(sample['centroid_px'][0], sample['centroid_px'][1], c='r', s=10)

            # print(sample["image"].shape)
            # double the resolution to 2048x2048 of all visual data
            sample["image"] = np.repeat(np.repeat(sample["image"], 2, axis=0), 2, axis=1)
            sample["mask"] = np.repeat(np.repeat(sample["mask"], 2, axis=0), 2, axis=1)
            sample["confidence_map"] = np.repeat(np.repeat(sample["confidence_map"], 2, axis=0), 2, axis=1)
            sample["solar_disk"] = np.repeat(np.repeat(sample["solar_disk"], 2, axis=0), 2, axis=1)
            sample["excentricity_map"] = np.repeat(np.repeat(sample["excentricity_map"], 2, axis=0), 2, axis=1)
            # print(sample["image"].shape)
            
            # # also double the delta shape values
            sample['deltashapeX'] = sample['deltashapeX']*2
            sample['deltashapeY'] = sample['deltashapeY']*2

            # also double the centroid values
            sample['centroid_px'] = sample['centroid_px']*2


            ax[1,0].imshow(sample["image"], cmap='gray', interpolation='none')
            ax[1,1].imshow(sample["mask"], cmap='gray', interpolation='none')
            ax[1,2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
            ax[1,3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
            ax[1,4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
            # scatter the centroid
            ax[1,0].scatter(sample['centroid_px'][0], sample['centroid_px'][1], c='r', s=10)



            fig.tight_layout()

            # show differences
            
        flip_time = "2003-03-08T00:00:00"
        date = whitelight_to_datetime(basename)
        datetime_str = datetime_to_db_string(date).replace(' ', 'T')
        # print(datetime_str)
        should_flip = (datetime.fromisoformat(datetime_str) - datetime.fromisoformat(flip_time)) < timedelta(0)
        sample['should_flip'] = should_flip

        if should_flip:
            sample['image'] = np.flip(sample['image'],axis=0)
            sample['solar_disk'] = np.flip(sample['solar_disk'],axis=0)
            sample['mask'] = np.flip(sample['mask'],axis=0)
            sample['confidence_map'] = np.flip(sample['confidence_map'],axis=0)
            sample['excentricity_map'] = np.flip(sample['excentricity_map'],axis=0)


#         st = time.time()
        if self.transforms is not None and do_transform:
            sample = self.transforms(**sample)
#         print_elapsed_time(st, 'transform')

        # fig,ax = plt.subplots(1, 5, figsize=(10, 4) )
        # ax[0].imshow(sample["image"], cmap='gray', interpolation='none')
        # ax[1].imshow(sample["mask"], cmap='gray', interpolation='none')
        # ax[2].imshow(sample["confidence_map"], cmap='gray', interpolation='none')
        # ax[3].imshow(sample["solar_disk"], cmap='gray', interpolation='none')
        # ax[4].imshow(sample["excentricity_map"], cmap='gray', interpolation='none')
        

        return sample
        

# Create NPY files for faster dataset loading during training and test

For each dataset (sunscc_all_revised / sunscc_overlaps_only / sunscc_no_overlap):

Run all the following cells to create the npy of a given split (train/val/test).



In [3]:
root_dir = "../../datasets/classification/2002-2019_2"

sub_dir = 'sunscc_all_revised'
# sub_dir = 'sunscc_OverlapOnly'

partition = 'train'
# partition = 'val'
# partition = 'test'

dtypes = ['image', 'T425-T375-T325_fgbg']

classes = ['A','B','C','SuperGroup','H']
first_classes = [ 'A','B','C','SuperGroup','H']
second_classes= [ 'x','r','sym','asym']
third_classes= [ "x","o","frag"]
json_file = f'{sub_dir}/dataset_revised.json'

classification = 'SuperClass'
transforms = OmegaConf.load('../../sunscc/conf/exp/Classification_Superclasses4.yaml').dataset.train_dataset.transforms

transforms[3].standard_height = 350
transforms[3].standard_width = 350

print(transforms)

dataset =  ClassificationDatasetSuperclasses(root_dir, partition, dtypes, classes, first_classes, second_classes, third_classes, json_file, classification, transforms)

[{'_target_': 'sunscc.transforms.DeepsunScaleWhitelight'}, {'_target_': 'sunscc.transforms.DeepsunScaleExcentricityMap'}, {'_target_': 'sunscc.transforms.DeepsunScaleConfidenceMap'}, {'_target_': 'sunscc.transforms.DeepsunRotateAndCropAroundGroup_Focus_Move', 'standard_height': 350, 'standard_width': 350, 'focus_on_group': '${focus_on_group}', 'random_move': '${random_move}', 'random_move_percent': '${random_move_percent}'}, {'_target_': 'sunscc.transforms.DeepsunMcIntoshScaleAdditionalInfo'}]


In [4]:
partition_samples = {}

In [5]:
def get_sample(idx, dataset):
    tmp = dataset[idx]
    center  = tmp['image'].shape[0]//2 , tmp['image'].shape[1]//2
    slice_x = center[0]-( 256 //2 ), center[0]+ (256 //2 )
    slice_y = center[1]-( 256 //2 ), center[1]+ (256 //2 )
    center_region = tmp['confidence_map'][slice_x[0]:slice_x[1], slice_y[0]:slice_y[1]]
    if np.sum(center_region) == 0:
        print('skipping sample', idx)
        return None
    return dataset[idx]

num_cpu = 15


with concurrent.futures.ProcessPoolExecutor(max_workers=int(num_cpu)) as executor:
    for sample in tqdm(executor.map(get_sample, range(len(dataset)) , repeat(deepcopy(dataset)))):
    # for sample in tqdm(executor.map(get_sample, range(10) , repeat(deepcopy(dataset)))):
        if sample is not None:
            if sample['group_name'] not in partition_samples:
                partition_samples[sample['group_name']] = sample


0it [00:00, ?it/s]

skipping sample 74
skipping sample 489
skipping sample 488
skipping sample 739
skipping sample 743
skipping sample 751
skipping sample 822
skipping sample 1031
skipping sample 1038
skipping sample 1068
skipping sample 1141
skipping sample 1367
skipping sample 1583
skipping sample 1632
skipping sample 3205


In [6]:
npy_file = os.path.join(root_dir, sub_dir, f'all_samples_{partition}.npy')

np.save(npy_file, partition_samples)