In [1]:
from sedona.spark import *
from contextlib import contextmanager
from pyspark.storagelevel import StorageLevel
from pyspark.sql.functions import *
from pyspark.sql.functions import col, first, sum as _sum
import time
import pandas as pd

class Enricher:

    def __init__(self, crs=3035):
        self.crs = crs
        self.cores = None
        self.res = None
        self.sedona = None
        self.res_agr = None
        self.df1 = None
        self.df2 = None
            
    def setup(self, which="wherobots", ex_mem=26, dr_mem=24, log_level=None):
        if which == "wherobots":
            config = SedonaContext.builder().getOrCreate()
            self.sedona = SedonaContext.create(config)
            
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Wherobots setup started with {self.cores} cores for parellelism.")
        elif which == "sedona":
            config = SedonaContext.builder() .\
                config("spark.executor.memory", f"{ex_mem}g") .\
                config("spark.driver.memory", f"{dr_mem}g") .\
                config('spark.jars.packages',
                    'org.apache.sedona:sedona-spark-shaded-3.5_2.12:1.7.0,'
                    'org.datasyslab:geotools-wrapper:1.7.0-28.5'). \
                getOrCreate()

            self.sedona = SedonaContext.create(config)
            
            if log_level in ["OFF", "ERROR", "WARN", "INFO", "DEBUG"]:
                self.sedona.sparkContext.setLogLevel(log_level)
                
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Sedona initialized with {self.cores} cores for parellelism.")
        else:
            raise ValueError("Invalid 'which'. Choose either 'wherobots' or 'sedona'")
        
    
    @contextmanager
    def get_time(self, task_name):
        start = time.time()
        yield
        elapsed = time.time() - start
    
        print(f"{task_name}... DONE in {(elapsed/60):.2f} min" \
              if elapsed >= 60 else f"{task_name}... DONE in {elapsed:.2f} sec")

    def load_datsets(self, which="wherobots", path1=None, path2=None):
        print(f"Make sure the geometry column is named \"geometry\" in the datasets")
        if which == "wherobots":
            self.df1 = self.sedona.read.format("shapefile").load("s3://wbts-wbc-o2l64ln0km/kc2se8vgd6/data/customer-nzdf2d7gjuu3ou/countries/")
            self.df2 = self.sedona.read.format("geoparquet").load("s3://wbts-wbc-o2l64ln0km/kc2se8vgd6/data/customer-nzdf2d7gjuu3ou/bench_data/grids1.parquet")

        elif which == "sedona":
            # corine_path = "./data_Corine/land_cover_100m.parquet"
            # cens_path = "./data_Italy/merged/merged_pop_geom/merged.parquet"
            
            path1 = "./data_EU/comuni_shp/"
            path2 = "./data_EU/census_grid_EU/grids.parquet"

            self.df1 = self.sedona.read.format("shapefile").load(path1)
            self.df2 = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(path2)
            # cens_df = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(cens_path)
            # corine_df = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(corine_path)
            # cens_df = cens_df.withColumn("geometry", expr("ST_MakeValid(geometry)"))
            
        else:
            raise ValueError("Invalid 'which'. Choose either 'wherobots' or 'sedona'")
        
        print(f"Loaded. \n A.cols: {self.df1.columns} \n \n B.cols: {self.df2.columns}\n")
    

    def _make_cache(self, *dfs):
        cached_dfs = []
        for df in dfs:
            if df.storageLevel != StorageLevel.NONE:
                df.unpersist()
            df.cache()
            print(f"Dataset cached. {df.count()} rows.")
            cached_dfs.append(df)
        return cached_dfs
    
    def clear_memory(self, *keep):
        if self.res.storageLevel != StorageLevel.NONE:
            self.res.unpersist()
        self.res = None
    
    def join_chey(self, group_by=None, *cols, pred="ST_Intersects", rel_str="2********", make_geom=True, ratio=True, aggr=True, madre=True, cache=True, grid_area=1e6):
        join_expr = f"{pred}(df1.geometry, df2.geometry)"
        if pred == "ST_Relate":
            join_expr = f"{pred}(df1.geometry, df2.geometry, '{rel_str}')"

        self.res = self.df1.alias("df1").join(
            self.df2.alias("df2"), expr(join_expr)
        ).select(
            expr("df1.geometry").alias("df1_geom"),
            expr("df2.geometry").alias("df2_geom"),
            *[f"df1.{c}" for c in self.df1.columns if c != "geometry"],
            *[f"df2.{c}" for c in self.df2.columns if c != "geometry" and c not in self.df1.columns]
        )

        if make_geom:
            self.res = self.res.withColumn("intr_geom", expr("ST_Intersection(df1_geom, df2_geom)"))
            if ratio:
                if grid_area > 0:
                    self.res = self.res.withColumn("intr_ratio", expr(f"ST_Area(intr_geom) / {grid_area}"))
                else:
                    self.res = self.res.withColumn("intr_ratio", expr("ST_Area(intr_geom) / ST_Area(df2_geom)"))
        
        if aggr:
            if not ratio:
                self.res_agr = self.res.groupBy(group_by).\
                    agg(*([first(col(c)).alias(c) for c in self.res.columns if c not in cols and c != group_by]\
                        + [_sum(col(c)).alias(f"agr_{c}") for c in cols]))
            else:
                if "intr_ratio" not in self.res.columns:
                    raise ValueError("ratio column not found. run 'make_int_ratio()' first")
                self.res_agr = self.res.groupBy(group_by).\
                    agg(*([first(col(c)).alias(c) for c in self.res.columns if c not in cols and c != group_by]\
                    + [ceil(_sum(col(c) * col("intr_ratio"))).alias(f"agr_{c}") for c in cols])
                )
            if madre:
                self.res = self.res.join(self.res_agr, on=group_by, how="left")
            
            if cache:
                if madre:
                    self._make_cache(self.res)
                self._make_cache(self.res_agr)
            return self.res_agr
        
        else:
            return self.res

    def export(self, df="madre", path="outputs", name="unnamed", how="repartition", num=None, clear=False):
        if num is None:
            num = self.cores
        if how == "repartition":
            self.res = self.res.repartition(num)
        elif how == "coalesce":
            self.res = self.res.coalesce(num)
        else:
            raise ValueError("Invalid 'how'. Choose either 'repartition' or 'coalesce'")

        if df == "madre":
            self.res.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}")
        else:
            self.res_agr.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}_agr")
        
        if clear:
            self.clear_memory()


In [3]:
from sedona.spark import *
from contextlib import contextmanager
from pyspark.storagelevel import StorageLevel
from pyspark.sql.functions import *
from pyspark.sql.functions import col, first, sum as _sum
import time
import ipywidgets as widgets
from IPython.display import display, clear_output
import sys
import io
import pandas as pd

class Enricher:

    def __init__(self, crs=3035):
        self.crs = crs
        self.cores = None
        self.res = None
        self.sedona = None
        self.res_agr = None
        self.df1 = None
        self.df2 = None
            
    def setup(self, which="wherobots", ex_mem=26, dr_mem=24, log_level=None):
        if which == "wherobots":
            config = SedonaContext.builder().getOrCreate()
            self.sedona = SedonaContext.create(config)
            
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Wherobots setup started with {self.cores} cores for parellelism.")
        elif which == "sedona":
            config = SedonaContext.builder() .\
                config("spark.executor.memory", f"{ex_mem}g") .\
                config("spark.driver.memory", f"{dr_mem}g") .\
                config('spark.jars.packages',
                    'org.apache.sedona:sedona-spark-shaded-3.5_2.12:1.7.0,'
                    'org.datasyslab:geotools-wrapper:1.7.0-28.5'). \
                getOrCreate()

            self.sedona = SedonaContext.create(config)
            
            if log_level in ["OFF", "ERROR", "WARN", "INFO", "DEBUG"]:
                self.sedona.sparkContext.setLogLevel(log_level)
                
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Sedona initialized with {self.cores} cores for parellelism.")
        else:
            raise ValueError("Invalid 'which'. Choose either 'wherobots' or 'sedona'")
        
    
    @contextmanager
    def get_time(self, task_name):
        start = time.time()
        yield
        elapsed = time.time() - start
    
        print(f"{task_name}... DONE in {(elapsed/60):.2f} min" \
              if elapsed >= 60 else f"{task_name}... DONE in {elapsed:.2f} sec")

    def load_datsets(self, which="wherobots", path1=None, path2=None):
        print(f"Make sure the geometry column is named \"geometry\" in the datasets")
        if which == "wherobots":
            self.df1 = self.sedona.read.format("shapefile").load("s3://wbts-wbc-o2l64ln0km/kc2se8vgd6/data/customer-nzdf2d7gjuu3ou/countries/")
            self.df2 = self.sedona.read.format("geoparquet").load("s3://wbts-wbc-o2l64ln0km/kc2se8vgd6/data/customer-nzdf2d7gjuu3ou/bench_data/grids1.parquet")

        elif which == "sedona":
            # corine_path = "./data_Corine/land_cover_100m.parquet"
            # cens_path = "./data_Italy/merged/merged_pop_geom/merged.parquet"
            
            path1 = "./data_EU/comuni_shp/"
            path2 = "./data_EU/census_grid_EU/grids.parquet"

            self.df1 = self.sedona.read.format("shapefile").load(path1)
            self.df2 = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(path2)
            # cens_df = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(cens_path)
            # corine_df = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(corine_path)
            # cens_df = cens_df.withColumn("geometry", expr("ST_MakeValid(geometry)"))
            
        else:
            raise ValueError("Invalid 'which'. Choose either 'wherobots' or 'sedona'")
        
        print(f"Loaded. \n A.cols: {self.df1.columns} \n \n B.cols: {self.df2.columns}\n")
    

    def _make_cache(self, *dfs):
        cached_dfs = []
        for df in dfs:
            if df.storageLevel != StorageLevel.NONE:
                df.unpersist()
            df.cache()
            print(f"Dataset cached. {df.count()} rows.")
            cached_dfs.append(df)
        return cached_dfs
    
    def clear_memory(self, *keep):
        if self.res.storageLevel != StorageLevel.NONE:
            self.res.unpersist()
        self.res = None
    
    def join_chey(self, *cols, group_by=None, pred="ST_Intersects", rel_str="2********", make_geom=True, ratio=True, aggr=True, madre=True, cache=True, grid_area=1e6):
        join_expr = f"{pred}(df1.geometry, df2.geometry)"
        if pred == "ST_Relate":
            join_expr = f"{pred}(df1.geometry, df2.geometry, '{rel_str}')"

        self.res = self.df1.alias("df1").join(
            self.df2.alias("df2"), expr(join_expr)
        ).select(
            expr("df1.geometry").alias("df1_geom"),
            expr("df2.geometry").alias("df2_geom"),
            *[f"df1.{c}" for c in self.df1.columns if c != "geometry"],
            *[f"df2.{c}" for c in self.df2.columns if c != "geometry" and c not in self.df1.columns]
        )

        if make_geom:
            self.res = self.res.withColumn("intr_geom", expr("ST_Intersection(df1_geom, df2_geom)"))
            if ratio:
                if grid_area > 0:
                    self.res = self.res.withColumn("intr_ratio", expr(f"ST_Area(intr_geom) / {grid_area}"))
                else:
                    self.res = self.res.withColumn("intr_ratio", expr("ST_Area(intr_geom) / ST_Area(df2_geom)"))
        
        if aggr:
            if not ratio:
                self.res_agr = self.res.groupBy(group_by).\
                    agg(*([first(col(c)).alias(c) for c in self.res.columns if c not in cols and c != group_by]\
                        + [_sum(col(c)).alias(f"agr_{c}") for c in cols]))
            else:
                if "intr_ratio" not in self.res.columns:
                    raise ValueError("ratio column not found. run 'make_int_ratio()' first")
                self.res_agr = self.res.groupBy(group_by).\
                    agg(*([first(col(c)).alias(c) for c in self.res.columns if c not in cols and c != group_by]\
                    + [ceil(_sum(col(c) * col("intr_ratio"))).alias(f"agr_{c}") for c in cols])
                )
            if madre:
                self.res = self.res.join(self.res_agr, on=group_by, how="left")
            
        if cache:
            if aggr:
                self._make_cache(self.res_agr)
            self._make_cache(self.res)

        return self.res

    def export(self, df="madre", path="outputs", name="unnamed", how="repartition", num=None, clear=False):
        if num is None:
            num = self.cores
        if how == "repartition":
            self.res = self.res.repartition(num)
        elif how == "coalesce":
            self.res = self.res.coalesce(num)
        else:
            raise ValueError("Invalid 'how'. Choose either 'repartition' or 'coalesce'")

        if df == "madre":
            self.res.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}")
        else:
            self.res_agr.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}_agr")
        
        if clear:
            self.clear_memory()







import ipywidgets as widgets
from IPython.display import display
from io import StringIO
import sys

class EnricherUI:
    def __init__(self, enricher):
        self.enricher = enricher
        self.selected_cols = []  # Track selected columns for aggregation
        self.group_by_col = None  # Track the selected group_by column
        self._init_ui()

    def _init_ui(self):
        # Heading: Join
        self.join_heading = widgets.HTML(value="<h2>Join</h2>")

        # Checkboxes for Join
        self.make_geom_checkbox = widgets.Checkbox(value=False, description="Make Geometry")
        self.intr_ratio_checkbox = widgets.Checkbox(value=False, description="Intersection Ratio")
        self.madre_checkbox = widgets.Checkbox(value=False, description="Madre")

        # Text boxes for Predicate and Rel Str
        self.predicate_text = widgets.Text(value="ST_Intersects", description="Predicate:")
        self.rel_str_text = widgets.Text(value="2********", description="Rel Str:")

        # Grid area text box (disabled by default)
        self.grid_area_text = widgets.FloatText(value=1e6, description="Grid Area:", disabled=True)

        # Checkboxes for Cache and Aggregate
        self.cache_checkbox = widgets.Checkbox(value=True, description="Cache")
        self.aggregate_checkbox = widgets.Checkbox(value=False, description="Aggregate")

        # Group by dropdown (disabled by default)
        self.group_by_dropdown = widgets.Dropdown(options=self.enricher.df1.columns, description="Group By:", disabled=True)

        # Aggregate columns dropdown (disabled by default)
        self.agg_cols_dropdown = widgets.SelectMultiple(options=self.enricher.df2.columns, description="Aggregate Cols:", disabled=True)

        # Selected aggregate columns display
        self.selected_agg_cols_label = widgets.Label(value="Selected cols to aggregate:")
        self.selected_agg_cols = widgets.Select(options=[], description="Selected:")

        # Selected group_by column display
        self.selected_group_by_label = widgets.Label(value="Selected group_by column:")
        self.selected_group_by = widgets.Label(value="None")

        # Join button
        self.join_button = widgets.Button(description="Join")

        # Heading: Export
        self.export_heading = widgets.HTML(value="<h2>Export</h2>")

        # Export fields
        self.df_text = widgets.Text(value="madre", description="DF:")
        self.path_text = widgets.Text(value="outputs", description="Path:")
        self.name_text = widgets.Text(value="unnamed", description="Name:")
        self.how_text = widgets.Text(value="repartition", description="How:")
        self.num_text = widgets.IntText(value=self.enricher.cores, description="Num:")
        self.clear_checkbox = widgets.Checkbox(value=False, description="Clear")
        self.export_button = widgets.Button(description="Export")

        # Console output text box
        self.console_output = widgets.Textarea(value="", description="Console:", layout=widgets.Layout(width="100%", height="200px"))
        self.clear_console_button = widgets.Button(description="Clear Console")

        # Layout
        self._setup_layout()
        self._setup_event_handlers()

    def _setup_layout(self):
        # Join section
        self.join_section = widgets.VBox([
            self.join_heading,
            widgets.HBox([self.make_geom_checkbox, self.intr_ratio_checkbox, self.madre_checkbox]),
            widgets.HBox([self.predicate_text, self.rel_str_text]),
            self.grid_area_text,
            widgets.HBox([self.cache_checkbox, self.aggregate_checkbox]),
            self.group_by_dropdown,
            self.selected_group_by_label,
            self.selected_group_by,
            self.agg_cols_dropdown,
            self.selected_agg_cols_label,
            self.selected_agg_cols,
            self.join_button
        ])

        # Export section
        self.export_section = widgets.VBox([
            self.export_heading,
            self.df_text,
            self.path_text,
            self.name_text,
            self.how_text,
            self.num_text,
            self.clear_checkbox,
            self.export_button
        ])

        # Console section
        self.console_section = widgets.VBox([
            self.console_output,
            self.clear_console_button
        ])

        # Display everything
        display(widgets.VBox([self.join_section, self.export_section, self.console_section]))

    def _setup_event_handlers(self):
        # Enable/disable grid area based on intr_ratio checkbox
        def on_intr_ratio_change(change):
            self.grid_area_text.disabled = not change['new']
        self.intr_ratio_checkbox.observe(on_intr_ratio_change, names='value')

        # Enable/disable group by and aggregate cols based on aggregate checkbox
        def on_aggregate_change(change):
            self.group_by_dropdown.disabled = not change['new']
            self.agg_cols_dropdown.disabled = not change['new']
        self.aggregate_checkbox.observe(on_aggregate_change, names='value')

        # Handle selection of group_by column
        def on_group_by_change(change):
            self.group_by_col = change['new']
            self.selected_group_by.value = self.group_by_col
        self.group_by_dropdown.observe(on_group_by_change, names='value')

        # Handle selection of aggregate columns
        def on_agg_cols_change(change):
            self.update_selected_columns(change)
        self.agg_cols_dropdown.observe(on_agg_cols_change, names='value')

        # Join button click handler
        def on_join_button_click(b):
            # Capture console output
            old_stdout = sys.stdout
            sys.stdout = captured_output = StringIO()

            try:
                group_by = self.group_by_col if self.aggregate_checkbox.value else None
                agg_cols = self.selected_cols if self.aggregate_checkbox.value else []

                self.enricher.join_chey(
                    group_by=group_by,
                    pred=self.predicate_text.value,
                    rel_str=self.rel_str_text.value,
                    make_geom=self.make_geom_checkbox.value,
                    ratio=self.intr_ratio_checkbox.value,
                    aggr=self.aggregate_checkbox.value,
                    madre=self.madre_checkbox.value,
                    cache=self.cache_checkbox.value,
                    grid_area=self.grid_area_text.value,
                    *agg_cols
                )
                # print(f"pred: {self.predicate_text.value}, \n rel_str: {self.rel_str_text.value} \n make_geom: {self.make_geom_checkbox.value} \n ratio: {self.intr_ratio_checkbox.value} \n aggr: {self.aggregate_checkbox.value} \n madre: {self.madre_checkbox.value} \n cache: {self.cache_checkbox.value} \n grid_area: {self.grid_area_text.value}")
                print("Join operation completed.")
            except Exception as e:
                print(f"Error: {str(e)}")
            finally:
                sys.stdout = old_stdout
                self.console_output.value += captured_output.getvalue()

        self.join_button.on_click(on_join_button_click)

        # Export button click handler
        def on_export_button_click(b):
            # Capture console output
            old_stdout = sys.stdout
            sys.stdout = captured_output = StringIO()

            try:
                self.enricher.export(
                    df=self.df_text.value,
                    path=self.path_text.value,
                    name=self.name_text.value,
                    how=self.how_text.value,
                    num=self.num_text.value,
                    clear=self.clear_checkbox.value
                )
                print("Export operation completed.")
            except Exception as e:
                print(f"Error: {str(e)}")
            finally:
                sys.stdout = old_stdout
                self.console_output.value += captured_output.getvalue()

        self.export_button.on_click(on_export_button_click)

        # Clear console button click handler
        def on_clear_console_button_click(b):
            self.console_output.value = ""
        self.clear_console_button.on_click(on_clear_console_button_click)

    def update_selected_columns(self, change):
        """Update the list of selected columns for group_by and aggregate_cols."""
        # Add newly selected columns to the selected list
        for col in change["new"]:
            if col not in self.selected_cols:
                self.selected_cols.append(col)
    
        # Update the dropdown options to disable already selected columns
        available_options = [col for col in self.enricher.df2.columns if col not in self.selected_cols]
        self.agg_cols_dropdown.options = available_options
    
        # Update the display of selected columns
        self.selected_agg_cols.options = self.selected_cols

In [None]:
from sedona.spark import *
from contextlib import contextmanager
from pyspark.storagelevel import StorageLevel
from pyspark.sql.functions import *
from pyspark.sql.functions import col, first, sum as _sum
import time
import ipywidgets as widgets
from IPython.display import display, clear_output
import sys
import io
import pandas as pd

class Enricher:

    def __init__(self, crs=3035):
        self.crs = crs
        self.cores = None
        self.res = None
        self.sedona = None
        self.res_agr = None
        self.df1 = None
        self.df2 = None
            
    def setup(self, which="wherobots", ex_mem=26, dr_mem=24, log_level=None):
        if which == "wherobots":
            config = SedonaContext.builder().getOrCreate()
            self.sedona = SedonaContext.create(config)
            
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Wherobots setup started with {self.cores} cores for parellelism.")
        elif which == "sedona":
            config = SedonaContext.builder() .\
                config("spark.executor.memory", f"{ex_mem}g") .\
                config("spark.driver.memory", f"{dr_mem}g") .\
                config('spark.jars.packages',
                    'org.apache.sedona:sedona-spark-shaded-3.5_2.12:1.7.0,'
                    'org.datasyslab:geotools-wrapper:1.7.0-28.5'). \
                getOrCreate()

            self.sedona = SedonaContext.create(config)
            
            if log_level in ["OFF", "ERROR", "WARN", "INFO", "DEBUG"]:
                self.sedona.sparkContext.setLogLevel(log_level)
                
            self.cores = self.sedona.sparkContext.defaultParallelism
            print(f"Sedona initialized with {self.cores} cores for parellelism.")
        else:
            raise ValueError("Invalid 'which'. Choose either 'wherobots' or 'sedona'")
        
    
    @contextmanager
    def get_time(self, task_name):
        start = time.time()
        yield
        elapsed = time.time() - start
    
        print(f"{task_name}... DONE in {(elapsed/60):.2f} min" \
              if elapsed >= 60 else f"{task_name}... DONE in {elapsed:.2f} sec")

    def load_datsets(self, which="wherobots", path1=None, path2=None):
        print(f"Make sure the geometry column is named \"geometry\" in the datasets")
        if which == "wherobots":
            self.df1 = self.sedona.read.format("shapefile").load("s3://wbts-wbc-o2l64ln0km/kc2se8vgd6/data/customer-nzdf2d7gjuu3ou/countries/")
            self.df2 = self.sedona.read.format("geoparquet").load("s3://wbts-wbc-o2l64ln0km/kc2se8vgd6/data/customer-nzdf2d7gjuu3ou/bench_data/grids1.parquet")

        elif which == "sedona":
            # corine_path = "./data_Corine/land_cover_100m.parquet"
            # cens_path = "./data_Italy/merged/merged_pop_geom/merged.parquet"
            
            path1 = "./data_EU/comuni_shp/"
            path2 = "./data_EU/census_grid_EU/grids.parquet"

            self.df1 = self.sedona.read.format("shapefile").load(path1)
            self.df2 = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(path2)
            # cens_df = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(cens_path)
            # corine_df = self.sedona.read.format("geoparquet").option("legacyMode", "true").load(corine_path)
            # cens_df = cens_df.withColumn("geometry", expr("ST_MakeValid(geometry)"))
            
        else:
            raise ValueError("Invalid 'which'. Choose either 'wherobots' or 'sedona'")
        
        print(f"Loaded. \n A.cols: {self.df1.columns} \n \n B.cols: {self.df2.columns}\n")
    

    def _make_cache(self, *dfs):
        cached_dfs = []
        for df in dfs:
            if df.storageLevel != StorageLevel.NONE:
                df.unpersist()
            df.cache()
            print(f"Dataset cached. {df.count()} rows.")
            cached_dfs.append(df)
        return cached_dfs
    
    def clear_memory(self, *keep):
        if self.res.storageLevel != StorageLevel.NONE:
            self.res.unpersist()
        self.res = None
    
    def join_chey(self, *cols, group_by=None, pred="ST_Intersects", rel_str="2********", make_geom=True, ratio=True, aggr=True, madre=True, cache=True, grid_area=1e6):
        join_expr = f"{pred}(df1.geometry, df2.geometry)"
        if pred == "ST_Relate":
            join_expr = f"{pred}(df1.geometry, df2.geometry, '{rel_str}')"

        self.res = self.df1.alias("df1").join(
            self.df2.alias("df2"), expr(join_expr)
        ).select(
            expr("df1.geometry").alias("df1_geom"),
            expr("df2.geometry").alias("df2_geom"),
            *[f"df1.{c}" for c in self.df1.columns if c != "geometry"],
            *[f"df2.{c}" for c in self.df2.columns if c != "geometry" and c not in self.df1.columns]
        )

        if make_geom:
            self.res = self.res.withColumn("intr_geom", expr("ST_Intersection(df1_geom, df2_geom)"))
            if ratio:
                if grid_area > 0:
                    self.res = self.res.withColumn("intr_ratio", expr(f"ST_Area(intr_geom) / {grid_area}"))
                else:
                    self.res = self.res.withColumn("intr_ratio", expr("ST_Area(intr_geom) / ST_Area(df2_geom)"))
        
        if aggr:
            if not ratio:
                self.res_agr = self.res.groupBy(group_by).\
                    agg(*([first(col(c)).alias(c) for c in self.res.columns if c not in cols and c != group_by]\
                        + [_sum(col(c)).alias(f"agr_{c}") for c in cols]))
            else:
                if "intr_ratio" not in self.res.columns:
                    raise ValueError("ratio column not found. run 'make_int_ratio()' first")
                self.res_agr = self.res.groupBy(group_by).\
                    agg(*([first(col(c)).alias(c) for c in self.res.columns if c not in cols and c != group_by]\
                    + [ceil(_sum(col(c) * col("intr_ratio"))).alias(f"agr_{c}") for c in cols])
                )
            if madre:
                self.res = self.res.join(self.res_agr, on=group_by, how="left")
            
        if cache:
            if aggr:
                self._make_cache(self.res_agr)
            self._make_cache(self.res)

        return self.res

    def export(self, df="madre", path="outputs", name="unnamed", how="repartition", num=None, clear=False):
        if num is None:
            num = self.cores
        if how == "repartition":
            self.res = self.res.repartition(num)
        elif how == "coalesce":
            self.res = self.res.coalesce(num)
        else:
            raise ValueError("Invalid 'how'. Choose either 'repartition' or 'coalesce'")

        if df == "madre":
            self.res.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}")
        else:
            self.res_agr.write.mode("overwrite").format("geoparquet").save(f"./{path}/" + f"/{name}_agr")
        
        if clear:
            self.clear_memory()







import ipywidgets as widgets
from IPython.display import display
from io import StringIO
import sys
import logging
from pyspark.sql import SparkSession

class EnricherUI:
    def __init__(self, enricher):
        self.enricher = enricher
        self.selected_cols = []
        self.group_by_col = None
        self._init_ui()

    def capture_output(self, func, *args, **kwargs):
        """Capture standard output and Spark logs, redirecting them to the console box."""
        # Create a StringIO buffer to capture output
        buffer = StringIO()

        # Redirect sys.stdout to the buffer
        old_stdout = sys.stdout
        sys.stdout = buffer

        # Redirect Spark logs to the buffer
        spark_logger = logging.getLogger("py4j")
        spark_logger.setLevel(logging.INFO)
        handler = logging.StreamHandler(buffer)
        spark_logger.addHandler(handler)

        try:
            func(*args, **kwargs)
        except Exception as e:
            buffer.write(f"Error: {str(e)}\n")
        finally:
            sys.stdout = old_stdout
            spark_logger.removeHandler(handler)

            # Update the console box with the captured output
            self.console_output.value += buffer.getvalue()


    def _init_ui(self):
        # Heading: Join
        self.join_heading = widgets.HTML(value="<h2>Join</h2>")

        # Checkboxes for Join
        self.make_geom_checkbox = widgets.Checkbox(value=False, description="Make Geometry")
        self.intr_ratio_checkbox = widgets.Checkbox(value=False, description="Intersection Ratio")
        self.madre_checkbox = widgets.Checkbox(value=False, description="Madre")

        # Text boxes for Predicate and Rel Str
        self.predicate_text = widgets.Text(value="ST_Intersects", description="Predicate:")
        self.rel_str_text = widgets.Text(value="2********", description="Rel Str:")

        # Grid area text box (disabled by default)
        self.grid_area_text = widgets.FloatText(value=1e6, description="Grid Area:", disabled=True)

        # Checkboxes for Cache and Aggregate
        self.cache_checkbox = widgets.Checkbox(value=True, description="Cache")
        self.aggregate_checkbox = widgets.Checkbox(value=False, description="Aggregate")

        # Group by dropdown (disabled by default)
        self.group_by_dropdown = widgets.Dropdown(options=self.enricher.df1.columns, description="Group By:", disabled=True)

        # Aggregate columns dropdown (disabled by default)
        self.agg_cols_dropdown = widgets.SelectMultiple(options=self.enricher.df2.columns, description="Aggregate Cols:", disabled=True)

        # Selected aggregate columns display
        self.selected_agg_cols_label = widgets.Label(value="Selected cols to aggregate:")
        self.selected_agg_cols = widgets.Select(options=[], description="Selected:")

        # Selected group_by column display
        self.selected_group_by_label = widgets.Label(value="Selected group_by column:")
        self.selected_group_by = widgets.Label(value="None")

        # Join button
        self.join_button = widgets.Button(description="Join")

        # Heading: Export
        self.export_heading = widgets.HTML(value="<h2>Export</h2>")

        # Export fields
        self.df_text = widgets.Text(value="madre", description="DF:")
        self.path_text = widgets.Text(value="outputs", description="Path:")
        self.name_text = widgets.Text(value="unnamed", description="Name:")
        self.how_text = widgets.Text(value="repartition", description="How:")
        self.num_text = widgets.IntText(value=self.enricher.cores, description="Num:")
        self.clear_checkbox = widgets.Checkbox(value=False, description="Clear")
        self.export_button = widgets.Button(description="Export")

        # Console output text box
        self.console_output = widgets.Textarea(value="", description="Console:", layout=widgets.Layout(width="100%", height="200px"))
        self.clear_console_button = widgets.Button(description="Clear Console")

        # Layout
        self._setup_layout()
        self._setup_event_handlers()

    def _setup_layout(self):
        # Join section
        self.join_section = widgets.VBox([
            self.join_heading,
            widgets.HBox([self.make_geom_checkbox, self.intr_ratio_checkbox, self.madre_checkbox]),
            widgets.HBox([self.predicate_text, self.rel_str_text]),
            self.grid_area_text,
            widgets.HBox([self.cache_checkbox, self.aggregate_checkbox]),
            self.group_by_dropdown,
            self.selected_group_by_label,
            self.selected_group_by,
            self.agg_cols_dropdown,
            self.selected_agg_cols_label,
            self.selected_agg_cols,
            self.join_button
        ])

        # Export section
        self.export_section = widgets.VBox([
            self.export_heading,
            self.df_text,
            self.path_text,
            self.name_text,
            self.how_text,
            self.num_text,
            self.clear_checkbox,
            self.export_button
        ])

        # Console section
        self.console_section = widgets.VBox([
            self.console_output,
            self.clear_console_button
        ])

        # Display everything
        display(widgets.VBox([self.join_section, self.export_section, self.console_section]))

    def _setup_event_handlers(self):
        # Enable/disable grid area based on intr_ratio checkbox
        def on_intr_ratio_change(change):
            self.grid_area_text.disabled = not change['new']
        self.intr_ratio_checkbox.observe(on_intr_ratio_change, names='value')

        # Enable/disable group by and aggregate cols based on aggregate checkbox
        def on_aggregate_change(change):
            self.group_by_dropdown.disabled = not change['new']
            self.agg_cols_dropdown.disabled = not change['new']
        self.aggregate_checkbox.observe(on_aggregate_change, names='value')

        # Handle selection of group_by column
        def on_group_by_change(change):
            self.group_by_col = change['new']
            self.selected_group_by.value = self.group_by_col
        self.group_by_dropdown.observe(on_group_by_change, names='value')

        # Handle selection of aggregate columns
        def on_agg_cols_change(change):
            self.update_selected_columns(change)
        self.agg_cols_dropdown.observe(on_agg_cols_change, names='value')


        # Join button click handler
        def on_join_button_click(b):           
            self.capture_output(self._perform_join)            
        self.join_button.on_click(on_join_button_click)
        
    def _perform_join(self):
        """Perform the join operation."""
        group_by = self.group_by_col if self.aggregate_checkbox.value else None
        agg_cols = self.selected_cols if self.aggregate_checkbox.value else []

        self.enricher.join_chey(
                group_by=group_by,
                pred=self.predicate_text.value,
                rel_str=self.rel_str_text.value,
                make_geom=self.make_geom_checkbox.value,
                ratio=self.intr_ratio_checkbox.value,
                aggr=self.aggregate_checkbox.value,
                madre=self.madre_checkbox.value,
                cache=self.cache_checkbox.value,
                grid_area=self.grid_area_text.value,
                *agg_cols
        )
        print("Join operation completed.")

        # Export button click handler
        def on_export_button_click(b):
            # Capture console output
            old_stdout = sys.stdout
            sys.stdout = captured_output = StringIO()

            try:
                self.enricher.export(
                    df=self.df_text.value,
                    path=self.path_text.value,
                    name=self.name_text.value,
                    how=self.how_text.value,
                    num=self.num_text.value,
                    clear=self.clear_checkbox.value
                )
                print("Export operation completed.")
            except Exception as e:
                print(f"Error: {str(e)}")
            finally:
                sys.stdout = old_stdout
                self.console_output.value += captured_output.getvalue()

        self.export_button.on_click(on_export_button_click)

        # Clear console button click handler
        def on_clear_console_button_click(b):
            self.console_output.value = ""
        self.clear_console_button.on_click(on_clear_console_button_click)

    def update_selected_columns(self, change):
        """Update the list of selected columns for group_by and aggregate_cols."""
        # Add newly selected columns to the selected list
        for col in change["new"]:
            if col not in self.selected_cols:
                self.selected_cols.append(col)
    
        # Update the dropdown options to disable already selected columns
        available_options = [col for col in self.enricher.df2.columns if col not in self.selected_cols]
        self.agg_cols_dropdown.options = available_options
    
        # Update the display of selected columns
        self.selected_agg_cols.options = self.selected_cols

In [None]:
# Setup

obj = Enricher()
obj.setup(which="sedona", ex_mem=26, dr_mem=24)
path1="./data_EU/countries_shp/"
path2="./data_EU/census_grid_EU/grids.parquet"

obj.load_datsets(which="sedona", path1=path1, path2=path2)


In [4]:
# GUI

obj_ui = EnricherUI(obj)


VBox(children=(VBox(children=(HTML(value='<h2>Join</h2>'), HBox(children=(Checkbox(value=False, description='M…

25/01/27 11:31:44 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/01/27 11:31:56 WARN JoinQuery: UseIndex is true, but no index exists. Will build index on the fly.
                                                                                

In [5]:
obj_ui.enricher.res.show(5)

25/01/27 11:36:26 WARN JoinQuery: UseIndex is true, but no index exists. Will build index on the fly.
[Stage 27:>                                                         (0 + 1) / 1]

+--------------------+--------------------+--------------+-------+---------+--------------------+--------------------+---------+-------------+--------------------+--------------------+---------+--------------+--------------------+---+-----+-----+------+------+------+----+-----+------+----+-----+------+-------+------------+---------+------------------+
|            df1_geom|            df2_geom|       COMM_ID|CNTR_ID|CNTR_CODE|           COMM_NAME|           NAME_ASCI|TRUE_FLAG|     NSI_CODE|            NAME_NSI|           NAME_LATN|NUTS_CODE|           FID|              GRD_ID|  T|    M|    F|Y_LT15|Y_1564|Y_GE65| EMP|  NAT|EU_OTH| OTH| SAME|CHG_IN|CHG_OUT|LAND_SURFACE|POPULATED|CONFIDENTIALSTATUS|
+--------------------+--------------------+--------------+-------+---------+--------------------+--------------------+---------+-------------+--------------------+--------------------+---------+--------------+--------------------+---+-----+-----+------+------+------+----+-----+------+----+--

                                                                                

In [2]:
# No GUI


obj = Enricher()

obj.setup(which="sedona", ex_mem=26, dr_mem=24)

path1="./data_EU/countries_shp/"
path2="./data_EU/census_grid_EU/grids.parquet"

obj.load_datsets(which="sedona", path1=path1, path2=path2)

agr_cols = ['T', 'M', 'F', 'Y_LT15', 'Y_1564', 'Y_GE65', 'EMP', 'NAT', 'EU_OTH', 'OTH', 'SAME', 'CHG_IN', 'CHG_OUT', 'LAND_SURFACE', 'POPULATED']

with obj.get_time("join + make geom + intersection areas + aggregate"):
    obj.join_chey('CNTR_ID', *agr_cols)

25/01/24 14:42:51 WARN Utils: Your hostname, marvin resolves to a loopback address: 127.0.1.1; using 172.20.27.4 instead (on interface eth0)
25/01/24 14:42:51 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


:: loading settings :: url = jar:file:/data/homes_data/sudheer/benchmark_data/sedona_venv/lib/python3.12/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /data/homes_data/sudheer/.ivy2/cache
The jars for the packages stored in: /data/homes_data/sudheer/.ivy2/jars
org.apache.sedona#sedona-spark-shaded-3.5_2.12 added as a dependency
org.datasyslab#geotools-wrapper added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-8cfb8944-1123-44e5-8687-12d080f76c5d;1.0
	confs: [default]
	found org.apache.sedona#sedona-spark-shaded-3.5_2.12;1.7.0 in central
	found org.datasyslab#geotools-wrapper;1.7.0-28.5 in central
:: resolution report :: resolve 261ms :: artifacts dl 10ms
	:: modules in use:
	org.apache.sedona#sedona-spark-shaded-3.5_2.12;1.7.0 from central in [default]
	org.datasyslab#geotools-wrapper;1.7.0-28.5 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	-----------------------------------------

Sedona initialized with 10 cores for parellelism.
Make sure the geometry column is named "geometry" in the datasets


                                                                                

Loaded. 
 A.cols: ['geometry', 'COMM_ID', 'CNTR_ID', 'CNTR_CODE', 'COMM_NAME', 'NAME_ASCI', 'TRUE_FLAG', 'NSI_CODE', 'NAME_NSI', 'NAME_LATN', 'NUTS_CODE', 'FID'] 
 
 B.cols: ['GRD_ID', 'T', 'M', 'F', 'Y_LT15', 'Y_1564', 'Y_GE65', 'EMP', 'NAT', 'EU_OTH', 'OTH', 'SAME', 'CHG_IN', 'CHG_OUT', 'LAND_SURFACE', 'POPULATED', 'CONFIDENTIALSTATUS', 'geometry']



25/01/24 14:43:13 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/01/24 14:43:26 WARN JoinQuery: UseIndex is true, but no index exists. Will build index on the fly.
25/01/24 14:43:40 WARN JoinQuery: UseIndex is true, but no index exists. Will build index on the fly.
25/01/24 15:04:33 WARN MemoryStore: Not enough space to cache rdd_65_78 in memory! (computed 844.5 MiB so far)
25/01/24 15:04:33 WARN BlockManager: Persisting block rdd_65_78 to disk instead.
25/01/24 15:04:34 WARN MemoryStore: Not enough space to cache rdd_65_75 in memory! (computed 838.1 MiB so far)
25/01/24 15:04:34 WARN BlockManager: Persisting block rdd_65_75 to disk instead.
25/01/24 15:04:36 WARN MemoryStore: Not enough space to cache rdd_65_173 in memory! (computed 128.6 MiB so far)
25/01/24 15:04:36 WARN BlockManager: Persisting block rdd_65_173 to disk instead.
25/01/24 15:04:37 WARN Memory

Dataset cached. 6315546 rows.


25/01/24 15:09:23 WARN JoinQuery: UseIndex is true, but no index exists. Will build index on the fly.

Dataset cached. 41 rows.
join + make geom + intersection areas + aggregate... DONE in 44.35 min


                                                                                