# 🎯 Sampling Design

**Description**: This notebook facilitates the __sampling design__ process for unbiased area estimation. It preprocesses the geospatial data and creates sample sets based on the specified allocation method and error targets.

In [None]:

import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import time
import random
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import folium
from folium.plugins import MarkerCluster
import pandas as pd


from unbiased_area_estimation.sampling_design import SamplingDesignPipeline

# Global state tracking
step = 1
results = {}
out = widgets.Output()
sampling_design_pipeline = None
expected_accuracy_widgets = []
samples_per_class_widgets = []
expected_error_displays = []
regions = None

# --- Step 1: Sampling and Allocation Method Selection ---
sampling_method = widgets.Dropdown(
    options=["Simple Random", "Stratified", "Two-Stage"],
    description="Sampling Method:",
    style={'description_width': 'initial'}
)
step1_button = widgets.Button(description="Next")
loading1 = widgets.Label("")

def step1_submit(b):
    global step
    if step == 1:
        results["sampling_method"] = sampling_method.value
        step2_ui()
        step += 1

step1_button.on_click(step1_submit)

# --- Step 2: Map and Mask Upload ---
map_path = widgets.Text(description="Map Path:", style={'description_width': 'initial'})
masks_path = widgets.Textarea(description="Mask Paths:", style={'description_width': 'initial'})
class_merge = widgets.Textarea(description="Class Merging: ", value="0:1\n1:1\n 2:2\n3:3\n4:4\n5:2", style={'description_width': 'initial'})
target_spatial_ref = widgets.Text(description="Target Spatial Ref", style={'description_width': 'initial'})
step2_button = widgets.Button(description="Next")
loading2 = widgets.Label("")

def get_class_merge_dict():
    class_merge_str = class_merge.value
    class_merge_entries = class_merge_str.split("\n")
    class_merge_dict = {}
    for e in class_merge_entries:
        e_cleaned = e.replace(" ", "")
        e_cleaned_split = e_cleaned.split(":")
        if len(e_cleaned_split) == 0:
            continue
        class_merge_dict[int(e_cleaned_split[0])] = int(e_cleaned_split[1])
    
    return class_merge_dict


def step2_submit(b):
    global step
    global results
    global sampling_design_pipeline
    global regions
    if step == 2:
        loading2.value = "Preprocessing files. This might take a while..."
        
        results["map_path"] = map_path.value
        results["mask_paths"] = {Path(m).stem : m for m in  masks_path.value.split("\n")}
        results["target_spatial_ref"] = target_spatial_ref.value
        output_path = os.path.join(os.path.dirname(results["map_path"]), "sampling_design")
        sampling_design_pipeline = SamplingDesignPipeline(
            sampling_method=results["sampling_method"], 
            output_path=output_path,
            use_cached=True)

        # TODO need to add class merging
        results["class_merge_dict"] = get_class_merge_dict()
        results["classes"] = set(list(results["class_merge_dict"].values()))
        
        regions = sampling_design_pipeline.preprocess(map_path=results["map_path"],
                                                      mask_paths=results["mask_paths"],
                                                      target_spatial_ref=results["target_spatial_ref"],
                                                      class_merge_map=results["class_merge_dict"])

        loading2.value = "Pixel counting..."
        results["pixel_counts"] = [region.get_pixel_counts_by_class() for region in regions]
        results["areas"] = [region.get_areas() for region in regions]

        loading2.value = ""

        for class_id in results["classes"]:
            expected_accuracy_widgets.append(widgets.FloatText(description=f"{class_id}:", value=0.85))

        step3_ui()
        step += 1

step2_button.on_click(step2_submit)

# --- Step 3: Sampling Calculations ---
expected_accuracy_header = widgets.Label("Expected Users Accuracies (per class)")
target_error = widgets.FloatText(description="Target Error:", value=0.01, style={'description_width': 'initial'})
allocation_method = widgets.Dropdown(
    options=["Proportional", "Neyman"],
    description="Allocation Method:",
    style={'description_width': 'initial'}
)
allocate_sampled_header = widgets.Label("Allocate Samples")
allocate_button = widgets.Button(description="Allocate")
expected_errors_header = widgets.Label("Expected Errors")
expected_errors_header.layout.visibility = "hidden"
expected_error_button = widgets.Button(description="Update Expected Error", style={'description_width': 'initial'})
sampling_button = widgets.Button(description="Run Sampling", style={'description_width': 'initial'})
loading3 = widgets.Label("")

def get_expected_uas():
    global expected_accuracy_widgets
    expected_uas = {}
    for expected_accuracy_widget in expected_accuracy_widgets:
        class_id = expected_accuracy_widget.description.replace(":", "")
        expected_ua = expected_accuracy_widget.value
        expected_uas[int(class_id)] = expected_ua
    
    return expected_uas

def get_updated_sampling_designs():
    global samples_per_class_widgets
    updated_sampling_designs = {}
    cur_region = None
    for samples_per_class_widget in samples_per_class_widgets:
        if type(samples_per_class_widget) == widgets.widgets.widget_string.Label:
            cur_region = samples_per_class_widget.value
            continue

        class_id = samples_per_class_widget.description.replace(":", "")
        num_samples = samples_per_class_widget.value

        if cur_region not in updated_sampling_designs:
            updated_sampling_designs[cur_region] = {}
        
        updated_sampling_designs[cur_region][int(class_id)] = num_samples

    return updated_sampling_designs
    

def allocate_samples(b):
    global regions, sampling_design_pipeline
    loading3.value = "Computing num samples and allocating..."
    sampling_designs , sampling_design_details = sampling_design_pipeline.create_and_save_sampling_designs(
        regions=regions,
        expected_uas=get_expected_uas(),
        target_error=target_error.value,
        allocation_method_name=allocation_method.value
    )
    results["sampling_designs"] = sampling_designs
    results["sampling_design_details"] = sampling_design_details

    for region_name, sampling_design in sampling_designs.items():
        new_samples_per_class_header = widgets.Label(region_name)
        new_samples_per_class_widgets = [widgets.IntText(description=f"{class_id}:", value=val) for class_id, val in sampling_design.items()]
        samples_per_class_widgets.append(new_samples_per_class_header)
        samples_per_class_widgets.extend(new_samples_per_class_widgets)

        new_expected_error_display = widgets.Label(f"{region_name}: -")
        expected_error_displays.append(new_expected_error_display)
    
    clear_output(wait=True)
    display(expected_accuracy_header, *expected_accuracy_widgets, target_error, allocation_method, allocate_button, *samples_per_class_widgets, expected_errors_header, expected_error_button, *expected_error_displays, sampling_button, loading3, out, reset_button)
    loading3.value = ""

def update_expected_error(b):
    global regions, sampling_design_pipeline
    loading3.value = "Computing new expected errors..."
    expected_errors = sampling_design_pipeline.get_expected_target_errors(region_names=[region.name for region in regions], sampling_designs=get_updated_sampling_designs())
    
    for expected_error_display in expected_error_displays:
        region_name = expected_error_display.value.split(":")[0]
        expected_error = f'{float(expected_errors[region_name]):4f}'
        expected_error_display.value = f'{region_name}: {expected_error}'
    loading3.value = ""

def run_sampling(b):
    global regions, results, sampling_design_pipeline
    loading3.value = "Sampling..."
    updated_sampling_designs = get_updated_sampling_designs()
    sample_sets = sampling_design_pipeline.sample_and_save(regions=regions, sampling_designs=updated_sampling_designs)
    loading3.value = "Sampling completed!"

    sample_set_all_regions = pd.concat([sample_set for sample_set in sample_sets.values()])
    center = [37.7749, -95.4194]


    def random_color():
        return "#{:06x}".format(random.randint(0, 0xFFFFFF))
    
    unique_classes = sample_set_all_regions["stratum_id"].unique()
    color_map = {class_id: random_color() for class_id in unique_classes}

    gdfs = [gpd.read_file(fp).to_crs(epsg=4326) for fn, fp in results["mask_paths"].items()]
    merged_gdf = gpd.GeoDataFrame(geometry=[gdf.geometry.unary_union for gdf in gdfs], crs="EPSG:3857")
    center_point = merged_gdf.geometry.centroid.iloc[0]
    m = folium.Map(location=[center_point.x, center_point.y], zoom_start=4, tiles="cartodbpositron")
    marker_cluster = MarkerCluster().add_to(m)


    for _, row in sample_set_all_regions.iterrows():
        folium.CircleMarker(
            location=[row["LAT"], row["LON"]],
            radius=6,
            color=color_map[row["stratum_id"]],
            fill=True,
            fill_color=color_map[row["stratum_id"]],
            fill_opacity=0.7,
            popup=f'Stratum ID: {row["stratum_id"]}'
        ).add_to(marker_cluster)

    legend_html = '''
    <div style="position: fixed; bottom: 50px; left: 50px; width: 200px; height: auto; background-color: white; 
        z-index:9999; font-size:14px; padding:10px; border-radius:5px;">
        <b>Legend</b><br>
    '''

    for class_id, color in color_map.items():
        legend_html += f'<i class="fa fa-circle" style="color:{color}"></i> Stratum {class_id}<br>'

    legend_html += '</div>'

    m.get_root().html.add_child(folium.Element(legend_html))
    display(m)

    
allocate_button.on_click(allocate_samples)
expected_error_button.on_click(update_expected_error)
sampling_button.on_click(run_sampling)

# --- Reset Button ---
def reset_all(b):
    global step, results, expected_accuracy_widgets, samples_per_class_widgets, sampling_design_pipeline, regions
    step = 1
    results = {}
    loading1.value = ""
    loading2.value = ""
    loading3.value = ""
    expected_accuracy_widgets.clear()
    samples_per_class_widgets.clear()
    expected_accuracy_widgets.clear()
    sampling_design_pipeline = None
    regions = None
    clear_output(wait=True)
    step1_ui()

reset_button = widgets.Button(description="Reset")
reset_button.on_click(reset_all)

def step1_ui():
    clear_output(wait=True)
    display(sampling_method, step1_button, loading1, reset_button)

def step2_ui():
    clear_output(wait=True)
    display(map_path, masks_path, target_spatial_ref, class_merge, step2_button, loading2, reset_button)

def step3_ui():
    clear_output(wait=True)
    display(expected_accuracy_header, *expected_accuracy_widgets, target_error, allocation_method, allocate_button, loading3, out, reset_button)

# Initialize UI
step1_ui()
