In [None]:
from unbiased_area_estimation.utils import read_config
from unbiased_area_estimation.unbiased_estimation import AreaEstimator
from unbiased_area_estimation.sampling_allocation import create_sample_allocator
import pandas as pd
from IPython.display import display, HTML
import ipywidgets as widgets

CONFIG_PATH = 'config.json' # Set the path to your config file here
config = read_config(CONFIG_PATH)


# No need to change anything below this line usually
sample_allocator = create_sample_allocator(config['sampling'])

area_estimator = AreaEstimator(
    raster_path=config['raster_path'],
    mask_paths=config['mask_paths'],
    class_merge_dict=config['class_merge_dict'],
    epsg=config['epsg'],
    sample_allocator=sample_allocator,
    temp_dir=config['output_dir'] if config['keep_preprocessed_files'] else None,
    output_dir=config['output_dir'],
    overwrite_existing=False,
    )

area_estimator.preprocess_files()

input_widgets_nsamples = {}
output_widget = widgets.Output()
sampling_designs = area_estimator.create_sample_designs()



# Create input fields
for key, sub_dict in sampling_designs.items():
    input_widgets_nsamples[key] = {}
    display(widgets.HTML(f"<b>{key}</b>"))
    
    for class_key, value in sub_dict.items():
        input_field = widgets.BoundedFloatText(
            value=value, min=0, max=10000, step=1, description=f'Class {class_key}'
        )
        input_widgets_nsamples[key][class_key] = input_field
        display(input_field) 

# Function to retrieve updated values
def get_updated_values():
    updated_data = {}
    for key, sub_dict in input_widgets_nsamples.items():
        updated_data[key] = {class_key: input_field.value for class_key, input_field in sub_dict.items()}
    return updated_data

def handle_create_samples_button_click(b):
    with output_widget:
        output_widget.clear_output()
        display(HTML('<p style="color:green; font-weight:bold; font-family: Lato, sans-serif;">Sampling... Please wait.</p>'))
        sample_sets = area_estimator.create_sample_sets(get_updated_values())
        output_dir = area_estimator.save_sample_sets(sample_sets=sample_sets, single_file=True)
        output_widget.clear_output()
        display(HTML(f'<p style="color:green; font-weight:bold; font-family: Lato, sans-serif;">Samples created and saved successfully to {output_dir}!</p>'))

create_samples_button = widgets.Button(description="Create Samples")
create_samples_button.on_click(handle_create_samples_button_click)
display(create_samples_button, output_widget)