## Permissions
Minimum permissions to run this are:
* Monitoring Metrics Publisher on DCR to write
* Monitoring Reader on DCR
* Monitoring Reader on DCE

# Setup

In [0]:
%pip install -r ./requirements.txt -q
dbutils.library.restartPython()

In [0]:
import json

from azure.core.exceptions import HttpResponseError
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.monitor.ingestion import LogsIngestionClient

from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType

from sentinel_libraries import *

# Variables

In [0]:
dbutils.widgets.text("log_analytics_table_name", "")
dbutils.widgets.text("system_table_name", "")
dbutils.widgets.text("checkpoint_volume_location", "")
dbutils.widgets.text("starting_datetime", "")
dbutils.widgets.text("workspace_ids", "")
dbutils.widgets.text("processing_time", "1 minute")

dbutils.widgets.text("tenant_id", "")
dbutils.widgets.text("subscription_id", "")
dbutils.widgets.text("sp_client_id", "")
dbutils.widgets.text("sp_client_secret_scope", "")
dbutils.widgets.text("sp_client_secret_key", "")
dbutils.widgets.text("resource_group_name", "")

In [0]:
log_analytics_table_name = dbutils.widgets.get("log_analytics_table_name")
system_table_name = dbutils.widgets.get("system_table_name")
checkpoint_volume_location = dbutils.widgets.get("checkpoint_volume_location")
starting_datetime = dbutils.widgets.get("starting_datetime")
workspace_ids = json.loads(dbutils.widgets.get("workspace_ids"))
processing_time = dbutils.widgets.get("processing_time")

tenant_id = dbutils.widgets.get("tenant_id")
subscription_id = dbutils.widgets.get("subscription_id")
sp_client_id = dbutils.widgets.get("sp_client_id")
sp_client_secret_scope = dbutils.widgets.get("sp_client_secret_scope")
sp_client_secret_key = dbutils.widgets.get("sp_client_secret_key")
resource_group_name = dbutils.widgets.get("resource_group_name")

sp_secret = dbutils.secrets.get(sp_client_secret_scope, sp_client_secret_key)

In [0]:
print(log_analytics_table_name)
print(system_table_name)
print(checkpoint_volume_location)
print(starting_datetime)
print(workspace_ids)
print(processing_time)

# Initialization

In [0]:
data_collection_endpoint_name = f"{log_analytics_table_name}-dce"
data_collection_rule_name = f"{log_analytics_table_name}-dcr"
raw_stream_declaration_name = f"Custom-{log_analytics_table_name}RawData"

credentials = ClientSecretCredential(
    tenant_id=tenant_id,
    client_id=sp_client_id,
    client_secret=sp_secret
)

dcr_id = get_dcr(credentials, subscription_id, resource_group_name, data_collection_rule_name).immutable_id
dce_url = get_dce(credentials, subscription_id, resource_group_name, data_collection_endpoint_name).logs_ingestion.endpoint

client = LogsIngestionClient(endpoint=dce_url, credential=credentials, logging_enable=True)

## Core

In [0]:
from typing import Iterator
from pyspark.sql import Row, DataFrame
from pyspark.sql.functions import to_json, struct, col
import json

def push_logs_batched(partition: Iterator[Row]):
    def push_logs(logs):
        try:
            client.upload(rule_id=dcr_id, stream_name=raw_stream_declaration_name, logs=logs)
        except HttpResponseError as e:
            print(f"Upload failed: {e}")

    batch_size = 1_000_000 # 1MB limits
    current_batch_size = 0    
    current_batch = []
    
    for r in partition:
        # Check if we need to flush the current batch
        if len(r.row_json) > (batch_size - current_batch_size):
            push_logs(current_batch)
            current_batch_size = 0
            current_batch = []
        
        # Add the current row to the batch
        current_batch.append(json.loads(r.row_json))
        current_batch_size += len(r.row_json)

    if current_batch_size > 0:
            push_logs(current_batch)

def push_log_to_sentinel_batched(df: DataFrame, epoch_id: int):
    df.foreachPartition(push_logs_batched)

In [0]:
input_df = (
    spark.readStream.option("skipChangeCommits", "true")
    .option("maxFilesPerTrigger", 1000)
    .option("maxBytesPerTrigger", "1g")
    .table(system_table_name)
    .filter(F.col("event_date") >= starting_datetime)
)

if workspace_ids:
    input_df = input_df.filter(F.col("workspace_id").isin(workspace_ids))

streaming_query = (
    input_df.withColumn("TimeGenerated", col("event_time"))
    .select(to_json(struct("*")).alias("row_json"))
    .writeStream.option("checkpointLocation", checkpoint_volume_location)
    .outputMode("append")
    .queryName(f"{system_table_name} streaming process")
    .foreachBatch(push_log_to_sentinel_batched)
)

In [0]:
if processing_time:
    print(f"Starting streaming query with processing time: {processing_time}")
    streaming_query.trigger(processingTime=processing_time).start()
else:
    print(f"Starting streaming query with availableNow trigger")
    q = streaming_query.trigger(availableNow=True).start()
    q.awaitTermination()