In [None]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

from pyspark.sql import functions as F
from pyspark.sql import Column

In [None]:
spark = SparkSession.builder.master("local[1]").appName("AoC_2022_3[2]").getOrCreate()

In [None]:
inputs = spark.read.option("delimiter", " ").schema("items string").csv("./data/aoc_3.txt")

In [None]:
items = F.col("items")

In [None]:
items_split = F.split(F.trim(items), "")

slices = lambda start, end: F.aggregate(F.sequence(start, end - 1), 
                                        F.lit(None).cast("string"),
                                        lambda acc, x: F.when(acc.isNotNull(), F.concat(acc, items_split[x])).otherwise(items_split[x]),
                                        lambda acc: F.split(F.trim(acc), ""))


items_length = F.length(items)
mid_point = F.floor(items_length / 2)

first_compartment = slices(F.lit(0), mid_point)
second_compartment = slices(mid_point, items_length)

common_items = F.array_intersect(first_compartment, second_compartment).alias("common_items")


def _sum_priority(items_array: Column) -> Column:
    priority = F.transform(items_array, lambda x: F.when(x.rlike("[a-z]"), F.ascii(x) - 96)
                                                .when(x.rlike("[A-Z]"), F.ascii(x) - 64 + 26)  
                                                .otherwise(F.lit(0)))
    return F.aggregate(priority, F.lit(0), lambda acc, x: acc + x)
    

inputs.select(F.sum(_sum_priority(common_items)).alias("result")).head()

In [None]:
input_whole = spark.read.text("./data/aoc_3.txt", wholetext=True).withColumn("filename", F.input_file_name())

inputs = input_whole.select(F.posexplode(F.split("value", "\n")).alias("linenumber", "items"), "filename")

In [None]:
group_id = F.floor(F.col("linenumber") / 3)
group_items = F.col("group_items")

group_common = F.aggregate(group_items, group_items[0], lambda acc, x: F.array_intersect(acc, x))

(inputs.groupby("filename", group_id)
       .agg(F.collect_list(items_split).alias("group_items"))
       .select(F.sum(_sum_priority(group_common)).alias("3b Result")).head())