In [1]:
import sys
import os
import numpy as np
import pandas as pd
import random
sys.path.append('../source')
from utils import *

In [2]:

def image_filter(images: dict):
    # for key, img in images.copy().items():
    #     if get_percent_coverage(img) != 0:
    #         del images[key]
    keys = random.sample(sorted(images.keys()), 10)
    return {k: images[k] for k in keys}

def mask_filter(masks: dict, ):
    for key, mask in masks.copy().items():
        print(get_percent_coverage(mask))
        if get_percent_coverage(mask) != 0:
            del masks[key]
    return masks

def generate_lst_dataset(
    images: list, 
    masks: list, 
    patch_size: tuple, 
    image_filter,
    mask_filter,
    save_path: str,
    limit=5000
    ):
    """Generates LST dataset by retiling images masks. Masks are named by originalname_r_c

    Args:
        images (list): hdf images to retile
        masks (list): hdf masks to retile
        patch_size (tuple): size of patches
        image_filter (function): function returning a filtered list of images
        mask_filter (function): function returning a filtered list of masks
        save_path (str): directory to save images and masked images
        limit (int, optional): set limit of images to save. Defaults to 5000.
    """
    collection = pd.DataFrame(columns=['sample', 'ground_truth'])
    
    for image in images:
        basename = os.path.splitext(os.path.basename(image))[0]
        lst = get_lst_day(image)
        lst_tiles = retile_and_name(lst, basename, patch_size)
        if image_filter is not None:
            lst_tiles = image_filter(lst_tiles)
    
    for mask in masks:
        basename = os.path.splitext(os.path.basename(image))[0]
        mask = get_cloud_mask(mask)
        mask_tiles = retile_and_name(mask, basename, patch_size)
        if mask_filter is not None:
            mask_tiles = mask_filter(mask_tiles)
    
    assert len(mask_tiles) * len(lst_tiles) <= limit, 'Limit Exceeded ({} total images)'.format(len(mask_tiles) * len(lst_tiles))
    assert len(mask_tiles) > 0 and len(lst_tiles) > 0, 'Either masks or images are empty'
    assert len(mask_tiles) == 10 and len(lst_tiles) == 10, 'WRONG'
    assert len(os.listdir(save_path)) != 0, 'Save path not empty' 
    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not os.path.exists(os.path.join(save_path, 'ground_truth')):
        os.makedirs(os.path.join(save_path, 'ground_truth'))
    if not os.path.exists(os.path.join(save_path, 'masked_images')):
        os.makedirs(os.path.join(save_path, 'masked_images'))
    if not os.path.exists(os.path.join(save_path, 'masks')):
        os.makedirs(os.path.join(save_path, 'masks'))
    
    for image_name, image in lst_tiles.items():
        image_filename = image_name + '.npy'
        np.save(os.path.join(save_path, 'ground_truth', image_filename), image)
        for mask_name, mask in mask_tiles.items():
            if not os.path.exists(os.path.join(save_path, 'masks', mask_name + 'npy')):
                np.save(os.path.join(save_path, 'masks', mask_name + 'npy'), mask)
                
            masked_lst = np.multiply(image, mask)
            masked_lst_filename = mask_name + '_' + image_name + '.npy'
            # ADD ASSERTIONS HERE

            np.save(os.path.join(save_path, 'masked_images', masked_lst_filename), masked_lst)

            collection.loc[len(collection.index)] = [ 
                masked_lst_filename,
                image_filename,
            ]      
             
    collection.to_csv(os.path.join(save_path, 'collection.csv'))

In [3]:
data_path = '../data/'
sample = [os.path.join(data_path, 'LST_miniset', 'raw', os.listdir(os.path.join(data_path, 'LST_miniset', 'raw'))[0])]

In [4]:
generate_lst_dataset(
    images=sample, 
    masks=sample, 
    patch_size=(36, 36), 
    image_filter=image_filter, 
    mask_filter=image_filter,
    save_path=os.path.join(data_path, 'test')
    )

In [30]:
assert len(os.listdir(os.path.join(data_path, 'test','masked_images'))) == 100

In [11]:
ds = pd.read_csv(os.path.join(data_path,'test', 'collection.csv'), index_col=0)

In [22]:
assert len(ds['sample']) == len(np.unique(ds['sample'])), 'Fail'
assert len(np.unique(ds['ground_truth'])) == 10, 'Fail'

10
