In [None]:
"""~Wherobots setup"""

from sedona.spark import *
from sedona.spark import SedonaContext, Adapter
from contextlib import contextmanager
import time
from pyspark.sql.functions import col, expr
from pyspark.sql.functions import desc
from sedona.core.formatMapper.shapefileParser import ShapefileReader

config = SedonaContext.builder().appName('sedona-test').getOrCreate()

sedona = SedonaContext.create(config)
sc = sedona.sparkContext
print(f'master: {sedona.conf.get("spark.master")}')
print(f"sc.defaultParallelism: {sc.defaultParallelism}")

""""""
num_partitions = sc.defaultParallelism * 2
print(f"using partitions: {num_partitions}")

In [None]:
from pyspark.sql.functions import regexp_replace
from sedona.sql.types import GeometryType

total_time = 0

@contextmanager
def get_time(task_name):
    start = time.time()
    yield
    elapsed = time.time() - start
    global total_time
    total_time += elapsed
    print(f"{task_name}... DONE in {(elapsed/60):.2f} min" \
          if elapsed >= 60 else f"{task_name}... DONE in {elapsed:.2f} sec")
    
def chudu(df):
    rows = df.count()
    cols = df.columns
    print(f"Total rows: {rows}, Total columns: {cols}")
    
    memory = df.memory_usage(deep=True).sum() / 1024**2
    print(f"Memory usage: {memory:.2f} MB")
    
    print("\nFirst 5 rows:")
    print(df.head(5))
    
    unique_fields = [col for col in df.columns if df[col].is_unique]
    print("\nCols with unique values no repeats:")
    print(unique_fields if unique_fields else "None")
    
    nan_columns = df.isna().sum()
    nan_columns = nan_columns[nan_columns > 0]
    print("\nColumns with NaN:")
    print(nan_columns if not nan_columns.empty else "None")
    print()
    print()

countries_path = "./my_data/data_EU/countries_shp/"
countries_rdd = ShapefileReader.readToGeometryRDD(sc, countries_path)
countries_df = Adapter.toDf(countries_rdd, sedona)

grids_df = sedona.read.parquet("./my_data/data_EU/census_grid_EU/grids.parquet")
# grids_df = grids_df.withColumn("geom", grids_df["geom"].cast(GeometryType))
grids_df = grids_df.withColumn("geom", expr("ST_GeomFromWKB(geom)"))

# grids_df = grids_df.withColumn("geom", expr("ST_AsText(geom)"))

grids_df = grids_df.withColumn(
    "GRD_ID", regexp_replace("GRD_ID", "CRS3035RES1000m", "")
)
grids_df = grids_df.select("GRD_ID", "T", "geom")
countries_df = countries_df.select("CNTR_ID", "NAME_ENGL", "geometry")

joined_df = grids_df.alias("grd").join(
    countries_df.alias("con"),
    expr("ST_Intersects(grd.geom, con.geometry)"),
    "inner"
)
chudu(joined_df)

# print(joined_df.count())