In [None]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

from pyspark.sql import functions as F
from pyspark.sql.functions import col as c, lit as l
from pyspark.sql import Column
from typing import List
from pyspark.sql import Window as W
from functools import reduce
import operator

In [None]:
spark = SparkSession.builder.master("local[2]").appName("AoC_2022_8[2]").getOrCreate()

In [None]:
inputs_raw = (spark.read.text("./data/aoc_8.txt", wholetext=True)
                   .withColumn("filename", F.input_file_name())
                   .withColumn("value", F.split("value", "\n"))
                   .select(F.posexplode("value").alias("rn", "value"), "filename"))

trees_arr = inputs_raw.withColumn("trees", F.split("value", "")).select("rn", "trees", "filename")

nr_cols = trees_arr.select(F.size("trees").alias("size")).head().size - 1

row_selection_expr = [(F.col("trees")[r]).cast("int").alias(f"col_{r}") for r in range(0, nr_cols)]
inputs = trees_arr.select("rn", *row_selection_expr, "filename")

In [None]:
ws = W.partitionBy("filename").orderBy("rn")
up = ws.rowsBetween(W.unboundedPreceding, -1)
down = ws.rowsBetween(1, W.unboundedFollowing)
ws_complete_frame = ws.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

nr_colnames = [f"col_{r}" for r in range(0, nr_cols)]

In [None]:
def traversal_compute(idx: int) -> Column:
    c = F.col(f"col_{idx}")
    up_max = F.max(c).over(up)
    down_max = F.max(c).over(down)
    left_max = F.greatest(*nr_colnames[0:idx], F.lit(-1))
    right_max = F.greatest(*nr_colnames[idx+1:], F.lit(-1))
    is_visible = (c > up_max) | (c > down_max) | (c > left_max) | (c > right_max)
    return F.when(is_visible, F.lit(1)).otherwise(F.lit(0))

def visibility(idx: int) -> Column:
    if idx == 0 or idx == (nr_cols - 1):
        return F.lit(1)
    is_min_rn = F.min("rn").over(ws_complete_frame) == F.col("rn")
    is_max_rn = F.max("rn").over(ws_complete_frame) == F.col("rn")
    return F.when((is_min_rn) | (is_max_rn), F.lit(1)).otherwise(traversal_compute(idx))

In [None]:
result_selection_expr = [(visibility(r)).alias(f"col_{r}") for r in range(0, nr_cols)]

row_summation = reduce(operator.__add__, result_selection_expr)

In [None]:
inputs.select("filename", row_summation.alias("row_counts")).groupBy("filename").agg(F.sum("row_counts").alias("result")).collect()

In [None]:
ws = W.partitionBy("filename").orderBy("rn")
up = ws.rowsBetween(W.unboundedPreceding, -1)
down = ws.rowsBetween(1, W.unboundedFollowing)
ws_complete_frame = ws.rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

nr_colnames = [f"col_{r}" for r in range(0, nr_cols)]

def construct_accum(stop: Column, distance: Column) -> Column:
    return F.struct(l(stop).cast("boolean").alias("stop"), l(distance).cast("int").alias("distance"))

def viewing_distance(dc: Column, arr: Column) -> Column:
    def merge(acc: Column, x: Column) -> Column:
        stopped = (acc.stop == l(True))
        will_stop_increment = (~stopped) & (x >= dc)
        increment = (~stopped) & (x < dc)
        return (F.when(will_stop_increment, construct_accum(l(True), acc.distance + 1))
          .when(increment, construct_accum(acc.stop, acc.distance + 1))
          .otherwise(construct_accum(l(True), acc.distance)))
    init_acc = construct_accum(l(False), l(0))
    return F.aggregate(arr, init_acc, merge, lambda acc: acc.distance)

def traversal_score(idx: int) -> Column:
    c = F.col(f"col_{idx}")
    up_score = viewing_distance(c, F.reverse(F.collect_list(c).over(up))).alias("up_score") # start looking from closest to the tree
    down_score = viewing_distance(c, F.collect_list(c).over(down)).alias("down_score")
    left_score = viewing_distance(c, F.array(*(nr_colnames[0:idx][::-1]))).alias("left_score") # start looking from closest to the tree
    right_score = viewing_distance(c, F.array(*nr_colnames[idx+1:])).alias("right_score")
    return up_score * down_score * left_score * right_score
    # return [up_score, down_score, left_score, right_score]

def score(idx: int) -> Column:
    if idx == 0 or idx == (nr_cols - 1):
        return F.lit(0)
    is_min_rn = F.min("rn").over(ws_complete_frame) == F.col("rn")
    is_max_rn = F.max("rn").over(ws_complete_frame) == F.col("rn")
    return F.when((is_min_rn) | (is_max_rn), F.lit(0)).otherwise(traversal_score(idx))

In [None]:
score_selection_expr = [(score(r)).alias(f"col_{r}") for r in range(0, nr_cols)]

score_max = F.greatest(*nr_colnames).alias("row_max")

In [None]:
inputs.select(*score_selection_expr, "filename").groupBy("filename").agg(F.max(score_max).alias("result")).collect()