In [1]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

from pyspark.sql import functions as F
from pyspark.sql.functions import col as c, lit as l
from pyspark.sql import Column
from typing import List
from pyspark.sql import Window as W
from functools import reduce
import operator

In [2]:
spark = SparkSession.builder.master("local[2]").appName("AoC_2022_9[2]").getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/12/09 21:45:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
def repeater(e: Column) -> Column:
    splits = F.split(e, " ")
    return F.array_repeat(splits[0], splits[1].cast("int"))


inputs_df = (spark.read.text("./data/aoc_9.txt", wholetext=True)
             .withColumn("value", F.split("value", "\n"))
             .withColumn("movements", F.flatten(F.transform("value", lambda e: repeater(e)))))
inputs_df.show(truncate=False)

                                                                                

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [4]:
def coords(hx: Column, hy: Column, tx: Column, ty: Column) -> Column:
    fields = [F.lit(c_expr).cast("int").alias(cn) for c_expr, cn in [(hx, "hx"), (hy, "hy"), (tx, "tx"), (ty, "ty"), ]]
    return F.struct(*fields)


def move_left(coord: Column):
    return coords(coord.hx - 1, coord.hy, coord.tx, coord.ty)


def move_right(coord: Column):
    return coords(coord.hx + 1, coord.hy, coord.tx, coord.ty)


def move_up(coord: Column):
    return coords(coord.hx, coord.hy + 1, coord.tx, coord.ty)


def move_down(coord: Column):
    return coords(coord.hx, coord.hy - 1, coord.tx, coord.ty)


def _straight_val(dim: Column) -> Column:
    step_sign = F.when(dim == F.abs(dim), l(1)).otherwise(l(-1))
    return dim - step_sign


def straight_follow(coord: Column) -> Column:
    vertical_follow = coords(coord.hx, coord.hy, coord.hx, _straight_val(coord.hy))
    horizontal_follow = coords(coord.hx, coord.hy, _straight_val(coord.hx), coord.ty)
    return (F.when(coord["hx"] == coord["tx"], vertical_follow)
            .when(coord["hy"] == coord["ty"], horizontal_follow)
            .otherwise(l(None)))


def reconcile_move(new_coord: Column, old_coord: Column) -> Column:
    covers = (new_coord["hx"] == new_coord["tx"]) & (new_coord["hy"] == new_coord["ty"])
    straight_works = F.when(covers, new_coord).otherwise(straight_follow(new_coord))
    return F.when(straight_works.isNull(),
                  coords(new_coord["hx"], new_coord["hy"], old_coord["hx"], old_coord["hy"])).otherwise(straight_works)


def agg_struct(coord: Column, visited: Column) -> Column:
    return F.struct(coord.alias("coord"), visited.cast("array<struct<tx: int, ty: int>>").alias("visited"))


def transformer(last_coord: Column, x: Column) -> Column:
    new_coord = (F.when(x == l("U"), move_up(last_coord))
                 .when(x == l("D"), move_down(last_coord))
                 .when(x == l("L"), move_left(last_coord))
                 .when(x == l("R"), move_right(last_coord)))
    return reconcile_move(new_coord, last_coord)


def solver(arr: Column) -> Column:
    def merger(acc: Column, x: Column) -> Column:
        new_coord = transformer(acc["coord"], x)
        visited = F.array_union(acc.visited,
                                F.array(F.struct(new_coord["tx"].alias("tx"), new_coord["ty"].alias("ty"))))
        return agg_struct(new_coord, visited)

    start_position = F.struct(l(0).alias("tx"), l(0).alias("ty"))
    start_accum = agg_struct(coords(0, 0, 0, 0), F.array(start_position))

    return F.aggregate(arr, start_accum, merger, lambda acc: acc["visited"])

In [None]:
inputs_df.select(solver(c("movements"))).show(truncate=False)

[Stage 1:>                                                          (0 + 1) / 1]

In [39]:
prev_coord = coords(2, 2, 1, 1)
d1 = transformer(prev_coord, l("U")).alias("d1")
d2 = transformer(prev_coord, l("R")).alias("d2")
d3 = transformer(prev_coord, l("L")).alias("d3")
d4 = transformer(coords(0, 0, 0, 0), l("L")).alias("d4")
d5 = transformer(coords(0, 1, 0, 0), l("U")).alias("d5")
d6 = transformer(coords(0, 1, 0, 0), l("D")).alias("d6")

spark.range(1).select(d1, d2, d3, d4, d5, d6).show(truncate=False)

22/12/09 21:21:49 WARN Column: Constructing trivially true equals predicate, 'U = U'. Perhaps you need to use aliases.
22/12/09 21:21:49 WARN Column: Constructing trivially true equals predicate, 'R = R'. Perhaps you need to use aliases.
22/12/09 21:21:49 WARN Column: Constructing trivially true equals predicate, 'L = L'. Perhaps you need to use aliases.
22/12/09 21:21:49 WARN Column: Constructing trivially true equals predicate, 'L = L'. Perhaps you need to use aliases.
22/12/09 21:21:49 WARN Column: Constructing trivially true equals predicate, 'U = U'. Perhaps you need to use aliases.
22/12/09 21:21:49 WARN Column: Constructing trivially true equals predicate, 'D = D'. Perhaps you need to use aliases.


+------------+------------+------------+-------------+------------+------------+
|d1          |d2          |d3          |d4           |d5          |d6          |
+------------+------------+------------+-------------+------------+------------+
|{2, 3, 2, 2}|{3, 2, 2, 2}|{1, 2, 1, 1}|{-1, 0, 0, 0}|{0, 2, 0, 1}|{0, 0, 0, 0}|
+------------+------------+------------+-------------+------------+------------+



In [32]:
prev_coord = coords(2, 2, 1, 1)
new1_coord = coords(2, 3, 1, 1)
new2_coord = coords(3, 2, 1, 1)

covers_conditions = reconcile_move(coords(100, 200, 100, 200), prev_coord).alias("covers_conditions")
s1 = reconcile_move(coords(1, -2, 1, 0), prev_coord).alias("s1")
s2 = reconcile_move(coords(-2, 1, 0, 1), prev_coord).alias("s2")

spark.range(1).select(reconcile_move(new1_coord, prev_coord).alias("1"),
                      reconcile_move(new2_coord, prev_coord).alias("2"),
                      covers_conditions, s1, s2).show(truncate=False)

+------------+------------+--------------------+--------------+--------------+
|1           |2           |covers_conditions   |s1            |s2            |
+------------+------------+--------------------+--------------+--------------+
|{2, 3, 2, 2}|{3, 2, 2, 2}|{100, 200, 100, 200}|{1, -2, 1, -1}|{-2, 1, -1, 1}|
+------------+------------+--------------------+--------------+--------------+



In [17]:
# spark.range(1).select(construct_accum(0, 0, 0, 0).hx).show(truncate=False)
test_init_accum = coords(0, 0, 0, 0)
spark.range(1).select(move_left(test_init_accum).alias("left"),
                      move_right(test_init_accum).alias("right"),
                      move_up(test_init_accum).alias("up"),
                      move_down(test_init_accum).alias("down")).show(truncate=False)

+-------------+------------+------------+-------------+
|left         |right       |up          |down         |
+-------------+------------+------------+-------------+
|{-1, 0, 0, 0}|{1, 0, 0, 0}|{0, 1, 0, 0}|{0, -1, 0, 0}|
+-------------+------------+------------+-------------+



In [24]:
# (1, 3)(1, 1) -> (1, 2) -> (3 - 1)
# (1, -3)(1, -1) -> (1, -2) -> (-3 - (-1))
# (1, -2)(1, 0) -> (1, -1)
# (1, 1)(1, -1) -> (1, 0)

vertical_follow_exprs = [straight_follow(coords(1, hy, 1, ty)).alias(f"1_{hy}_1_{ty}") for hy, ty in
                         [(3, 1), (-3, -1), (-2, 0), (1, -1), (-1, 1)]]

horizontal_follow_exprs = [straight_follow(coords(hx, 1, tx, 1)).alias(f"{hx}_1_{tx}_1") for hx, tx in
                           [(3, 1), (-3, -1), (-2, 0), (1, -1), (-1, 1)]]

# spark.range(1).select(vertical_follow_exprs).show(truncate=False)

"""
+------------+--------------+--------------+------------+-------------+
|1_3_1_1     |1_-3_1_-1     |1_-2_1_0      |1_1_1_-1    |1_-1_1_1     |
+------------+--------------+--------------+------------+-------------+
|{1, 3, 1, 2}|{1, -3, 1, -2}|{1, -2, 1, -1}|{1, 1, 1, 0}|{1, -1, 1, 0}|
+------------+--------------+--------------+------------+-------------+
"""

spark.range(1).select(horizontal_follow_exprs).show(truncate=False)

"""
+------------+--------------+--------------+------------+-------------+
|3_1_1_1     |-3_1_-1_1     |-2_1_0_1      |1_1_-1_1    |-1_1_1_1     |
+------------+--------------+--------------+------------+-------------+
|{3, 1, 2, 1}|{-3, 1, -2, 1}|{-2, 1, -1, 1}|{1, 1, 0, 1}|{-1, 1, 0, 1}|
+------------+--------------+--------------+------------+-------------+
"""

+------------+--------------+--------------+------------+-------------+
|3_1_1_1     |-3_1_-1_1     |-2_1_0_1      |1_1_-1_1    |-1_1_1_1     |
+------------+--------------+--------------+------------+-------------+
|{3, 1, 2, 1}|{-3, 1, -2, 1}|{-2, 1, -1, 1}|{1, 1, 0, 1}|{-1, 1, 0, 1}|
+------------+--------------+--------------+------------+-------------+



'\n+------------+--------------+--------------+------------+-------------+\n|3_1_1_1     |-3_1_-1_1     |-2_1_0_1      |1_1_-1_1    |-1_1_1_1     |\n+------------+--------------+--------------+------------+-------------+\n|{3, 1, 2, 1}|{-3, 1, -2, 1}|{-2, 1, -1, 1}|{1, 1, 0, 1}|{-1, 1, 0, 1}|\n+------------+--------------+--------------+------------+-------------+\n'

In [47]:
spark.range(5).select(F.array_repeat(l("a"), c("id").cast("int"))).show(truncate=False)

+--------------------------------+
|array_repeat(a, CAST(id AS INT))|
+--------------------------------+
|[]                              |
|[a]                             |
|[a, a]                          |
|[a, a, a]                       |
|[a, a, a, a]                    |
+--------------------------------+

