In [None]:
import os
import numpy as np
import pickle
import rasterio
from collections import Counter
import json

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def get_grid_nums(home, country, source):
    cur_path = os.path.join(home, country, source)
    files = [os.path.join(cur_path, f) for f in os.listdir(cur_path) if f.endswith('.npy')]
    grid_numbers = [f.split('_')[-1].replace('.npy', '') for f in files]
    grid_numbers.sort()
    return grid_numbers

In [None]:
home = '/home/data'

s1_grid_numbers = get_grid_nums(home, 'Ghana', 's1_64x64_npy')
s2_grid_numbers = get_grid_nums(home, 'Ghana', 's2_64x64_npy')

s1_grid_numbers.sort()

only_in_one = []
all_grid_nums = set(s1_grid_numbers + s2_grid_numbers)
for grid_num in all_grid_nums:
    if grid_num not in s1_grid_numbers or grid_num not in s2_grid_numbers:
        only_in_one.append(grid_num)
        
print(only_in_one)

In [None]:
s2_grid_numbers.sort()
print(s1_grid_numbers)

In [None]:
def get_empty_grids(home, countries, sources, verbose):
    """
    Provides data from input .tif files depending on function input parameters. 
    
    Args:
      home - (str) the base directory of data

      countries - (list of str) list of strings that point to the directory names
                  of the different countries (i.e. ['Ghana', 'Tanzania', 'SouthSudan'])

      verbose - (boolean) prints outputs from function

    """

    mask_pixels_list = []
    empty_masks = []
    for country in countries:
        mask_fnames = [os.path.join(home, country, 'raster_64x64', f) for f in os.listdir(os.path.join(home, country, 'raster_64x64')) if f.endswith('.tif')]
        mask_ids = [f.split('_')[-1].replace('.tif', '') for f in mask_fnames]

        mask_fnames.sort()
        mask_ids.sort()

        assert len(mask_fnames) == len(mask_ids)

        for mask_fname, mask_id in zip(mask_fnames, mask_ids):
            with rasterio.open(mask_fname) as src:
                cur_mask = src.read()
                valid_pixels = np.sum(cur_mask > 0)
                mask_pixels_list.append((mask_id, valid_pixels))
                if valid_pixels == 0:
                    empty_masks.append(mask_id)

        for source in sources:
            cur_path = os.path.join(home, country, source)
            files = [os.path.join(cur_path, f) for f in os.listdir(cur_path) if f.endswith('.tif')]
            grid_numbers = [f.split('_')[-2] for f in files]
            grid_numbers.sort()

        delete_me = []
        all_ids = set(empty_masks + grid_numbers)
        for el in all_ids:
            if el in empty_masks and el in grid_numbers:
                delete_me.append(el)

        delete_me.sort()

        print('valid pixels list: ', len(mask_pixels_list))
        print('empty masks: ', len(empty_masks))
        print('delete me: ', len(delete_me))
        print('delete me: ', delete_me)
        
    return mask_pixels_list


In [None]:
home = '/home/data'
countries = ['Ghana']
sources = ['s2_64x64']
verbose = 1

mask_pixels_list = get_empty_grids(home, countries, sources, verbose)

print('--------------------------')

home = '/home/data'
countries = ['Ghana']
sources = ['s1_64x64']
verbose = 1

mask_pixels_list = get_empty_grids(home, countries, sources, verbose)

In [None]:
mask_pixels_arr = np.array(mask_pixels_list)
mask_pixels_arr.shape

In [None]:
mask_pix_sub = mask_pixels_arr[:,1].astype(int)
mask_pix_sub = mask_pix_sub[mask_pix_sub != 0]
mask_pix_sub.shape

In [None]:
plt.hist(mask_pix_sub, bins=50) 
plt.title("Histogram for # of valid pixels in each grid")
plt.show()

In [None]:
valid_pix_numbers = mask_pixels_arr[:,1].astype(int)
print('Less than 10 pixels: ', np.sum((valid_pix_numbers < 10) * (valid_pix_numbers > 0)))
print('Less than 20 pixels: ', np.sum((valid_pix_numbers < 20) * (valid_pix_numbers > 0)))
print('Less than 30 pixels: ', np.sum((valid_pix_numbers < 30) * (valid_pix_numbers > 0)))
print('Less than 40 pixels: ', np.sum((valid_pix_numbers < 40) * (valid_pix_numbers > 0)))
print('Less than 50 pixels: ', np.sum((valid_pix_numbers < 50) * (valid_pix_numbers > 0)))
print('Less than 100 pixels: ', np.sum((valid_pix_numbers < 100) * (valid_pix_numbers > 0)))