In [None]:
import tensorflow as tf
import numpy as np
import rasterio
import cv2
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import Polygon
import ee
import geemap


try:
    ee.Initialize()
except Exception:
    ee.Authenticate()
    ee.Initialize()


shp_path = "E:/shashank/shp/Tupul.shp"
gdf = gpd.read_file(shp_path)
geometry = gdf.geometry.union_all()
ee_geometry = ee.Geometry.Polygon(list(geometry.exterior.coords))
aoi = ee.FeatureCollection(ee_geometry)

def mask_clouds(image):
    cloud_prob = image.select("MSK_CLDPRB")
    return image.updateMask(cloud_prob.lt(5))

# Post-Landslide Events
start_date_post = "2024-01-01"  
end_date_post = "2024-01-30"  
s2_post = (  
    ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")  
    .filterBounds(aoi)  
    .filterDate(start_date_post, end_date_post)  
    .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 10))  
    .map(mask_clouds).median().clip(aoi)
)

# Function to compute 3-sigma stretch parameters
def get_stretch_params(image, aoi, bands):
    stats = image.select(bands).reduceRegion(
        reducer=ee.Reducer.mean().combine(ee.Reducer.stdDev(), None, True),
        geometry=aoi.geometry(),
        scale=10,  
        maxPixels=1e9
    ).getInfo()
    
    stretch_params = {}
    for band in bands:
        mean = stats[f"{band}_mean"]
        std = stats[f"{band}_stdDev"]
        stretch_params[band] = {"min": mean - 3 * std, "max": mean + 3 * std}
    return stretch_params

bands = ["B4", "B3", "B2"]
stretch_params_post = get_stretch_params(s2_post, aoi, bands)
vis_params_post = {
    "bands": bands,
    "min": [stretch_params_post["B4"]["min"], stretch_params_post["B3"]["min"], stretch_params_post["B2"]["min"]],
    "max": [stretch_params_post["B4"]["max"], stretch_params_post["B3"]["max"], stretch_params_post["B2"]["max"]]
}

# Export s2_post Image
export_path = "E:/shashank/landslides/predictions/s2_post.tif"
geemap.ee_export_image(s2_post, filename=export_path, scale=10, region=aoi.geometry())

# Load Trained U-Net Model
MODEL_PATH = "unet_landslide_detector_tif.h5"
model = tf.keras.models.load_model(MODEL_PATH)

# Predict Landslide Function
def predict_landslide(image_path, model, input_size=(256, 256)):
    with rasterio.open(image_path) as src:
        img = src.read([1, 2, 3])  
        img = np.moveaxis(img, 0, -1) 
        img = cv2.resize(img, input_size)  
        img = img.astype(np.float32) / 5000.0  
        img = np.clip(img, 0, 1)  
        img = np.expand_dims(img, axis=0)  

    pred = model.predict(img)[0, :, :, 0]  
    pred = (pred > 0.5).astype(np.uint8)  
    return pred, img[0]  

# Convert Prediction to Polygon
def mask_to_polygon(mask, transform):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    polygons = []
    for contour in contours:
        if len(contour) > 2:  # Ignore small objects
            coords = [tuple(transform * (pt[0][0], pt[0][1])) for pt in contour]
            polygons.append(Polygon(coords))
    # Replace "EPSG:32643" with the correct UTM zone for your AOI
    return gpd.GeoDataFrame(geometry=polygons, crs="EPSG:32643")

# Paths
PREDICTION_SAVE_PATH = "E:/shashank/landslides/predictions/s2_post_mask.tif"
OUTPUT_POLYGON_PATH = "E:/shashank/landslides/predictions/s2_post_polygons.geojson"

# Predict Landslide
predicted_mask, preprocessed_img = predict_landslide(export_path, model)

# Save Predicted Mask as TIF
with rasterio.open(export_path) as src:
    meta = src.meta.copy()
    meta.update(dtype=rasterio.uint8, count=1)
    with rasterio.open(PREDICTION_SAVE_PATH, "w", **meta) as dst:
        dst.write(predicted_mask, 1)

# Convert Mask to Polygon and Save
with rasterio.open(export_path) as src:
    landslide_polygons = mask_to_polygon(predicted_mask, src.transform)
    landslide_polygons.to_file(OUTPUT_POLYGON_PATH)

print("Landslide mask saved at:", PREDICTION_SAVE_PATH)
print("Landslide polygons saved at:", OUTPUT_POLYGON_PATH)

# Visualize Prediction
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(preprocessed_img)
plt.title("s2_post Image (RGB)")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(predicted_mask, cmap="gray")
plt.title("Predicted Landslide Mask")
plt.axis("off")
plt.show()

# Optionally Visualize in geemap
Map = geemap.Map()
Map.centerObject(aoi, 12)
Map.addLayer(s2_post, vis_params_post, "s2_post (TCC)")
Map.addLayer(aoi, {"color": "red"}, "AOI")
Map.add_raster(PREDICTION_SAVE_PATH, layer_name="Predicted Landslide", colormap="Reds", opacity=0.7)
Map