In [1]:
import dask.array as da
import h5py
import ee
import geemap
from google.cloud import storage
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.externals import joblib
from sklearn.metrics import mean_squared_error

# Initialize Google Earth Engine
ee.Initialize()

# Function to fetch and filter HDF5 LiDAR data from GCS
def fetch_and_filter_lidar_from_gcs(bucket_name, filter_flags):
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blobs = list(bucket.list_blobs(prefix=""))
    lidar_files = [blob.name for blob in blobs if blob.name.endswith('.h5')]

    filtered_data = []

    for lidar_file in lidar_files:
        local_file = f"/tmp/{lidar_file.split('/')[-1]}"
        blob = bucket.blob(lidar_file)
        blob.download_to_filename(local_file)

        with h5py.File(local_file, 'r') as f:
            agbd = f['agbd'][:]
            l4_flag = f['l4_flag'][:]
            mask = (l4_flag == filter_flags['l4_flag']) & (~np.isnan(agbd))
            filtered_data.append(agbd[mask])

    return da.concatenate(filtered_data, axis=0)

# Fetch DEM data from GEE
def fetch_dem_data(region):
    dem = ee.Image("USGS/SRTMGL1_003").clip(region)
    return geemap.ee_to_numpy(dem, region=region, default_value=-9999)

# Fetch Landsat Data from GEE
def fetch_landsat_data(region, start_date, end_date):
    landsat = ee.ImageCollection("LANDSAT/LC08/C02/T1_TOA")\
        .filterBounds(region)\
        .filterDate(start_date, end_date)\
        .select(['B2', 'B3', 'B4', 'B5', 'B6', 'B7'])
    composite = landsat.median().clip(region)
    return geemap.ee_to_numpy(composite, region=region, default_value=-9999)

# Fetch Sentinel-1 Data for Seasonal Timeframes
def fetch_sentinel_data_seasons(region):
    seasons = {
        'spring': ('2021-03-01', '2021-05-31'),
        'late_spring': ('2021-06-01', '2021-06-30'),
        'summer': ('2021-07-01', '2021-09-30')
    }

    seasonal_data = {}

    for season, (start_date, end_date) in seasons.items():
        s1 = ee.ImageCollection("COPERNICUS/S1_GRD")\
            .filterBounds(region)\
            .filterDate(start_date, end_date)\
            .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))\
            .select('VV')
        composite = s1.median().clip(region)
        seasonal_data[season] = geemap.ee_to_numpy(composite, region=region, default_value=-9999)

    return seasonal_data

# Preprocess and combine features
def preprocess_data(lidar, dem, landsat, sentinel_seasons):
    lidar_array = da.from_array(lidar, chunks=(1000, 1000))
    dem_array = da.from_array(dem, chunks=(1000, 1000))
    landsat_array = da.from_array(landsat, chunks=(1000, 1000, landsat.shape[-1]))

    sentinel_arrays = [da.from_array(sentinel_seasons[season], chunks=(1000, 1000)) for season in sentinel_seasons]
    sentinel_combined = da.concatenate(sentinel_arrays, axis=-1)

    combined_features = da.concatenate([lidar_array, dem_array, landsat_array, sentinel_combined], axis=-1)
    return combined_features

# Predict AGB and Calculate RMSE
def predict_agb(features, model, true_agb):
    features = features.compute()  # Convert Dask array to NumPy for prediction
    agb_pred = model.predict(features.reshape(-1, features.shape[-1]))
    agb_pred = agb_pred.reshape(features.shape[:-1])

    if true_agb is not None:
        true_agb = true_agb.compute()  # Convert Dask array to NumPy
        rmse = np.sqrt(mean_squared_error(true_agb.flatten(), agb_pred.flatten()))
        print(f"RMSE: {rmse}")

    return agb_pred

# Save AGB Map to HDF5
def save_agb_map(agb_map, output_file):
    with h5py.File(output_file, 'w') as f:
        f.create_dataset('agb', data=agb_map)

# Main Workflow
def main():
    bucket_name = 'test-agb-bucket/GEDIL4A2023'
    output_file = '/tmp/agb_map.h5'

    # Define filter flags for LiDAR data
    filter_flags = {
        'l4_flag': 1,  # Flag to filter by L4
        'agbd_min': 0  # Minimum AGBD threshold
    }

    # # Define region of interest (ROI)
    # region = ee.Geometry.Polygon([
    #     [[-60, -10], [-60, 0], [-50, 0], [-50, -10], [-60, -10]]
    # ])
    # Load the shapefile as an Earth Engine feature collection
    region = ee.FeatureCollection("projects/test-project-agb/assets/AmazonBasinLimits-master")


    # Fetch and filter LiDAR data
    print("Fetching and filtering LiDAR data...")
    lidar = fetch_and_filter_lidar_from_gcs(bucket_name, filter_flags)

    print("Fetching DEM data...")
    dem = fetch_dem_data(region)

    print("Fetching Landsat data...")
    landsat = fetch_landsat_data(region, '2021-01-01', '2021-12-31')

    print("Fetching Sentinel-1 data for seasons...")
    sentinel_seasons = fetch_sentinel_data_seasons(region)

    print("Preprocessing data...")
    features = preprocess_data(lidar, dem, landsat, sentinel_seasons)

    # Load pre-trained model
    print("Loading model...")
    # model = joblib.load('agb_model.pkl')

    # # Assuming true AGB values are available for RMSE calculation
    # true_agb = lidar  # Replace with actual true AGB values if different

    # # Predict AGB
    # print("Predicting AGB map...")
    # agb_map = predict_agb(features, model, true_agb)

    # # Save AGB map
    # print(f"Saving AGB map to {output_file}...")
    # save_agb_map(agb_map, output_file)

    print("AGB prediction complete.")

if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'dask'