In [0]:
%run ./01-config

In [0]:
landing_zone = base_dir_data + "/raw"
test_data_dir = base_dir_data + "/test_data"

In [0]:
import time
from pyspark.sql.functions import col


def produce_user_registration(set_num):
    source = f"{test_data_dir}/1-registered_users_{set_num}.csv"
    target = f"{landing_zone}/registered_users_bz/1-registered_users_{set_num}.csv"
    print(f"Producing {source}...", end='')
    dbutils.fs.cp(source, target)
    print("Done")




def produce_profile_cdc(set_num):
    source = f"{test_data_dir}/2-user_info_{set_num}.json"
    target = f"{landing_zone}/kafka_multiplex_bz/2-user_info_{set_num}.json"
    print(f"Producing {source}...", end='')
    dbutils.fs.cp(source, target)
    print("Done")


def produce_workout(set_num):
    source = f"{test_data_dir}/4-workout_{set_num}.json"
    target = f"{landing_zone}/kafka_multiplex_bz/4-workout_{set_num}.json"
    print(f"Producing {source}...", end='')
    dbutils.fs.cp(source, target)
    print("Done")


def produce_bpm(set_num):
    source = f"{test_data_dir}/3-bpm_{set_num}.json"
    target = f"{landing_zone}/kafka_multiplex_bz/3-bpm_{set_num}.json"
    print(f"Producing {source}...", end='')
    dbutils.fs.cp(source, target)
    print("Done")


def produce_gym_logins(set_num):
    source = f"{test_data_dir}/5-gym_logins_{set_num}.csv"
    target = f"{landing_zone}/gym_logins_bz/5-gym_logins_{set_num}.csv"
    print(f"Producing {source}...", end='')
    dbutils.fs.cp(source, target)
    print("Done")



def produce_data(set_num):
    start = int(time.time())
    print(f"\nProducing test data set {set_num} ...")
    if set_num <= 2:
        produce_user_registration(set_num)
        produce_profile_cdc(set_num)
        produce_workout(set_num)
        produce_gym_logins(set_num)
    if set_num <= 10:
        produce_bpm(set_num)
    print(f"Test data set {set_num} produced in {int(time.time()) - start} seconds")


def validate_count(fmt, location, expected_count):
    print(f"Validating {location}...", end='')
    path = f"{landing_zone}/{location}_*.{fmt}"
    actual_count = (spark.read
                    .format(fmt)
                    .option("header", "true")
                    .load(path)
                    .count())
    assert actual_count == expected_count, f"Expected {expected_count:,} records, found {actual_count:,} in {location}"
    print(f"Found {actual_count:,} / Expected {expected_count:,} records: Success")


def validate_data(sets):
    start = int(time.time())
    print(f"\nValidating test data {sets} sets...")
    validate_count("csv", "registered_users_bz/1-registered_users", 5 if sets == 1 else 10)
    validate_count("json", "kafka_multiplex_bz/2-user_info", 7 if sets == 1 else 13)
    validate_count("json", "kafka_multiplex_bz/3-bpm", sets * 253801)
    validate_count("json", "kafka_multiplex_bz/4-workout", 16 if sets == 1 else 32)
    validate_count("csv", "gym_logins_bz/5-gym_logins", 8 if sets == 1 else 16)
    print(f"Test data validation completed in {int(time.time()) - start} seconds")