In [None]:
from sedona.spark import *
from contextlib import contextmanager
from pyspark.storagelevel import StorageLevel
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
import time
import ipywidgets as widgets
from IPython.display import display, clear_output
import sys
import io
import pandas as pd
import geopandas as gpd
import fiona

import ipywidgets as widgets
from ipywidgets import Button, Layout
from IPython.display import *
from io import StringIO
import sys


class Enricher:

    def __init__(self, crs="EPSG:3035"):
        self.crs = crs if type(crs) is not int else f"EPSG:{crs}"
        self.cores = None
        self.res = None
        self.sedona = None
        self.res_agr = None
        self.df1 = None
        self.df2 = None
        self.dfs_list = {}

            
    def setup_cluster(self, which="wherobots", ex_mem=26, dr_mem=24, log_level=None):
        if which == "wherobots":
            config = SedonaContext.builder().getOrCreate()
            self.sedona = SedonaContext.create(config)
            
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Wherobots setup started with {self.cores} cores for parellelism.")
        elif which == "sedona":
            # config = SedonaContext.builder() .\
            #     config("spark.executor.memory", f"{ex_mem}g").\
            #     config("spark.driver.memory", f"{dr_mem}g").\
            #     config('spark.jars.packages',\
            #         'org.apache.sedona:sedona-spark-shaded-3.5_2.12:1.7.0,'\
            #         'org.datasyslab:geotools-wrapper:1.7.0-28.5').\
            #     getOrCreate()

            config = SedonaContext.builder() \
                .config("spark.executor.memory", f"{ex_mem}g") \
                .config("spark.driver.memory", f"{dr_mem}g") \
                .config("spark.local.dir", "./tmp_spark_spills") \
                .config("spark.driver.maxResultSize", "4g") \
                .config('spark.jars.packages',
                    'org.apache.sedona:sedona-spark-shaded-3.5_2.12:1.7.0,'
                    'org.datasyslab:geotools-wrapper:1.7.0-28.5') \
                .getOrCreate()


            self.sedona = SedonaContext.create(config)
            
            if log_level in ["OFF", "ERROR", "WARN", "INFO", "DEBUG"]:
                self.sedona.sparkContext.setLogLevel(log_level)
                
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Sedona initialized with {self.cores} cores for parellelism.")
        else:
            raise ValueError("Invalid 'which'. Choose either 'wherobots' or 'sedona'")
        
    @contextmanager
    def get_time(self, task_name):
        start = time.time()
        yield
        elapsed = time.time() - start
    
        print(f"{task_name}... DONE in {(elapsed/60):.2f} min" \
              if elapsed >= 60 else f"{task_name}... DONE in {elapsed:.2f} sec")

    
    def load(self, datasets, silent=True):
        print("\nLoading datasets...")
        print(f"Make sure the geometry column is named \"geometry\" in the datasets")
    
        self.datasets = {}
        for name, (path, fformat) in datasets.items():
            if fformat == "geoparquet":
                gdf = gpd.read_parquet(path)
                crs = f"EPSG:{gdf.crs.to_epsg()}"
                print(f"Loaded '{name}': {gdf.shape}, '{crs}'")
            elif fformat == "geopackage":
                layers = fiona.listlayers(path)
                if layers:
                    self.dfs_list[name] = self.sedona.read.format(fformat).option("tableName", layers[0]).load(path)
                    gdf = gpd.read_file(f'{path}', engine='pyogrio', use_arrow=True)
                    crs = gdf.crs
                    print(f"Loaded '{name}': {gdf.shape}, '{crs}'")
                else:
                    print(f"No layers found in GeoPackage '{name}'")
            else:
                gdf = gpd.read_file(path)
                crs = gdf.crs
                self.dfs_list[name] = self.sedona.read.format(fformat).load(path)
                print(f"Loaded '{name}': {gdf.shape}, '{crs}'")
            
            self.datasets[name] = (path, fformat, crs)

            
        print(f"{len(self.dfs_list)} datasets loaded. \n")
        
        if not silent:
            for name, df in self.dfs_list.items():
                print(f"\n Dataset: \"{name}\", count: {df.count()}")
    
                geometry_types = df.select(F.expr("ST_GeometryType(geometry)")).distinct().collect()
                
                res_string = [
                    f"{geometry_type} ({(df.filter(F.expr(f'ST_GeometryType(geometry) = \'{geometry_type}\'')).count() / df.count()) * 100:.2f}%)"
                    for row in geometry_types if (geometry_type := row[0])
                ]
                
                print(f"\"{name}\" has geometries of type(s): {', '.join(res_string)}")
                df.printSchema()

    def force_repartition(self, skip=[]):
        for name, df in self.dfs_list.items():
            if name not in skip:
                self.dfs_list[name] = df.repartition(self.cores)
    
    def inspect_partitions(self):
        for name, df in self.dfs_list.items():
            print(f"'{name}' partitions: {df.rdd.getNumPartitions()}")
            print(f"'{name}' distribution: {df.rdd.glom().map(len).collect()}")

    
    def transform(self, target=None, lazy=True):
        if target is None:
            target = self.crs
        elif type(target) is int:
            self.crs = f"EPSG:{target}"
            target = self.crs
        else:
            self.crs = target
        
        print()
        print("Transforming CRS...")
        for name, df in self.dfs_list.items():
            df = df.withColumn("geometry", F.expr(f"ST_Transform(geometry, '{self.datasets[name][2]}', '{target}')"))
            print(f"Changed CRS of '{name}': '{self.datasets[name][2]}' to '{target}'")
            self.dfs_list[name] = df
            
            if not lazy:
                self._make_cache(self.dfs_list.values())
            
    def fix_geometries(self):
        for name, df in self.dfs_list.items():
            invalid_count = df.filter(F.expr("NOT ST_IsValid(geometry)")).count()
            print(f"'{name}' has {((invalid_count / df.count()) * 100 if df.count() > 0 else 0):.2f}% invalid geometries.")
            
            if invalid_count > 0:
                df = df.withColumn("geometry", F.expr("ST_MakeValid(geometry)"))
                print(f"Fixed {invalid_count} geometries in '{name}'")
            else:
                print(f"Nothing to fix in '{name}'")
            
            self.dfs_list[name] = df

    
    def _make_cache(self, dfs=[]):
        for df in dfs:
            if isinstance(df, DataFrame):
                if df.storageLevel != StorageLevel.NONE:
                    df.unpersist()
                df.cache()
            # print(f"Dataset cached. {df.count()} rows.")

    
    def clear_memory(self, *keep):
        if self.res.storageLevel != StorageLevel.NONE:
            self.res.unpersist()
        self.res = None

    def join_chey_simple(self, selected_aggs, df1_name, df2_name):
        self.res_agr = self.dfs_list[df1_name].alias("df1").join(
            self.dfs_list[df2_name].alias("df2"), F.expr("ST_Intersects(df1.geometry, df2.geometry)")
        ).select(
            F.expr("df1.geometry").alias("df1_geom"),
            F.expr("df2.geometry").alias("df2_geom"),
            *[f"df1.{c}" for c in self.dfs_list[df1_name].columns if c != "geometry"],
            *[f"df2.{c}" for c in self.dfs_list[df2_name].columns if c != "geometry" and c not in self.dfs_list[df1_name].columns]
        )
    
    def join_chey_new(self, selected_aggs, df1_name, df2_name, group_by=None, pred="ST_Intersects", rel_str="2********", make_geom=True, ratio=True, madre=False, cache=True, grid_area=None):
        
        # self.df1 = self.dfs_list[df1_name]
        # self.df2 = self.dfs_list[df2_name]
    
        if self.res_agr is None:
            join_expr = f"{pred}(df1.geometry, df2.geometry)"
            if pred == "ST_Relate":
                join_expr = f"{pred}(df1.geometry, df2.geometry, '{rel_str}')"
        
            self.res = self.dfs_list[df1_name].alias("df1").join(
                self.dfs_list[df2_name].alias("df2"), F.expr(join_expr)
            ).select(
                F.expr("df1.geometry").alias("df1_geom"),
                F.expr("df2.geometry").alias("df2_geom"),
                *[f"df1.{c}" for c in self.dfs_list[df1_name].columns if c != "geometry"],
                *[f"df2.{c}" for c in self.dfs_list[df2_name].columns if c != "geometry" and c not in self.dfs_list[df1_name].columns]
            )

            
            self.res = self.res.withColumn("intr_geometry", F.expr("ST_Intersection(df1_geom, df2_geom)"))
            if ratio:
                if grid_area > 0:
                    self.res = self.res.withColumn("intr_ratio", F.expr(f"ST_Area(intr_geometry) / {grid_area}"))
                else:
                    self.res = self.res.withColumn("intr_ratio", F.expr("ST_Area(intr_geometry) / ST_Area(df2_geom)"))

                agg_exprs = []
                for col_name, agg_func in selected_aggs.items():
                    if agg_func == "sum":
                        agg_exprs.append(F.sum(F.col(col_name) * F.col("intr_ratio")).alias(f"{col_name}_agr_{agg_func}"))
                    elif agg_func == "mean":
                        agg_exprs.append(F.mean(F.col(col_name) * F.col("intr_ratio")).alias(f"{col_name}_agr_{agg_func}"))
                    elif agg_func == "min":
                        agg_exprs.append(F.min(F.col(col_name) * F.col("intr_ratio")).alias(f"{col_name}_agr_{agg_func}"))
                    elif agg_func == "max":
                        agg_exprs.append(F.max(F.col(col_name) * F.col("intr_ratio")).alias(f"{col_name}_agr_{agg_func}"))
                    elif agg_func == "count":
                        agg_exprs.append(F.count(F.col(col_name)).alias(f"{col_name}_agr_{agg_func}"))
                    elif agg_func == "first":
                        agg_exprs.append(F.first(F.col(col_name)).alias(f"{col_name}_agr_{agg_func}"))
                    else:
                        raise ValueError(f"Unsupported aggregation function: {agg_func}")
                    
                df1_cols = [F.first(F.col(f"df1.{c}")).alias(c) for c in self.dfs_list[df1_name].columns if c != group_by and c != "geometry"]
                df1_cols.append(F.first(F.col("df1_geom")).alias("geometry"))
                self.res_agr = self.res.groupBy(group_by).agg(*df1_cols, *agg_exprs)

                print(f"Aggregation completed. {self.res_agr.count()} rows.")
            
            if madre:
                columns_to_drop = ["df1_geom", "df2_geom"] + list(selected_aggs.keys())
                self.res = self.res.drop(*columns_to_drop)
                self.res = self.res.join(self.res_agr.drop("geometry"), on=group_by, how="left")
                self.res = self.res.withColumnRenamed("intr_geometry", "geometry")
    
        if cache:
            if madre:
                self.res.cache()
                self.res.count()
            self.res_agr.cache()
            self.res_agr.count()
            
    
        return self.res_agr
    
    

    def export(self, df="default", path="outputs", name="unnamed", how="repartition", num=None, clear=False):
        if num is None:
            num = self.cores
        if how == "repartition":
            self.res = self.res.repartition(num)
        elif how == "coalesce":
            self.res = self.res.coalesce(num)
        else:
            raise ValueError("Invalid 'how'. Choose either 'repartition' or 'coalesce'")

        if df == "default":
            self.res_agr.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}_agr")
        else:
            self.res.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}_madre")
        
        if clear:
            self.clear_memory()



class EnricherUI:
    def __init__(self, enricher):
        self.enricher = enricher
        self.loaded_dataframes = {}
        self.selected_cols = []
        self.group_by_col = None
        self._init_ui()
        self.loaded_dataframes = self.list_dataframes_in_memory()
        self.df1_dropdown.options = list(self.loaded_dataframes.keys())
        self.df2_dropdown.options = list(self.loaded_dataframes.keys())
        self.df1_dropdown.disabled = False
        self.df2_dropdown.disabled = False        
        self.selected_aggs = {}
        self.agg_options = ["sum", "count", "mean", "min", "max", "first"]

    def list_dataframes_in_memory(self):
        # return {name: obj for name, obj in globals().items() if isinstance(obj, pd.DataFrame)}
        return {name: df for name, df in self.enricher.dfs_list.items()}
    
    def _init_ui(self):
        # Heading:
        self.heading = widgets.HTML(value="<h1>Enrich with Overlay & Aggregation</h1>")

        # First line: Enrich <df1> with <df2>
        self.df1_dropdown = widgets.Dropdown(options=[], description="df1:", disabled=True, style={'description_width': 'initial'}, layout=widgets.Layout(margin="5px 20px", width="150px"))
        self.df2_dropdown = widgets.Dropdown(options=[], description="df2:", disabled=True, style={'description_width': 'initial'}, layout=widgets.Layout(margin="5px 20px", width="150px"))
        self.load_status = widgets.HTML(value="<small>Status: No dataframes loaded.</small>")
        self.load_button = widgets.Button(description="Load", disabled=True, layout=widgets.Layout(margin="5px 0px", width="100px", border="2px solid black"))
        self.load_button.style.font_weight = 'bold'

        # Second line: with attributes: <cols>
        self.cols_dropdown = widgets.SelectMultiple(options=[], description="aggr cols:", disabled=True, style={'description_width': 'initial'})
        self.agg_table_output = widgets.Output()        
        
        # Third line: unique id: <col>
        self.group_by_dropdown = widgets.Dropdown(options=[], description="unique id:", disabled=True, style={'description_width': 'initial'}, layout=widgets.Layout(margin="5px 20px", width="250px"))
        
        self.agg_status = widgets.HTML(value="")

        # Advanced options
        self.advanced_checkbox = widgets.Checkbox(value=False, description="Advanced options", style={'description_width': 'initial'})
        self.preserve_geoms_checkbox = widgets.Checkbox(value=False, description="Preserve overlapping geometries", disabled=False)
        self.intersection_ratio_checkbox = widgets.Checkbox(value=False, description="Enricher has uniform grids", disabled=False)
        self.grid_area_text = widgets.FloatText(value=1e6, description="Grid area:", disabled=True, layout=widgets.Layout(width="165px"))
        self.custom_predicate_checkbox = widgets.Checkbox(value=False, description="Custom ST_Relate predicate string:", disabled=False, layout=Layout(width="500px"))
        self.custom_predicate_text = widgets.Text(value="2********", disabled=True, layout=widgets.Layout(width="200px"))
        
        self.go_button = widgets.Button(description="Go", disabled=True, layout=widgets.Layout(margin="5px 0px", width="100px", border="2px solid black"))
        self.go_button.style.font_weight = 'bold'
        
        # Console output text box
        self.console_output = widgets.Output(layout=widgets.Layout(width="100%", height="200px", border="1px solid black"))
        self.clear_console_button = widgets.Button(description="Clear", layout=widgets.Layout(width='80px', border="2px solid black"))
        self.clear_console_button.style.font_weight = 'bold'

        # Layout
        self._setup_layout()
        self._setup_event_handlers()

    def _setup_layout(self):
        # First line: Enrich <df1> with <df2>
        df_selection_line = widgets.HBox([
            widgets.HTML(value="<h2 style='display: inline; margin-right: 10px;'>Enrich </h2>"),
            self.df1_dropdown,
            widgets.HTML(value="<h2 style='display: inline; margin-right: 10px;'> with </h2>"),
            self.df2_dropdown,
            self.load_button
        ])

        # Second line: with attributes: <cols>
        cols_selection_line = widgets.HBox([
            widgets.HTML(value="<h2 style='display: inline; margin-right: 10px;'> with attributes: </h2>"),
            self.cols_dropdown,
            self.agg_table_output,
        ])

        # Third line: unique id: <col>
        grp_by_selection_line = widgets.HBox([
            widgets.HTML(value=f"<h2 style='display: inline; margin-right: 10px;'><span id='unique_id_text'>unique identifier:</span> </h2>"),
            self.group_by_dropdown,
            self.go_button
        ])

        # Advanced options
        self.advanced_options = widgets.VBox(
            [
                self.preserve_geoms_checkbox,
                widgets.HBox([self.intersection_ratio_checkbox, self.grid_area_text]),
                widgets.HBox([self.custom_predicate_checkbox, self.custom_predicate_text]),
            ],
            layout=widgets.Layout(max_width="600px", display="none")
        )

        # Main layout
        self.main_layout = widgets.VBox([
            self.heading,
            widgets.HTML(value="<div style='height: 5px;'></div>"),
            df_selection_line,
            self.load_status,
            cols_selection_line,
            widgets.HTML(value="<div style='height: 4px;'></div>"),
            grp_by_selection_line,
            self.agg_status,
            widgets.HTML(value="<div style='height: 3px;'></div>"),
            self.advanced_checkbox,
            self.advanced_options,
            widgets.HBox([self.console_output, self.clear_console_button])
        ])

        # Display everything
        display(self.main_layout)

    def _setup_event_handlers(self):
        # Enable/disable load button based on dataframe selection
        def on_df_selection_change(change):
            if self.df1_dropdown.value and self.df2_dropdown.value:
                self.load_button.disabled = False
            else:
                self.load_button.disabled = True
        self.df1_dropdown.observe(on_df_selection_change, names='value')
        self.df2_dropdown.observe(on_df_selection_change, names='value')

        # Handle load button click
        def on_load_button_click(b):
            try:
                df1_name = self.df1_dropdown.value
                df2_name = self.df2_dropdown.value

                if df1_name not in self.loaded_dataframes or df2_name not in self.loaded_dataframes:
                    raise ValueError("Selected dataframes are not loaded in memory.")

                # Set the selected dataframes in the Enricher
                self.enricher.df1 = self.loaded_dataframes[df1_name]
                self.enricher.df2 = self.loaded_dataframes[df2_name]

                # Update column dropdowns
                self.cols_dropdown.options = self.loaded_dataframes[self.df2_dropdown.value].columns
                self.group_by_dropdown.options = self.loaded_dataframes[self.df1_dropdown.value].columns
                self.cols_dropdown.disabled = False
                self.group_by_dropdown.disabled = False

                self.load_status.value = f"<small>Status: Loaded {df1_name} and {df2_name}.</small>"
                self.main_layout.children[6].children[0].value = f"<h2 style='display: inline; margin-right: 10px;'>{df1_name}'s unique identifier: </h2>"
            except Exception as e:
                self.load_status.value = f"<small>Error: {str(e)}</small>"

        self.load_button.on_click(on_load_button_click)

        # Handle column selection
        def on_cols_change(change):
            for col in change["new"]:
                if col not in self.selected_cols:
                    self.selected_cols.append(col)
            
            self.cols_dropdown.options = [col for col in self.loaded_dataframes[self.df2_dropdown.value].columns if col not in self.selected_cols]
            
            # Preserve previously selected operations, default to "sum" for new columns
            for col in self.selected_cols:
                if col not in self.selected_aggs:
                    self.selected_aggs[col] = "sum"            
            
            def generate_agg_table():
                headers = ["Selected Column", "Operation", ""]
                
                cell_style = widgets.Layout(
                    # border="1px solid black", 
                    padding="0px 2px",
                    align_items="center", 
                    justify_content="center", 
                    width="125px"
                )
                
                clr_style = widgets.Layout(border="1px solid black", padding="0px 2px",align_items="center", justify_content="center", width="80px")
                
                header_row = [
                    widgets.HTML(f"<b>{headers[0]}</b>", layout=cell_style),
                    widgets.HTML(f"<b>{headers[1]}</b>", layout=cell_style),
                    widgets.HTML("", layout=widgets.Layout(padding="0px 2px",align_items="center", justify_content="center", width="80px"))
                ]
            
                rows = []
                for col in self.selected_aggs:
                    dropdown = widgets.Dropdown(
                        options=self.agg_options, 
                        value=self.selected_aggs[col], 
                        layout=cell_style
                    )
                    dropdown.observe(lambda change, col=col: self.selected_aggs.update({col: change["new"]}), names="value")
            
                    clear_button = widgets.Button(description="Clear", layout=clr_style)

                    def on_clear(btn, col=col):
                        self.selected_cols.remove(col)
                        self.agg_status.value = f"Status: Aggregating with: {', '.join([f'<b>{col}</b>' for col in self.selected_cols]) if self.selected_cols else '<i>select cols</i>'}, grouping by: {f'<b>{self.group_by_col}</b>' if self.group_by_col else '<i>select cols</i>'}"
                        del self.selected_aggs[col]
                        self.cols_dropdown.options = [col for col in self.loaded_dataframes[self.df2_dropdown.value].columns if col not in self.selected_cols]
                        
                        with self.agg_table_output:
                            self.agg_table_output.clear_output()
                            display(generate_agg_table())
                        
                    clear_button.on_click(on_clear)
                        
                    rows.extend([
                        widgets.HTML(col, layout=cell_style),
                        dropdown,
                        clear_button
                    ])
                
                scrollable_container = widgets.VBox([
                    widgets.GridBox(
                        children=header_row + rows,
                        layout=widgets.Layout(
                            grid_template_columns="150px 150px 90px",
                            grid_template_rows="auto",
                            padding="1px",
                            width="max-content",
                        )
                    )
                ], layout=widgets.Layout(
                    max_height="150px",
                    overflow_y="auto",
                    border="1px solid black"
                ))
            
                return scrollable_container


            with self.agg_table_output:
                self.agg_table_output.clear_output()
                display(generate_agg_table())

            self.agg_status.value = f"Status: Will aggregate with: {', '.join([f'<b>{col}</b>' for col in self.selected_cols]) if self.selected_cols else '<i>select cols</i>'}; grouping by: {f'<b>{self.group_by_col}</b>' if self.group_by_col else '<i>select cols</i>'}"
            if self.group_by_col and self.selected_cols:
                self.go_button.disabled = False            
        self.cols_dropdown.observe(on_cols_change, names='value')

        
        def on_group_by_change(change):
            self.group_by_col = change["new"]
            self.agg_status.value = f"Status: Will aggregate with: {', '.join([f'<b>{col}</b>' for col in self.selected_cols]) if self.selected_cols else '<i>select col</i>'}; grouping by: {f'<b>{self.group_by_col}</b>' if self.group_by_col else '<i>select col</i>'}"
            if self.group_by_col and self.selected_cols:
                self.go_button.disabled = False                
                

        self.group_by_dropdown.observe(on_group_by_change, names='value')


        def on_advanced_checkbox_change(change):
            if change["new"]:
                self.advanced_options.layout.display = "block"                

            else:
                self.advanced_options.layout.display = "none"

        self.advanced_checkbox.observe(on_advanced_checkbox_change, names='value')

        def on_intersection_ratio_change(change):
            self.grid_area_text.disabled = not change["new"]
        self.intersection_ratio_checkbox.observe(on_intersection_ratio_change, names='value')

        def on_custom_predicate_change(change):
            self.custom_predicate_text.disabled = not change["new"]
        self.custom_predicate_checkbox.observe(on_custom_predicate_change, names='value')

        def on_go_button_click(b):
            with self.console_output:
                try:
                    print("Performing operation. This may take a while. Check logs for Spark logs and completion status.")
                    self.enricher.join_chey_new(
                        selected_aggs=self.selected_aggs,
                        df1_name=self.df1_dropdown.value,
                        df2_name=self.df2_dropdown.value,
                        group_by=self.group_by_col,
                        pred="ST_Relate" if self.custom_predicate_checkbox.value else "ST_Intersects",
                        rel_str=self.custom_predicate_text.value if self.custom_predicate_checkbox.value else "2********",
                        make_geom=True,
                        ratio=self.intersection_ratio_checkbox.value,
                        madre=self.preserve_geoms_checkbox.value,
                        cache=True,
                        grid_area=float(self.grid_area_text.value) if self.intersection_ratio_checkbox.value else 1e6
                    )
                    print("Enrichment operation completed.")
                    
                except Exception as e:
                    print(f"Error: {str(e)}")
        
        self.go_button.on_click(on_go_button_click)

        def on_clear_console_button_click(b):
            self.console_output.clear_output()
        self.clear_console_button.on_click(on_clear_console_button_click)

    def add_dataframe(self, name, dataframe):
        self.loaded_dataframes[name] = dataframe
        self.df1_dropdown.options = list(self.loaded_dataframes.keys())
        self.df2_dropdown.options = list(self.loaded_dataframes.keys())
        self.df1_dropdown.disabled = False
        self.df2_dropdown.disabled = False
        

In [23]:
# Load data

# file paths:

path_contr = "./data_EU/countries_shp/"
path_reg = "./data_Italy/regioni/"
path_prov = "./data_Italy/provinci"
path_com_EU = "./data_EU/comuni_shp/"
path_com = "./data_Italy/comuni/"
path_grids = "./data_EU/census_grid_EU/grids_corrected.parquet"
path_grids_new = "./data_EU/census_grid_EU/grids_new.gpkg"


# datasets:
# format: {display_name: (path, file_format), ...}

datasets = {
    "countries": (path_contr, "shapefile"),
    "regions": (path_reg, "shapefile"),
    "provinces": (path_prov, "shapefile"),
    "comuni_EU": (path_com_EU, "shapefile"),
    "comuni": (path_com, "shapefile"),
    "pop_grids": (path_grids, "geoparquet"),
    "pop_grids_new": (path_grids_new, "geopackage")
    # "census": (path_census, ""),
}


obj = Enricher(crs="EPSG:3035")
obj.setup_cluster(which="sedona", ex_mem=26, dr_mem=24, log_level="ERROR")

obj.load(datasets, silent=True)
# obj.fix_geometries()
# obj.force_repartition(skip=['pop_grids'])
# obj.transform(lazy=False)

# obj.inspect_partitions()

# obj.dfs_list['comuni_EU'] = obj.dfs_list['comuni_EU'].filter(F.col('CNTR_ID').isin(["IT", "DE"]))

Sedona initialized with 10 cores for parellelism.

Loading datasets...
Make sure the geometry column is named "geometry" in the datasets
Loaded 'countries': (259, 12), 'EPSG:3035'
Loaded 'regions': (20, 6), 'EPSG:32632'
Loaded 'provinces': (107, 13), 'EPSG:32632'
Loaded 'comuni_EU': (122750, 12), 'EPSG:3035'
Loaded 'comuni': (7899, 13), 'EPSG:32632'
in geoparquet...


  return ogr_read(


Loaded 'pop_grids_new' from layer 'type': (7055226, 20), 'EPSG:3035'
6 datasets loaded. 



In [None]:
import geopandas as gpd
import pyarrow as pa
import pyarrow.parquet as pq

# Path to your GeoPackage file
gpkg_file = "./data_EU/census_grid_EU/grids_new.gpkg"

# Path to the output GeoParquet file
parquet_file = "./data_EU/census_grid_EU/grids_new.parquet"

# Read the GeoPackage file
gdf = gpd.read_file(f'{gpkg_file}', engine='pyogrio', use_arrow=True)

print(gdf.columns)
# # gdf = gdf.rename_geometry("geometry")

# # Convert the GeoDataFrame to a PyArrow Table
# table = pa.Table.from_pandas(gdf)

# # Write the table to a GeoParquet file
# pq.write_table(table, parquet_file)


  return next(self.gen)


In [7]:

for name, df in obj.dfs_list.items():
    df.coalesce(1).rdd.saveAsPickleFile(f"./pickles/dfs_list/{name}")
    



                                                                                

In [None]:
countries_df = obj.dfs_list['countries'].filter(F.col('CNTR_ID').isin(["IT", "NL", "BE", "DE"]))

In [4]:
# GUI

obj_ui = EnricherUI(obj)


VBox(children=(HTML(value='<h1>Enrich with Overlay & Aggregation</h1>'), HTML(value="<div style='height: 5px;'…

In [None]:

obj.res.columns
# unique_values = obj.res.select('CNTR_ID').distinct().rdd.flatMap(lambda x: x).collect()
# print(unique_values)

In [10]:
temp = obj.res.drop("df1_geom", "intr_geometry")
print(temp.columns)
temp.coalesce(1).write.mode("overwrite").format("geoparquet").save(f"./outputs")

                                                                                

In [4]:
from keplergl import KeplerGl
import geopandas as gpd
from shapely.geometry import shape
import decimal

# temp_df = obj.res_agr
# temp_df = obj.res.filter(F.col('CNTR_ID').isin("NL", "BE", "DE", "IT"))
# temp_df = obj.dfs_list['comuni_EU']
temp_df = obj.dfs_list['pop_grids'].filter(F.col('T')>1000)


def prep_for_map(res_agr, crs, geom_col="geometry"):
    df = res_agr.toPandas()
    df = df.map(lambda x: float(x) if isinstance(x, decimal.Decimal) else x)
    df['geometry'] = df[f'{geom_col}'].apply(lambda geom: shape(geom))

    gdf = gpd.GeoDataFrame(df, geometry='geometry')
    gdf.crs = crs
    
    return gdf

map = KeplerGl(height=600)
map.add_data(data=prep_for_map(temp_df, obj.crs, geom_col="df2_geom"), name="pop_grids")
map

User Guide: https://docs.kepler.gl/docs/keplergl-jupyter


                                                                                

KeyError: 'df2_geom'

In [None]:
from keplergl import KeplerGl
import geopandas as gpd
from shapely.geometry import shape
import decimal

res_agr = obj.res_agr

# res_agr = obj.dfs_list['provinces']

df = res_agr.toPandas()
df = df.map(lambda x: float(x) if isinstance(x, decimal.Decimal) else x)
df['geometry'] = df['geometry'].apply(lambda geom: shape(geom))

gdf = gpd.GeoDataFrame(df, geometry='geometry')
# gdf.crs = "EPSG:3035"
gdf.crs = obj.crs

map = KeplerGl(height=600)
map.add_data(data=gdf, name="df")
map

In [12]:
full_EU_com_enriched = obj.res_agr

In [11]:
from keplergl import KeplerGl
import geopandas as gpd
from shapely.geometry import shape

com_pop = obj.res_agr

res_agr = com_pop.toPandas()
res_agr['geometry'] = res_agr['geometry'].apply(lambda geom: shape(geom))

gdf = gpd.GeoDataFrame(res_agr, geometry='geometry')
gdf.crs = "EPSG:3035"

map = KeplerGl(height=600)
map.add_data(data=gdf, name="res_agr")
map

User Guide: https://docs.kepler.gl/docs/keplergl-jupyter


KeplerGl(data={'res_agr': {'index': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,…