In [None]:
import ee
import geemap
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

# ==========================================
# 1. INITIALIZATION & SETUP
# ==========================================
MY_PROJECT = "gen-lang-client-0426799622"
MODEL_PATH = "models/erosion_model_hybrid.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    ee.Initialize(project=MY_PROJECT)
except Exception:
    ee.Authenticate()
    ee.Initialize(project=MY_PROJECT)

# ==========================================
# 2. LOAD THE AI MODEL
# ==========================================
class MultiClassUNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(MultiClassUNet, self).__init__()
        self.enc1 = self.conv_block(in_channels, 16)
        self.enc2 = self.conv_block(16, 32)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.dec1 = self.conv_block(32 + 16, 16)
        self.final = nn.Conv2d(16, num_classes, kernel_size=1)

    def conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(), nn.BatchNorm2d(out_c),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(), nn.BatchNorm2d(out_c),
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        d1 = self.up(e2)
        # Handles any odd dimensions dynamically
        if d1.size() != e1.size():
            d1 = torch.nn.functional.interpolate(d1, size=e1.shape[2:])
        d1 = torch.cat([d1, e1], dim=1)
        out = self.dec1(d1)
        return self.final(out)

model = MultiClassUNet(in_channels=5, num_classes=3).to(device)
if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True))
    model.eval()
else:
    print(f"‚ö†Ô∏è Warning: Model not found at {MODEL_PATH}. Prediction map will be blank.")

# ==========================================
# 3. INTERACTIVE UI ELEMENTS
# ==========================================
Map = geemap.Map(center=[19.0, 75.0], zoom=7)
Map.add_basemap("HYBRID")

# Button to trigger analysis
analyze_btn = widgets.Button(
    description="Analyze Selected Region",
    button_style="primary",
    tooltip="Draw a rectangle on the map, then click here",
    icon="search"
)

output_console = widgets.Output()

# ==========================================
# 4. PREDICTION LOGIC
# ==========================================
def run_prediction(b):
    with output_console:
        clear_output(wait=True)
        
        # 1. Check if user actually drew a box
        if Map.user_roi is None:
            print("‚ùå ERROR: Please draw a rectangle on the map first using the drawing tools on the left.")
            return

        # Get the drawn boundary and its center
        roi = Map.user_roi.bounds()
        
        # STRICT SAFEGUARD: Enforce max 5.5km x 5.5km limit
        area_sq_meters = roi.area(maxError=1).getInfo()
        max_allowed_area = (2750 * 2) ** 2  
        
        # üÜï AGGRESSIVE CLEANUP FUNCTION
        def force_clear_map():
            try:
                if hasattr(Map, 'draw_control') and Map.draw_control is not None:
                    Map.draw_control.clear() # Force wipe the drawing tool memory
                Map.remove_drawn_features()  # Remove the visual layer
                Map.user_roi = None          # Reset the internal variable
            except Exception:
                pass

        if area_sq_meters > max_allowed_area:
            print(f"‚ùå ERROR: Box is too large! ({area_sq_meters / 1e6:.2f} sq km)")
            print(f"Maximum allowed size is {max_allowed_area / 1e6:.2f} sq km to prevent memory crashes.")
            print("The invalid box has been removed. Please draw a smaller one.")
            force_clear_map() # Execute cleanup
            return
            
        centroid = roi.centroid(maxError=1).getInfo()['coordinates']
        lon, lat = centroid[0], centroid[1]
        
        print(f"üìç Valid box selected near Lat {lat:.4f}, Lon {lon:.4f}. Fetching satellite data...")

        # Execute cleanup immediately after grabbing the coordinates
        force_clear_map()

        def mask_s2_clouds(image):
            qa = image.select('QA60')
            mask = qa.bitwiseAnd(1 << 10).eq(0).And(qa.bitwiseAnd(1 << 11).eq(0))
            return image.updateMask(mask).divide(10000)

        # 2. Fetch Visual Data (2024)
        s2 = (ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
              .filterBounds(roi).filterDate("2024-01-01", "2024-06-30")
              .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 20))
              .map(mask_s2_clouds).median().clip(roi))

        dem = ee.Image("USGS/SRTMGL1_003").clip(roi)
        slope_deg = ee.Terrain.slope(dem)

        # 3. CALCULATE PURE RUSLE FORMULA
        rain = ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY").filterDate('2023-01-01', '2023-12-31').sum().clip(roi)
        R = rain.pow(1.61).multiply(0.0483)
        
        soil = ee.Image("OpenLandMap/SOL/SOL_TEXTURE-CLASS_USDA-TT_M/v02").clip(roi)
        K = soil.remap([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
                       [0.05, 0.15, 0.2, 0.25, 0.3, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6], 0.3)
        
        LS = slope_deg.divide(100).pow(1.3).multiply(2)
        
        ndvi = s2.normalizedDifference(['B8', 'B4'])
        C = ndvi.expression("exp(-2 * (ndvi / (1 - ndvi)))", {'ndvi': ndvi})
        C = C.where(C.gt(1), 1).where(C.lt(0), 0)
        
        soil_loss = R.multiply(K).multiply(LS).multiply(C)
        rusle_class = ee.Image(0).where(soil_loss.gte(5).And(soil_loss.lt(20)), 1).where(soil_loss.gte(20), 2).rename('rusle')

        # 4. Create the Smart Mask
        # 4. Create the Smart Mask (Farms are now analyzed!)
        landcover = ee.ImageCollection("ESA/WorldCover/v100").first().clip(roi)
        ignore_mask = landcover.eq(50).Or(landcover.eq(80)).Or(landcover.eq(100)).rename('ignore')
        
        inputs = s2.select(["B4", "B3", "B2", "B8"]).addBands(slope_deg.divide(90))
        download_stack = inputs.addBands(ignore_mask).addBands(rusle_class)
        
        # Download happens here
        pixel_data = geemap.ee_to_numpy(download_stack, region=roi, scale=30)
        
        if pixel_data is None or pixel_data.shape[0] == 0:
            print("‚ùå Error: Data unavailable for this region.")
            return
            
        pixel_data = np.nan_to_num(pixel_data, nan=0.0)

        # 5. Extract Data Layers
        data_tensor = torch.from_numpy(pixel_data).float().permute(2, 0, 1)
        X = data_tensor[:5, :, :].unsqueeze(0).to(device)
        city_water_farm_mask = pixel_data[:, :, 5]
        rusle_map = pixel_data[:, :, 6]

        # 6. Run AI Model Prediction
        print("üß† Running AI prediction...")
        with torch.no_grad():
            logits = model(X)
            ai_prediction_map = torch.argmax(logits, dim=1).squeeze().cpu().numpy()

        ai_prediction_map[city_water_farm_mask == 1] = 3
        rusle_map[city_water_farm_mask == 1] = 3

        # ==========================================
        # 7. PLOT ALL 3 MAPS
        # ==========================================
        cmap = ListedColormap(["#228B22", "#FFD700", "#FF0000", "#808080"])
        plt.figure(figsize=(18, 5))

        rgb = pixel_data[:, :, 0:3]
        rgb = (rgb - np.min(rgb)) / (np.max(rgb) - np.min(rgb) + 1e-8)

        plt.subplot(1, 3, 1)
        plt.imshow(rgb)
        plt.title(f"1. Satellite View\nCenter Lat {lat:.2f}, Lon {lon:.2f}")
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.imshow(rusle_map, cmap=cmap, vmin=0, vmax=3)
        plt.title("2. Pure RUSLE Formula (Math)\n(Theoretical Risk)")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.imshow(ai_prediction_map, cmap=cmap, vmin=0, vmax=3)
        plt.title("3. Hybrid AI Model (Reality)\n(Theory + Historical Degradation)")
        plt.axis("off")

        plt.tight_layout()
        plt.show()
        print("‚úÖ Analysis Complete! The map has been cleared. You can draw a new box immediately.")
# Bind button click to function
analyze_btn.on_click(run_prediction)

# ==========================================
# 5. RENDER THE APP
# ==========================================
print("üåç INTERACTIVE EROSION DETECTOR (Bounding Box Edition)")
print("1. Look at the toolbar on the left side of the map.")
print("2. Click the Rectangle Tool (‚¨õ) and draw a box over the area you want to analyze.")
print("3. Keep the box reasonably sized (around 5x5 km) to prevent memory crashes.")
print("4. Click the blue 'Analyze Selected Region' button below.")

# Display widgets
display(analyze_btn)
display(Map)
display(output_console)

üåç INTERACTIVE EROSION DETECTOR (Bounding Box Edition)
1. Look at the toolbar on the left side of the map.
2. Click the Rectangle Tool (‚¨õ) and draw a box over the area you want to analyze.
3. Keep the box reasonably sized (around 5x5 km) to prevent memory crashes.
4. Click the blue 'Analyze Selected Region' button below.


Button(button_style='primary', description='Analyze Selected Region', icon='search', style=ButtonStyle(), tool‚Ä¶

Map(center=[19.0, 75.0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', ‚Ä¶

Output()