In [None]:
# Set the PySpark environment variables
import os
import sys

os.environ["SPARK_HOME"] = r"C:\_dev\spark-3.5.1-hadoop3"
# os.environ['PYSPARK_DRIVER_PYTHON'] = 'jupyter'
# os.environ['PYSPARK_DRIVER_PYTHON_OPTS'] = 'lab'
os.environ["PYSPARK_PYTHON"] = sys.executable

In [None]:
from pyspark.sql.types import *
from decimal import Decimal
from datetime import datetime, date
from pyspark.sql.functions import *
import random
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    DoubleType,
    TimestampType,
)
from pyspark.sql.window import Window
from collections import namedtuple
from collections import defaultdict
from pyspark.sql import SparkSession

In [None]:
# creating a sparksession object
# and providing appName
spark = SparkSession.builder.appName("QA_SCD2").getOrCreate()

In [None]:
# function to compare two dataframes
# test functions


def test_SrcExceptTarget(source, target, ac, failed_flag, extra=""):
    if source.exceptAll(target).count() != 0:
        print("Test failed: records in source that are not in target - ", ac, extra)
        source.exceptAll(target).display()
        failed_flag = 1
    return failed_flag


def test_TargetExceptSrc(source, target, ac, failed_flag, extra=""):
    if target.exceptAll(source).count() != 0:
        print("Test failed: records in target that are not in source - ", ac, extra)
        target.exceptAll(source).display()
        failed_flag = 1
    return failed_flag


def test_distinctTarget(source, target, ac, failed_flag, id):
    src = source.select(col(id).alias("id"))
    joined = src.join(target, src.id == target[id], "inner")
    if joined.count() != joined.dropDuplicates().count:
        print("Test failed: duplicate target records - ", ac)
        failed_flag = 1
    return failed_flag


def test_countSourceTarget(source, target, ac, failed_flag):
    if target.count() != source.count():
        print("Test failed: incorrect number of records - ", ac)
        print("Distinct records in source: ", source.count())
        print("Distinct records in curated: ", target.count())
        failed_flag = 1
    return failed_flag

In [None]:
def create_expected_source(source, scd2_cols, ID_COL="pros_id"):
    SEPARATOR = "¿"
    FIRST = "First"
    MIDDLE = "Middle"
    LAST = "Last"
    CONCAT_COL = "concat"
    RANK_COL = "rank"
    CONSECUTIVE_COL = "consecutive"
    START_AT_COL = "start_at"
    START_AT_EARLIEST_COL = "start_at_earliest"

    windowPartition = Window.partitionBy(ID_COL).orderBy([ID_COL, START_AT_COL])
    # https://stackoverflow.com/questions/74942562/concat-ws-and-coalesce-in-pyspark
    source = (
        source.withColumn(
            CONCAT_COL,
            concat_ws(SEPARATOR, *scd2_cols),
        )
        .withColumn(
            RANK_COL,
            when(
                (col(CONCAT_COL) != lag(CONCAT_COL).over(windowPartition))
                & (col(CONCAT_COL) == lead(CONCAT_COL).over(windowPartition)),
                lit(FIRST),
            )
            .when(
                (col(CONCAT_COL) == lag(CONCAT_COL).over(windowPartition))
                & (col(CONCAT_COL) == lead(CONCAT_COL).over(windowPartition)),
                lit(MIDDLE),
            )
            .when(
                (col(CONCAT_COL) == lag(CONCAT_COL).over(windowPartition))
                & (col(CONCAT_COL) != lead(CONCAT_COL).over(windowPartition)),
                lit(LAST),
            )
            .otherwise(lit("")),
        )
        .withColumn(
            CONSECUTIVE_COL,
            when(
                (col(CONCAT_COL) == lag(CONCAT_COL).over(windowPartition))
                | (col(CONCAT_COL) == lead(CONCAT_COL).over(windowPartition)),
                lit(True),
            ).otherwise(lit(False)),
        )
        .withColumn(  # find earliest start date
            START_AT_EARLIEST_COL,
            when(
                col(RANK_COL) == LAST,
                min(col(START_AT_COL)).over(
                    (
                        Window.partitionBy(
                            [ID_COL, CONCAT_COL, CONSECUTIVE_COL]
                        ).orderBy([ID_COL, START_AT_COL])
                    )
                ),
            ).otherwise(lit(None)),
        )
    )

    # source.display()

    # update start_at with earliest start date when rank = last
    source = source.withColumn(
        START_AT_COL,
        when(
            col(START_AT_EARLIEST_COL).isNotNull(), col(START_AT_EARLIEST_COL)
        ).otherwise(col(START_AT_COL)),
    )

    # drop rows with rank = first or middle
    source = source.filter((col(RANK_COL) != MIDDLE) & (col(RANK_COL) != FIRST))

    # drop helper columns such as concat, rank, consecutive, and start_at_earliest
    source = source.drop(CONCAT_COL, RANK_COL, CONSECUTIVE_COL, START_AT_EARLIEST_COL)

    return source.orderBy(ID_COL, START_AT_COL)

In [None]:
# rename source table's columns to match target table's columns
def rename_source_cols(source_df, source_cols, target_cols):
    if len(source_cols) != len(target_cols):
        print("Column count mismatch")
        exit(1)

    pairs = list(zip(source_cols, target_cols))

    for p in pairs:
        if p[0] != p[1]:
            source_df = source_df.withColumnRenamed(*p)

    return source_df


# drop table's columns
def drop_cols(df, cols):
    for c in df.columns:
        if c not in cols:
            # print("dropping", c)
            df = df.drop(c)

    return df

In [None]:
def find_intersections(list, lists):

    maxdate = date(datetime.max.year, 1, 1)

    # collect all points first into a set and the into a sorted sequence
    breaks = set()
    for l in lists:
        breaks.update(*l)
    # sort None values to the end
    breaks = sorted(breaks, key=lambda x: (x or maxdate))

    # print("breaks", breaks)

    intersections = []
    index = 0
    # For each interval
    for start, end in list:
        if end is not None:
            # Advance b until it falls into this interval:
            while breaks[index] <= start:
                index += 1
            # Now collect all sub-intervals from
            while index < len(breaks) and (
                breaks[index] <= end if breaks[index] is not None else False
            ):
                intersections.append([start, breaks[index]])
                start = breaks[index]
                index += 1
        elif end is None:  # start 1, end None
            # Advance b until it falls into this interval:
            while breaks[index] <= start:
                index += 1
                if breaks[index] is None:
                    intersections.append([start, breaks[index]])
                    return intersections

            # Now collect all sub-intervals from
            while index < len(breaks) and breaks[index] is not None:
                intersections.append([start, breaks[index]])
                start = breaks[index]
                index += 1

            if breaks[index] is None:
                intersections.append([start, breaks[index]])
                return intersections

    return intersections

In [None]:
NT = namedtuple("NT", ["id", "start_date", "end_date"])


def find_date_intersections(daterange, lists):
    # collect all endpoints first into a set and the into a sorted sequence
    breaks = set()
    for l in lists:
        for x in l:
            breaks.add(x.start_date)
            breaks.add(x.end_date)

    maxdate = date(datetime.max.year, 1, 1)
    breaks = sorted(breaks, key=lambda x: (x or maxdate))
    # print(breaks)

    intersections: list[NT] = []
    index = 0
    for id, start, end in daterange:
        # print(id, start, end)
        if end is not None:
            # Advance b until it falls into this interval:
            while breaks[index] <= start:
                index += 1
            # Now collect all sub-intervals from
            while index < len(breaks) and (
                breaks[index] <= end if breaks[index] is not None else False
            ):
                intersections.append(NT(id, start, breaks[index]))
                start = breaks[index]
                index += 1
        elif end is None:
            # Advance b until it falls into this interval:
            while breaks[index] <= start:
                index += 1
                if breaks[index] is None:
                    intersections.append(NT(id, start, breaks[index]))
                    return intersections

            # Now collect all sub-intervals from
            while index < len(breaks) and breaks[index] is not None:
                intersections.append(NT(id, start, breaks[index]))
                start = breaks[index]
                index += 1

            if breaks[index] is None:
                intersections.append(NT(id, start, breaks[index]))
                return intersections

    return intersections

In [None]:
# target test data, Gperson
schema = StructType(  # Define the whole schema within a StructType
    [
        StructField("person_id", DecimalType(25, 0), True),
        StructField("Surname_cache", StringType(), True),
        StructField("__START_AT", DateType(), True),
        StructField("__END_AT", DateType(), True),
    ]
)

data = [
    (
        Decimal("1"),
        "t",
        datetime.strptime("2023-01-01", "%Y-%m-%d"),
        datetime.strptime("2023-01-02", "%Y-%m-%d"),
    ),
    (
        Decimal("1"),
        "te",
        datetime.strptime("2023-01-03", "%Y-%m-%d"),
        datetime.strptime("2023-01-04", "%Y-%m-%d"),
    ),
    (
        Decimal("1"),
        "test",
        datetime.strptime("2023-01-05", "%Y-%m-%d"),
        None,
    ),
    (
        Decimal("2"),
        "tes",
        datetime.strptime("2023-02-05", "%Y-%m-%d"),
        datetime.strptime("2023-02-09", "%Y-%m-%d"),
    ),
    (
        Decimal("2"),
        "test",
        datetime.strptime("2023-02-10", "%Y-%m-%d"),
        None,
    ),
]

gperson = spark.createDataFrame(data, schema)

In [None]:
# test data for Gpid
schema = StructType(  # Define the whole schema within a StructType
    [
        StructField("Id", DecimalType(25, 0), True),
        StructField("WId", DecimalType(), True),
        StructField("IDNumber", DecimalType(), True),
        StructField("START_AT", DateType(), True),
        StructField("END_AT", DateType(), True),
    ]
)

data = [
    (
        Decimal("1"),
        Decimal("1"),
        Decimal("123"),
        datetime.strptime("2023-01-03", "%Y-%m-%d"),
        datetime.strptime("2023-01-04", "%Y-%m-%d"),
    ),
    (
        Decimal("2"),
        Decimal("1"),
        Decimal("456"),
        datetime.strptime("2023-01-06", "%Y-%m-%d"),
        datetime.strptime("2023-01-10", "%Y-%m-%d"),
    ),
    (
        Decimal("3"),
        Decimal("1"),
        Decimal("789"),
        datetime.strptime("2023-01-12", "%Y-%m-%d"),
        None,
    ),
    (
        Decimal("4"),
        Decimal("2"),
        Decimal("888"),
        datetime.strptime("2023-02-01", "%Y-%m-%d"),
        datetime.strptime("2023-02-07", "%Y-%m-%d"),
    ),
    (
        Decimal("5"),
        Decimal("2"),
        Decimal("999"),
        datetime.strptime("2023-02-20", "%Y-%m-%d"),
        None,
    ),
]

gpid = spark.createDataFrame(data, schema)

In [None]:
# 1
# Select records from gperson where 'Id' begins with 15
# Left join:
# - gperson.Id = gpid.WId
# Drop duplicate rows.

gperson_gpid = gperson.join(
    gpid, gperson.person_id == gpid.WId, "left"
).dropDuplicates()

# gperson_gpid.printSchema()
# gperson_gpid.display()


gperson_gpid = gperson_gpid.orderBy(
    "person_id", "Surname_cache", "__START_AT", "START_AT"
)

gperson_gpid = gperson_gpid.withColumn(
    "row_idx", row_number().over(Window.orderBy(monotonically_increasing_id()))
)

gperson_gpid.show()

In [None]:
list_of_intersections = []

for row in gperson_gpid.collect():
    person_start_at = row["__START_AT"]
    person_end_at = row["__END_AT"]
    person_id = row["row_idx"]
    nt1 = NT(person_id, person_start_at, person_end_at)
    
    gpid_start_at = row.START_AT
    gpid_end_at = row.END_AT
    
    nt2 = NT(person_id, gpid_start_at, gpid_end_at)
    
    a = [nt1]
    b = [nt2]

    lists = [a, b]
    intersections = find_date_intersections(a, lists)

    list_of_intersections.append(intersections)

    print(intersections)

In [None]:
# b = spark.createDataFrame([(l,) for l in list_of_intersections], ["Intersections"])

abc = []
for l in list_of_intersections:
    for i in l:
        # print(i)
        abc.append(i)

b = spark.createDataFrame(abc, ['row_idx', 'start_date_new', 'end_date_new'])

b.show()

gperson_gpid  = gperson_gpid.join(b, gperson_gpid.row_idx == b.row_idx, "left")

gperson_gpid.show()
