In [None]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

from pyspark.sql import functions as F
from pyspark.sql.functions import col as c, lit as l
from pyspark.sql import Column
from typing import List
from pyspark.sql import Window as W
from functools import reduce
import operator

In [None]:
spark = SparkSession.builder.master("local[2]").appName("AoC_2022_10[2]").getOrCreate()

In [None]:
def value_transformer(e: Column) -> Column:
    splits = F.split(e, " ")
    op = splits[0]
    val = F.coalesce(splits[1].cast("int"), l(0))
    return F.struct(op.alias("op"), val.alias("val"))


inputs = (spark.read.text("./data/aoc_10.txt", wholetext=True)
          .withColumn("fn", l("fn"))
          .withColumn("value", F.split("value", "\n"))
          .withColumn("value", F.transform("value", lambda e: value_transformer(e)))
          .select("fn", F.posexplode("value").alias("rn", "op_val")).select("fn", "rn", "op_val.*"))
inputs.show(truncate=False)

In [None]:
# sample = [(1, "a", "noop", 0),
#           (2, "a", "noop", 0),
#           (3, "a", "addx", 5),
#           (4, "a", "noop", 0),
#           (5, "a", "noop", 0),
#           (6, "a", "addx", 6), ]
#
# inputs = spark.createDataFrame(sample, ("rn", "fn", "op", "val"))

In [None]:
noop = l("noop")
addx = l("addx")

cycles_map = F.create_map(noop, l(1), addx, l(2))

ws = W.partitionBy("fn").orderBy("rn").rowsBetween(W.unboundedPreceding, W.currentRow)

val_after_cycle = (F.sum("val").over(ws) + 1).alias("val_after_cycle")
val_before_cycle = (val_after_cycle - c("val")).alias("val_before_cycle")
cycles = cycles_map[c("op")].alias("cycles")
cycles_end = F.sum(cycles).over(ws).alias("cycles_end")
cycles_start = (cycles_end - cycles + 1).alias("cycles_start")

eval_cond = F
for r in [20, 60, 100, 140, 180, 220]:
    eval_cond = (eval_cond.when(l(r).between(c("cycles_start"), c("cycles_end")), r * c("val_before_cycle")))

eval_cond = eval_cond.otherwise(0)

strengths = (inputs.select("*", cycles, val_before_cycle, val_after_cycle, cycles_start, cycles_end)
             .select("*", eval_cond.alias("strength")))
strengths.groupby("fn").agg(F.sum("strength").alias("total_strength")).head()