In [None]:
import ee
import xarray as xr
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import joblib

# Initialize Google Earth Engine
try:
    ee.Initialize()
except Exception as e:
    ee.Authenticate()
    ee.Initialize()

# Define the Amazon Basin polygon (example placeholder)
amazon_basin = ee.Geometry.Polygon([
    [
        [-80, -10], [-80, 5], [-65, 5], [-65, -10], [-80, -10]
    ]
])

# Load datasets
def load_datasets():
    sentinel1 = ee.ImageCollection("COPERNICUS/S1_GRD").filterBounds(amazon_basin)
    sentinel2 = ee.ImageCollection("COPERNICUS/S2").filterBounds(amazon_basin)
    landsat = ee.ImageCollection("LANDSAT/LC08/C02/T1_L2").filterBounds(amazon_basin)
    dem = ee.Image("USGS/SRTMGL1_003")
    gedi = ee.FeatureCollection("LARSE/GEDI/GEDI04_A_002").filterBounds(amazon_basin)
    
    return sentinel1, sentinel2, landsat, dem, gedi

# Preprocess Sentinel-1 (reduce mean)
def preprocess_sentinel1(sentinel1):
    return sentinel1.select(['VV', 'VH']).mean()

# Preprocess Sentinel-2 (cloud masking and mean reduction)
def preprocess_sentinel2(sentinel2):
    def mask_clouds(image):
        qa = image.select('QA60')
        cloud_mask = qa.bitwiseAnd(1 << 10).eq(0).And(qa.bitwiseAnd(1 << 11).eq(0))
        return image.updateMask(cloud_mask)
    sentinel2 = sentinel2.map(mask_clouds)
    return sentinel2.select(['B2', 'B3', 'B4', 'B8']).mean()

# Preprocess Landsat 8 (cloud masking and mean reduction)
def preprocess_landsat(landsat):
    def mask_clouds(image):
        qa = image.select('QA_PIXEL')
        cloud_mask = qa.bitwiseAnd(1 << 5).eq(0).And(qa.bitwiseAnd(1 << 7).eq(0))
        return image.updateMask(cloud_mask)
    landsat = landsat.map(mask_clouds)
    return landsat.select(['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5']).mean()

# Combine all datasets into a single image
def create_composite(sentinel1, sentinel2, landsat, dem):
    sentinel1 = preprocess_sentinel1(sentinel1)
    sentinel2 = preprocess_sentinel2(sentinel2)
    landsat = preprocess_landsat(landsat)
    composite = sentinel1.addBands(sentinel2).addBands(landsat).addBands(dem)
    return composite

# Extract training data from GEDI
def extract_training_data(gedi, composite):
    samples = composite.sampleRegions(
        collection=gedi,
        properties=['agbd'],  # Assume GEDI AGBD field is 'agbd'
        scale=30
    )
    return samples

# Batch process samples to avoid memory issues
def batch_samples_to_xarray(samples, batch_size=1000):
    features = []
    for i in range(0, len(samples['features']), batch_size):
        batch = samples['features'][i:i+batch_size]
        df = pd.DataFrame(batch)
        properties = df['properties'].apply(pd.Series)
        features.append(properties)
    return xr.Dataset.from_dataframe(pd.concat(features, ignore_index=True))

# Train AGB Prediction Model
def train_model(features, target):
    X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.3, random_state=42)
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    predictions = model.predict(X_test)
    mse = mean_squared_error(y_test, predictions)
    print(f"Model MSE: {mse}")
    return model

# Main Execution
sentinel1, sentinel2, landsat, dem, gedi = load_datasets()
composite = create_composite(sentinel1, sentinel2, landsat, dem)
samples = extract_training_data(gedi, composite).getInfo()
data = batch_samples_to_xarray(samples)

# Split features and target
features = data.drop_vars('agbd').to_array().to_numpy()
target = data['agbd'].to_numpy()

# Train model
model = train_model(features, target)

# Save model
joblib.dump(model, "agb_model.pkl")

print("Model training completed and saved.")
