In [1]:
# General Imports
import json
import math
from functools import partial
from itertools import cycle
from pathlib import Path

import folium
import geopandas
import ipywidgets as widgets
import pandas as pd
from folium.plugins import BeautifyIcon
from IPython.display import clear_output, display

In [2]:
# Import Utility functions
def load_case_csv(
    filename: str | Path,
    use_filter: bool = True,
    filter_cols: bool = True,
    annot: dict[str, str] = None,
):
    columns_keep = [
        "NBN Atlas record ID",
        "Occurrence ID",
        "Licence",
        "Scientific name",
        "Longitude (WGS84)",
        "Latitude (WGS84)",
        # "Longitude",
        # "Latitude",
        "Identification verification status",
        "Dataset name",
        "Dataset ID",
        "Data provider",
        "Data provider ID",
        "OSGR 100km",
        "OSGR 10km",
        "OSGR 2km",
        "OSGR 1km",
    ]
    result_df = pd.read_csv(filename)
    if filter_cols:
        result_df = result_df.filter(items=columns_keep)
    result_df.rename(
        columns={"Longitude (WGS84)": "Longitude", "Latitude (WGS84)": "Latitude"}, inplace=True
    )
    result_df["geometry"] = geopandas.points_from_xy(result_df["Longitude"], result_df["Latitude"])

    if use_filter:
        result_df = result_df.loc[
            result_df["Identification verification status"].str.contains("Accepted"), :
        ]
    if annot:
        if "name" not in annot:
            annot["name"] = Path(filename).stem
        for k, v in annot.items():
            result_df[k] = v
    result_df = geopandas.GeoDataFrame(result_df).set_geometry("geometry")
    result_df["repr_point"] = result_df.representative_point()
    return result_df


def load_geojson(filename: str | Path, filter_cols: bool = True, annot: dict[str, str] = None):
    columns_keep = [
        "OBJECTID",
        "sphn_ref",
        "sphn_status",
        "pest_disease",
        "geometry",
    ]
    result_df = geopandas.read_file(filename)
    assert "geometry" in result_df.columns, "GeoJSON does not have a geometry column"
    result_df["repr_point"] = result_df.representative_point()
    result_df["Longitude"] = result_df.repr_point.map(lambda v: v.x)
    result_df["Latitude"] = result_df.repr_point.map(lambda v: v.y)
    if filter_cols:
        result_df = result_df.filter(items=columns_keep)
    if annot:
        if "name" not in annot:
            annot["name"] = Path(filename).stem
        for k, v in annot.items():
            result_df[k] = v
    return result_df
    return result_df


def empty_list(alist):
    if isinstance(alist, list):
        while len(alist) > 0:
            alist.pop()
    return alist

In [3]:
# Class for Data File Descriptions
class DataFileDesc:
    def __init__(self, index: int, remove_func=None, data=None):
        self.index = index
        self.remove_func = remove_func
        self.label_fn = widgets.Label(value="DataFile Path")
        self.label_fn.data_file_desc = self
        self.filename = widgets.Text(
            placeholder="Input Path to datafile",
            description="",
            disabled=False,
        )
        self.filename.data_file_desc = self
        self.label_name = widgets.Label(value="High Level Name")
        self.label_name.data_file_desc = self
        self.name = widgets.Text(
            placeholder="Input Name to Use in plots tables etc",
            description="",
            disabled=False,
        )
        self.name.data_file_desc = self
        self.input_type = widgets.Dropdown(
            options=["NBN Atlas", "GeoJson"],
            value="NBN Atlas",
            description="Import:",
            disabled=False,
        )
        self.input_type.data_file_desc = self
        self.remove_btn = widgets.Button(
            description="Remove",
            disabled=self.remove_func is None,
            button_style="",
            tooltip="Remove DataFile",
            icon="remove",
        )
        self.keep_all_columns = widgets.Dropdown(
            options=["Yes", "No"],
            value="No",
            description="Keep Columns:",
            disabled=False,
        )
        self.dataset_type = widgets.Dropdown(
            options=["Cases", "Trees"], disabled=False, value="Trees", description="DataSet Type:"
        )

        if self.remove_func:
            self.remove_btn.on_click(partial(self.remove_func, self.index))
        self.remove_btn.data_file_desc = self
        if data:
            if "filename" in data:
                self.filename.value = data["filename"]
            if "name" in data:
                self.name.value = data["name"]
            if "input_type" in data:
                try:
                    self.input_type.value = data["input_type"]
                except Exception as e:
                    self.input_type.value = "NBN Atlas"
            if "keep_all_columns" in data:
                self.keep_all_columns.value = data["keep_all_columns"]
            if "dataset_type" in data:
                self.dataset_type.value = data["dataset_type"]

    def get_widgets(self):
        return [
            self.label_fn,
            self.filename,
            self.label_name,
            self.name,
            self.input_type,
            self.keep_all_columns,
            self.dataset_type,
            self.remove_btn,
        ]

    def to_dict(self):
        return {
            "filename": self.filename.value,
            "name": self.name.value,
            "input_type": self.input_type.value,
            "keep_all_columns": self.keep_all_columns.value,
            "dataset_type": self.dataset_type.value,
        }

    def load_data(self):
        fn = Path(self.filename.value)
        extra_annotations = dict(name=self.name.value, dataset_type=self.dataset_type.value)
        if self.input_type.value == "GeoJson":
            return load_geojson(fn, filter_cols=False, annot=extra_annotations)
        else:
            return load_case_csv(
                fn, use_filter=True, filter_cols=self.keep_all_columns, annot=extra_annotations
            )

    def update_index(self, indx: int):
        self.index = indx

        if self.remove_func:
            self.remove_btn._click_handlers.callbacks = []
            self.remove_btn.on_click(partial(self.remove_func, self.index))

In [4]:
# global variables
class GlobalVars:
    def __init__(self):
        self.data_file_boxes: list[DataFileDesc] = []
        self.g_all_dfs: list[pd.DataFrame] = []
        self.merged_df: pd.DataFrame = None


globalvars = GlobalVars()

In [6]:
# data_file_boxes: list = []
out = widgets.Output()
display(out)

btn_add = widgets.Button(
    description="Add New Data File",
    disabled=False,
    button_style="",
    tooltip="Add DataFile",
    icon="add",
)

btn_load = widgets.Button(
    description="Load Current Data Files",
    disabled=False,
    button_style="",
    tooltip="Load current data",
    icon="load",
)


def save_config(text_box, gvars, btn):
    dfbs = gvars.data_file_boxes
    path = Path(text_box.value)
    try:
        with open(path, "w") as f:
            out_list = []
            for dfb in dfbs:
                out_list.append(dfb.to_dict())
            json.dump(out_list, f)
    except Exception as e:
        display(f"there was an error {e}")
        return
    function_show_list(gvars, f"Config was stored to {path}")


def load_config(text_box, gvars, btn):

    path = Path(text_box.value)
    if not path.exists():
        display(f"Path: {path} does not exist nothing to load")
        return
    json_dicts = []
    with open(path, "r") as f:
        json_dicts = json.load(f)
    if len(json_dicts) == 0:
        display(f"File: {path} is empty nothing to load")
        return
    empty_list(gvars.data_file_boxes)
    for i, json_dict in enumerate(json_dicts):
        gvars.data_file_boxes.append(
            DataFileDesc(index=i, remove_func=partial(remove_filebox, gvars), data=json_dict)
        )
    function_show_list(gvars, f"Config was loaded from {path} {len(gvars.data_file_boxes)}")


text_box = widgets.Text(
    placeholder="Input Path to datafile",
    description="Path to Config",
    disabled=False,
)

btn_save_config = widgets.Button(
    description="Store Config",
    disabled=False,
    button_style="",
    tooltip="Store current file config",
    icon="save",
)

btn_save_config.on_click(partial(save_config, text_box, globalvars))

btn_load_config = widgets.Button(
    description="Load Config",
    disabled=False,
    button_style="",
    tooltip="Load current file config",
    icon="open",
)
btn_load_config.on_click(partial(load_config, text_box, globalvars))
save_load_hbox = widgets.HBox([text_box, btn_save_config, btn_load_config])


def load_data(gvars, btn):
    gvars.g_all_dfs = empty_list(gvars.g_all_dfs)
    display("foooobar", len(gvars.data_file_boxes))
    for dfb in gvars.data_file_boxes:
        gvars.g_all_dfs.append(dfb.load_data())
    gvars.merged_df = pd.concat(gvars.g_all_dfs)
    function_show_list(
        gvars, msg=f"All data are loaded {len(gvars.merged_df)} {len(gvars.g_all_dfs)}"
    )


btn_load.on_click(partial(load_data, globalvars))


def add_filebox(gvars: GlobalVars):
    index = len(gvars.data_file_boxes)
    new_dfb = DataFileDesc(index=index, remove_func=partial(remove_filebox, gvars))
    gvars.data_file_boxes.append(new_dfb)
    function_show_list(gvars)


btn_add.on_click(partial(add_filebox, globalvars))


def remove_filebox(gvars, index: int, _):
    display(f"removing {index}")
    assert index < len(gvars.data_file_boxes)
    gvars.data_file_boxes.pop(index)
    for i in range(0, len(gvars.data_file_boxes)):
        gvars.data_file_boxes[i].update_index(i)
    function_show_list(gvars)


def function_show_list(gvars, msg=None):
    fboxes = gvars.data_file_boxes
    with out:
        clear_output(wait=True)
    with out:
        display(save_load_hbox)
        display(widgets.VBox([widgets.HBox(dfd.get_widgets()) for dfd in fboxes]))
        display(
            widgets.HBox(
                [
                    btn_add,
                    btn_load,
                ]
            )
        )

        if msg:
            display(msg)
        display(
            "show",
            len(gvars.merged_df) if gvars.merged_df is not None else 0,
            len(gvars.data_file_boxes),
            len(gvars.g_all_dfs),
        )


function_show_list(globalvars)

Output()

In [8]:
def load_countries(main_path: Path):
    result = {}
    for p in main_path.iterdir():
        if p.suffix != ".geojson":
            continue
        result[p.stem] = geopandas.read_file(p)

    return result


globalvars.countries_bbox = load_countries(Path("countries"))

In [9]:
# Visualize everything from the supported countries
COLORS = [
    "red",
    "blue",
    "green",
    "purple",
    "orange",
    "darkred",
    "lightred",
    "beige",
    "darkblue",
    "darkgreen",
    "cadetblue",
    "darkpurple",
    "white",
    "pink",
    "lightblue",
    "lightgreen",
    "gray",
    "black",
    "lightgray",
]


def generate_map_points(df, m):
    layers = {
        l: folium.FeatureGroup(name=l, control=True).add_to(m) for l in list(df.name.unique())
    }
    layers_colors = {l: c for l, c in zip(layers, cycle(COLORS))}
    # for layer in layers.values():
    #     layer.add_to(m)
    display(layers_colors)
    # return
    for g, group_df in df.groupby("name"):
        group_df.apply(
            lambda row: folium.Marker(
                location=[row["Latitude"], row["Longitude"]],
                tooltip=g,
                popup=g,
                icon=BeautifyIcon(
                    icon_shape="circle-dot",
                    shadow_size=(0, 0),
                    # icon='map-marker',
                    background_color=layers_colors[g],
                    border_color=layers_colors[g],
                    numer=" ",
                    iconSize=[1, 1],
                ),
            ).add_to(layers[g]),
            axis=1,
        )
    folium.LayerControl().add_to(m)
    return m

In [10]:
def filter_by_country(df, countries, bboxes):
    bbox = pd.concat([country_bb for country, country_bb in bboxes.items() if country in countries])
    return geopandas.sjoin(df, bbox)
    # for country in selected_countries:
    #     display(country)
    #     bbox = countries_bbox[country]
    #     country_df_list.append(geopandas.sjoin(df, bbox))

In [12]:
default_choice = "Scotland"
out2 = widgets.Output()
display(out2)
choices = list(globalvars.countries_bbox.keys())
country_select = widgets.SelectMultiple(
    options=choices,
    value=[default_choice],
    description="Countries to show data from",
    disabled=False,
)
btn_vis_all = widgets.Button(
    description="Visualize Countries",
    button_style="",
    icon="show",
    tooltip="Visualize the data",
)
hbox_vis = widgets.HBox([country_select, btn_vis_all])


def init_map(
    loc=[54.0694407363737476, -3.846893832385978], zoom_start=6, min_zoom=6, control_scale=True
):
    return folium.Map(
        loc,
        zoom_start=zoom_start,
        min_zoom=min_zoom,
        control_scale=control_scale,
        # height="100%",
        # width="100%",
    )


def update_output(out, hbox_vis, m=None):
    with out:
        display("Updating map output")
        clear_output(wait=True)

        mm: folium.Map = m if m else init_map()
        html = widgets.HTML(mm._repr_html_())
        display(widgets.VBox([hbox_vis, html]))
        # display(mm)


def visualize_data(out, hbox, countries, gvars, button):
    display("Visualize button pressed")
    all_selected = filter_by_country(gvars.merged_df, countries.value, gvars.countries_bbox)
    # all_selected = pd.concat(country_df_list)
    mm = init_map()
    m = generate_map_points(all_selected, mm)
    update_output(out, hbox, m)


btn_vis_all._click_handlers.callbacks = []
btn_vis_all.on_click(partial(visualize_data, out2, hbox_vis, country_select, globalvars))
update_output(out2, hbox_vis)

Output()

In [13]:
# Do the case thing
# dataset_types = list(merged_df.dataset_type.unique())

# dataset_types_to_names = {}
# for data_type in dataset_types:
#     names_for_type = list(merged_df[merged_df.dataset_type == data_type].name.unique())
#     dataset_types_to_names[data_type] = names_for_type


# def generate_vbox(data_type, options):
#     label = widgets.Label(value=f"Options for {data_type}")
#     options_select = widgets.SelectMultiple(options=options, value=[options[0]], disabled=False)
#     return widgets.VBox([label, options_select])


# compute_sel_vboxes = {
#     data_type: generate_vbox(data_type, dataset_types_to_names[data_type])
#     for data_type in ["Cases", "Trees"]
#     if data_type in dataset_types
# }
# compute_sel_vboxes["Country"] = generate_vbox("Country", list(countries_bbox.keys()))
# compute_sel_hbox = widgets.HBox(list(compute_sel_vboxes.values()))
# display(compute_sel_hbox)

In [3]:
# out = widgets.Output()
# display(out)
# btn_compute = widgets.Button(
#     button_style="",
#     icon="cog",
#     description="Compute",
#     tooltip="Compute nearest point per case according to the distance",
# )
# lbl_distance = widgets.Label(value="Distance in km")
# txt_distance = widgets.FloatText(
#     value=5,
#     description="",
#     disabled=False,
# )
# cmap = init_map()
# hbox = widgets.HBox([lbl_distance, txt_distance, btn_compute])


# def find_closest_tree_per_case(cases, trees, countries, map_update, dist, df):
#     display("finding closest tree")
#     country_df = filter_by_country(df, countries.value, countries_bbox)
#     display(len(country_df))
#     display(dist)
#     # country_df.to_crs("epsg:32633",inplace=True)
#     col_renames = {"index_left": "other_index_left", "index_right": "other_index_right"}
#     cases_df = country_df[country_df.name.isin(cases.value)].rename(columns=col_renames)
#     trees_df = country_df[country_df.name.isin(trees.value)].rename(columns=col_renames)
#     display('cases', len(cases_df))
#     display('trees', len(trees_df))
#     cases_df.to_crs("epsg:32633", inplace=True)
#     trees_df.to_crs("epsg:32633", inplace=True)
#     joined = geopandas.sjoin_nearest(cases_df, trees_df, max_distance=dist, how="left")
#     display("joined", len(joined))
#     color_gen = cycle(COLORS)
#     feature_groups = {case: folium.FeatureGroup(name=case, control=True) for case in cases.value}
#     sub_groups = {}
#     sub_group_colors = {}
#     for (name_case, lat, long), group in joined.groupby(
#         ["name_left", "Latitude_left", "Longitude_left"]
#     ):
#         tree_names = sorted(list(group.name_right.unique()))
#         if len(tree_names) == 0:
#             continue
#         elif len(tree_names) == 1 and tree_names[0] is not None:
#             color = next(color_gen)
#             name = f"{name_case}"
#             if isinstance(tree_names[0], str):
#                 name += f" closest {tree_names[0]}"
#             if name in sub_group_colors:
#                 color = sub_group_colors[name]
#             else:
#                 color = next(color_gen)
#                 sub_group_colors[name] = color
#             sub_group = sub_groups.get(name,folium.plugins.FeatureGroupSubGroup(feature_groups[name_case], name))
#             if name not in sub_groups:
#                 sub_groups[name] = sub_group
#             group.apply(
#                 lambda row: folium.Marker(
#                     location=[row["Latitude_left"], row["Longitude_left"]],
#                     tooltip=name,
#                     popup=name,
#                     icon=BeautifyIcon(
#                         icon_shape="circle-dot",
#                         shadow_size=(0, 0),
#                         # icon='map-marker',
#                         background_color=color,
#                         border_color=color,
#                         number="",
#                         iconSize=[1, 1],
#                     ),
#                 ).add_to(sub_group),
#                 axis=1,
#             )
#         elif len(tree_names) > 1:
#             # for tree_name in tree_names:
#             color = next(color_gen)
#             name = f"{name_case} closest {','.join(tree_names)}"
#             if name in sub_group_colors:
#                 color = sub_group_colors[name]
#             else:
#                 color = next(color_gen)
#                 sub_group_colors[name] = color
#             sub_group = sub_group = sub_groups.get(name,folium.plugins.FeatureGroupSubGroup(feature_groups[name_case], name))
#             if name not in sub_groups:
#                 sub_groups[name] = sub_group
#             group.apply(
#                 lambda row: folium.Marker(
#                     location=[row["Latitude_left"], row["Longitude_left"]],
#                     tooltip=name,
#                     popup=name,
#                     icon=BeautifyIcon(
#                         icon_shape="circle-dot",
#                         shadow_size=(0, 0),
#                         # icon='map-marker',
#                         background_color=color,
#                         border_color=color,
#                         number="",
#                         iconSize=[1, 1],
#                     ),
#                 ).add_to(sub_group),
#                 axis=1,
#             )
#         else:
#             assert False, f"Unexpected group results case: {name_case}, names: {tree_names}"
#     for fgg in feature_groups.values():
#         map_update.add_child(fgg)
#     for sub_group in sub_groups.values():
#         map_update.add_child(sub_group)
#     display(sub_groups)
#     display(feature_groups)
#     folium.LayerControl(collapsed=False).add_to(map_update)
#     return map_update


# def calculate_closest_dist(
#     hbox, compute_map, case_select, tree_select, country_select, dist, df, btn=None
# ):
#     compute_map = init_map()
#     compute_map = find_closest_tree_per_case(
#         case_select, tree_select, country_select, compute_map, dist.value * 1000, df
#     )
#     update_compute_output(out, hbox, compute_map)


# btn_compute.on_click(
#     partial(
#         calculate_closest_dist,
#         hbox,
#         cmap,
#         compute_sel_vboxes["Cases"].children[1],
#         compute_sel_vboxes["Trees"].children[1],
#         compute_sel_vboxes["Country"].children[1],
#         txt_distance,
#         merged_df,
#     )
# )


# def update_compute_output(out, hbox, cm=None):
#     with out:
#         clear_output(wait=True)
#         display(widgets.VBox([hbox]))
#         display(cm)


# update_compute_output(out, hbox, cmap)

In [None]:
# larch = merged_df[merged_df.name == "Larch"].copy()
# other = merged_df[merged_df.name != "Larch"].copy()
# larch.geometry = larch.representative_point()
# other.geometry = other.representative_point()
# larch.to_crs("epsg:32633", inplace=True)
# other.to_crs("epsg:32633", inplace=True)
# nearest = geopandas.sjoin_nearest(other, larch, max_distance=5000, distance_col="distance")
# nearest.to_crs(crs=4326,inplace=True)
# display(compute_sel_vboxes["Cases"].children[1].value)
# display(compute_sel_vboxes["Trees"].children[1].value)
# display(compute_sel_vboxes["Country"].children[1].value)