In [None]:
%pip install tqdm

In [None]:
from tqdm import tqdm
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.types import StructType, IntegerType,StringType,FloatType,DecimalType
from pyspark.sql.functions import from_json, col, expr, base64
import base64

In [None]:
STORAGE_ACCOUNT_NAME = "azuredatalakespuertaf"
RAW_PATH = f"wasbs://raw@{STORAGE_ACCOUNT_NAME}.blob.core.windows.net"
MOUNT_POINT_RAW = "/mnt/raw/{}"
RAW_CHECKPOINT_LOCATION = "/mnt/raw/{}/check"
DATABRICKS_SECRET_SCOPE = "accessdatalake"
AZURE_SECRET_PATH = f"fs.azure.account.key.{STORAGE_ACCOUNT_NAME}.blob.core.windows.net"
AZURE_SECRET_NAME = "datalakeaccess"
ENTITIES = ["customer","date","product","reseller","sales","sales_order","sales_territory"]
TOPICS = ["prueba.customer","prueba.date","prueba.product","prueba.reseller","prueba.sales","prueba.sales-order","prueba.sales-terrritory"]
KAFKA_SERVER = "74.249.34.106:9092"

In [None]:
#configurando y montando puntos de montaje raw DataBricks - Azure
def gen_mount_point(entiti_name:str):
    dbutils.fs.mount(
        source = RAW_PATH,
        mount_point = MOUNT_POINT_RAW.format(entiti_name),
        extra_configs = {AZURE_SECRET_PATH:dbutils.secrets.get(scope=DATABRICKS_SECRET_SCOPE, key=AZURE_SECRET_NAME)}
    )

progress_bar = tqdm(total=len(ENTITIES))
for entiti in ENTITIES:
    gen_mount_point(entiti)
    progress_bar.update(1)

In [None]:
schema_customer = StructType()\
                  .add("CustomerKey",IntegerType())\
                  .add("CustomerID",StringType())\
                  .add("Customer",StringType())\
                  .add("City",StringType())\
                  .add("StateProvince",StringType())\
                  .add("CountryRegion",StringType())\
                  .add("PostalCode",StringType())


schema_date = StructType()\
              .add("DateKey",IntegerType())\
              .add("Date",StringType())\
              .add("FiscalYear",StringType())\
              .add("FiscalQuarter",StringType())\
              .add("Month",StringType())\
              .add("FullDate",StringType())\
              .add("MonthKey",IntegerType())

schema_product = StructType()\
                 .add("ProductKey",IntegerType())\
                 .add("SKU",StringType())\
                 .add("Product",StringType())\
                 .add("StandardCost",FloatType())\
                 .add("Color",StringType())\
                 .add("ListPrice",FloatType())\
                 .add("Model",StringType())\
                 .add("Subcategory",StringType())\
                 .add("Category",StringType())

schema_reseller = StructType()\
                 .add("ResellerKey",IntegerType())\
                 .add("ResellerID",StringType())\
                 .add("BusinessType",StringType())\
                 .add("Reseller",StringType())\
                 .add("City",StringType())\
                 .add("StateProvince",StringType())\
                 .add("CountryRegion",StringType())\
                 .add("PostalCode",StringType())

schema_sales = StructType()\
               .add("SalesOrderLineKey",IntegerType())\
               .add("ResellerKey",IntegerType())\
               .add("CustomerKey",IntegerType())\
               .add("ProductKey",IntegerType())\
               .add("OrderDateKey",IntegerType())\
               .add("DueDateKey",IntegerType())\
               .add("ShipDateKey",IntegerType())\
               .add("SalesTerritoryKey",IntegerType())\
               .add("OrderQuantity",IntegerType())\
               .add("UnitPrice",FloatType())\
               .add("ExtendedAmount",FloatType())\
               .add("UnitPriceDiscountPct",DecimalType(5,2))\
               .add("ProductStandardCost",FloatType())\
               .add("TotalProductCost",FloatType())\
               .add("SalesAmount",FloatType())

schema_sales_order = StructType()\
                 .add("Channel",StringType())\
                 .add("SalesOrderLineKey",IntegerType())\
                 .add("SalesOrder",StringType())\
                 .add("SalesOrderLine",IntegerType())

schema_sales_territory = StructType()\
                         .add("SalesTerritoryKey",IntegerType())\
                         .add("Region",StringType())\
                         .add("Country",StringType())\
                         .add("Group",StringType())

In [None]:
def read_kafka_topic(kafka_server:str, topic_name:str) -> SparkDataFrame:
    stream_df = spark.readStream\
         .format("kafka")\
         .option("kafka.bootstrap.servers", kafka_server)\
         .option("subscribe",topic_name)\
         .option("startingOffsets","earliest")\
         .option("failOnDataLoss","false")\
         .load()
    return stream_df

In [None]:
def get_data(stream_df:SparkDataFrame, schema:StructType) -> SparkDataFrame:
    df = stream_df.select(from_json(col("value").cast(StringType()),schema).alias("data"))
    df = df.select("data.*")
    return df

In [None]:
def save_data(data_df:SparkDataFrame, format_:str, out_path:str,mode:str, check_path:str):
    return data_df.writeStream\
                  .format(format_)\
                  .option("path",out_path)\
                  .outputMode(mode)\
                  .option("checkpointLocation",check_path)\
                  .trigger(processingTime="30 seconds")

In [None]:
def main(kafka_server:str, topic_name:str, schema:StructType,format_:str, out_path:str,mode:str, check_path:str):
    kafka_data = read_kafka_topic(
        kafka_server=kafka_server,
        topic_name=topic_name
    )
    parsed_data = get_data(
        stream_df=kafka_data,
        schema=schema
    )
    saved_data = save_data(
        data_df=parsed_data,
        format_=format_,
        out_path=out_path,
        mode = mode,
        check_path=check_path
    )
    return saved_data

In [None]:
data = read_kafka_topic(
    kafka_server=KAFKA_SERVER,
    topic_name=TOPICS[0],
)

parsed = get_data(
    stream_df=data,
    schema = schema_customer
)

display(parsed)

In [None]:
#Para detener los procesos de streaming
for s in spark.streams.active:
      s.stop()