In [2]:
import ee
import geemap
import pandas as pd
import geopandas as gpd
from shapely.geometry import Polygon
import matplotlib.pyplot as plt
import numpy as np

import geopandas as gpd
from shapely.geometry import box
from datetime import datetime, timedelta

import os
from os import path as osp
from tqdm.auto import tqdm
import math

import geemap.colormaps as cm

# Trigger the authentication flow.
ee.Authenticate()

# Initialize the library.
ee.Initialize(project ='ee-georgethulsey')

#### aux functions

In [3]:
def replace_values_with_null(image, values_to_replace):
    """
    Replace specific values in an Earth Engine image with null.
    
    :param image: ee.Image object
    :param values_to_replace: List of values to be replaced with null
    :return: ee.Image with specified values replaced by null
    """
    # Create a mask where the image equals any of the values to replace
    mask = ee.Image.constant(0)
    for value in values_to_replace:
        mask = mask.Or(image.eq(value))
    
    # Invert the mask (1 where we keep values, 0 where we replace with null)
    mask = mask.Not()
    
    # Apply the mask to the image
    return image.updateMask(mask)

def feature_collection_to_arrays(feature_collection,num = 5):
    first_collections = feature_collection.toList().getInfo()
    results = {}
    properties_to_ignore = ['system:index','date','geometry','detection']
    for feature in tqdm(first_collections):
        
        index = feature['properties']
        results[feature['id']] = {}

        properties = list(feature['properties'].keys())
        for prop in properties:
            if prop in properties_to_ignore:
                continue
            data = np.array(feature['properties'][prop])
            results[feature['id']][prop] = data

    return results   

def plot_dict_grid(info):
    n = len(info)
    
    # Calculate grid dimensions
    cols = math.ceil(math.sqrt(n))
    rows = math.ceil(n / cols)
    
    # Create subplots
    fig, axs = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    fig.suptitle('Grid Plot of Dictionary Items', fontsize=16)
    
    # If there's only one subplot, wrap it in a list for consistency
    if n == 1:
        axs = [axs]
    else:
        axs = axs.flatten()
    
    # Plot each item
    for ax, (key, value) in zip(axs, info.items()):
        im = ax.imshow(value, cmap='viridis')
        ax.set_title(key)
        # ax.axis('off')
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Remove any unused subplots
    for i in range(n, len(axs)):
        fig.delaxes(axs[i])
    
    plt.tight_layout()
    plt.show()

### Load FBFM image from Google Earth Engine

This asset must be uploaded via the GEE command line interface. 

In [4]:
landfire = ee.Image("projects/ee-georgethulsey/assets/landfire/landfire2019").rename('fbfm')
landfire = replace_values_with_null(landfire,[32767,-32768,-9999])

In [8]:
lf_meta = pd.read_csv('fuel_autoencoder/landfire_metadata.csv')
lf_meta.drop(index = 0,inplace= True)

landfire_fuel_classes = dict(zip(lf_meta['VALUE'],lf_meta['FBFM40']))

In [9]:
Map = geemap.Map(center=[39, -120.5], zoom=7)
palette = cm.palettes.dem
viz = {'min':91,'max':204,'palette':palette}
Map.addLayer(landfire,viz,'FBFM')

Map

Map(center=[39, -120.5], controls=(WidgetControl(options=['position', 'transparent_bg'], widget=SearchDataGUI(…

### Prepare export functions

In [10]:
def _verify_feature_collection(
    feature_collection
):
    """Verifies the feature collection is valid.
    
    If the feature collection is invalid, resets the feature collection.
    
    Args:
    feature_collection: An EE feature collection.
    
    Returns:
    `(feature_collection, size)` a tuple of the verified feature collection and
    its size.
    """
    try:
        size = int(feature_collection.size().getInfo())
    except ee.EEException:
        # Reset the feature collection
        feature_collection = ee.FeatureCollection([])
        size = 0
    return feature_collection, size

def extract_samples(
    image,
    geometry,
    sampling_limit_per_call=60,
    resolution=30,
    seed=123,
    numPixels = 8200, # default 16^2*32 = 8192 < 8200
    n_samplings = 100,
    mode = 'train'
):
    """
    Samples an EE image.

    Args:
        image: The EE image to extract samples from.
        geometry: The EE geometry over which to sample.
        sampling_limit_per_call: The limit on the size of EE calls. Can be used to
          avoid memory errors on the EE server side. To disable this limit, set it
          to `detection_count`.
        resolution: The resolution in meters at which to scale.
        seed: The number used to seed the random number generator. Used when
          sampling less than the total number of pixels.

    Returns:
        An EE feature collection with all the extracted samples.
    """
    feature_collection = ee.FeatureCollection([])
    num_per_call = sampling_limit_per_call 

    # The sequence of sampling calls is deterministic, so calling stratifiedSample
    # multiple times never returns samples with the same center pixel.
    for _ in range(math.ceil(n_samplings / num_per_call)):
        samples = image.sample(
            region=geometry,
            numPixels=numPixels,
            scale=resolution,
            seed=seed,
        )
        
        feature_collection = feature_collection.merge(samples)
    
    # Add overall geometry as metadata to the feature collection
    feature_collection = feature_collection.set('overall_geometry', geometry)
    
    return feature_collection

In [11]:
def _export_dataset(
    bucket,
    folder,
    prefix,
    geometry,
    kernel_size,
    sampling_scale,
    num_samples_per_file,
    n_samplings
):
    """Exports the dataset TFRecord files for wildfire risk assessment.
    
    Args:
    bucket: Google Cloud bucket
    folder: Folder to which to export the TFRecords.
    prefix: Export file name prefix.
    start_date: Start date for the EE data to export.
    start_days: Start day of each time chunk to export.
    geometry: EE geometry from which to export the data.
    kernel_size: Size of the exported tiles (square).
    sampling_scale: Resolution at which to export the data (in meters).
    num_samples_per_file: Approximate number of samples to save per TFRecord
      file.
    """
    #########################################################
    def _verify_and_export_feature_collection(
      num_samples_per_export,
      feature_collection,
      file_count,
      features,
      mode = 'train',
    ):
        """Wraps the verification and export of the feature collection.
        
        Verifies the size of the feature collection and triggers the export when
        it is larger than `num_samples_per_export`. Resets the feature collection
        and increments the file count at each export.
        
        Args:
          num_samples_per_export: Approximate number of samples per export.
          feature_collection: The EE feature collection to export.
          file_count: The TFRecord file count for naming the files.
          features: Names of the features to export.
        
        Returns:
          `(feature_collection, file_count)` tuple of the current feature collection
            and file count.
        """
        feature_collection, size_count = _verify_feature_collection(
            feature_collection)
        if size_count > num_samples_per_export:
          ee_utils.export_feature_collection(
              feature_collection,
              description=prefix + '_{:03d}'.format(file_count)+'_'+mode,
              bucket=bucket,
              folder=folder,
              bands=features,
          )
          file_count += 1
          feature_collection = ee.FeatureCollection([])
        return feature_collection, file_count
    ############################################################

    sampling_limit_per_call = 60
    
    file_count = 0
    feature_collection = ee.FeatureCollection([])

    image_list = [landfire]
    features = ['fbfm']

    arrays = ee_utils.convert_features_to_arrays(image_list, kernel_size)

    for (mode,factor) in [('train',1),('test',0.3),('val',0.3)]:
        print("Sampling in mode",mode)
        samples = extract_samples(
            arrays,
            geometry,
            sampling_limit_per_call=sampling_limit_per_call,
            resolution=30,
            seed=123,
            numPixels = 8200, # default 16^2*32 = 8192 < 8200
            n_samplings = int(n_samplings*factor),
            mode = mode# number of numpixel samplings
            
        )
        print('samples extracted')
        feature_collection = feature_collection.merge(samples)
        
        feature_collection, file_count = _verify_and_export_feature_collection(
          num_samples_per_file, feature_collection, file_count, features,mode = mode)
        # Export the remaining feature collection
        _verify_and_export_feature_collection(0, feature_collection, file_count,
                                        features,mode = mode)

### Perform dataset export (currently to a single TFRecord file

In [144]:
oregon = ee.Geometry.Rectangle([-124.6, 41.9, -116.4, 46.3])
conus_west = ee.Geometry.Rectangle([-125, 26, -100, 49])

In [145]:
bucket = 'scott_burgan_fuel_data'
folder = 'fbfm_conus_west'
prefix = 'fbfm40'
geometry = conus_west
kernel_size = 16
sampling_scale = 30
num_samples_per_file = 64
n_samplings = 10000

In [146]:
_export_dataset(bucket,folder,prefix,geometry,kernel_size,sampling_scale,num_samples_per_file,n_samplings)

Sampling in mode train
samples extracted
Sampling in mode test
samples extracted
Sampling in mode val
samples extracted
