# Retroactive Threat Intelligence Matching

### Overview

This notebook correlates Threat Intelligence Indicators, from the ThreatIntelIndicators table, with log data from multiple sources over a configurable lookback period, aggregates matches by TI indicator, and saves results to a managed table for further analysis.  Future runs will reference matches already in the result set to avoid generating duplicate alerts each time this notebook is run.

### How to Run Notebook

Reference the general [Sentinel Notebook Readme](./README.md) for guidance on installing and running notebooks.  

### Key Features:
- **Multiple log source support**
- **Flexible matching modes**: current (TI valid now), loose (ignore validity)
- **Configurable log sources**: Enable/disable different log types as needed.
- **Adjustable lookback period**: Configure how long back it should look for matches.

### Currently Supported Log Sources:
- **SigninLogs**: Standard user sign-in activities.
- **Syslogs**: General table for logging system and security events.
- **CommonSecurityLogs**: Table for collecting events in the Common Event format from different security sources.

### Required Customer Input:
- **WORKSPACE_NAME**: Customer Log Analytics workspace name.  This will be used for retrieving indicator and log data, as well as for outputing match results.  If 'None' is provided then the notebook will look for the first log analytics workspace that is not the Sentinel generated 'default' workspace.
- **LOOKBACK_DAYS**: 14-365.  Lookback time period for logs matching.  Default 365.
- **MATCH_MODE**: Which ThreatIntelIndicators to match against which logs: current (TI valid now), loose (ignore validity).  Default "current".
- Enabled the log sources that you would like to match against under the `LOG SOURCE TOGGLES - SUPPORTED` section.

### Output Schema:
Results are aggregated by TI indicator with match counts and event references for detailed analysis.  The RetroThreatMatchResults_SPRK_CL output table will be generated on the provided Log Analytics workspace.

| Column Name | Type |Description |
|-------------|------|------------|
|MatchId | string |  Unique identifier for the match result record (reference to the original ThreatIntelIndicators Id) | 
|JobId | string | Identifier for the retroactive matching job execution.  This is a random uuid created by the notebook. |
| JobStartTime          | datetime         | Timestamp when the retro-matching job started. |
| JobEndTime            | datetime         | Timestamp when the retro-matching job completed. |
| MatchType             | string           | Type of match (e.g., "IoC", "Observable", "CVE", "TTP"/"MITRE-Technique"). |
| ObservableType        | string           | Subtype of the match (e.g., "IP", "Domain", "URL", "SHA256", "x509", "JA3").| 
| ObservableValue       | string           | Observable value (IoC value == Observable value).  Domain, IP, URL, etc. |
| TIReferenceId         | string           | Reference to the Threat Intelligence record (e.g., internal IoC ID or STIX ID). |
| TIValue               | string           | Actual IoC or observable value that was matched (e.g., "malicious.com", name of TTP, etc.). |
| MatchCount            | int              | Number of events or records in the environment that matched this TI object. |
| EventReferences       | dynamic          | Array of matched events with format `[{"Table":"SigninLogs","RecordId":"abc123"}, ...]`. |
| TTPs                  | dynamic          | Array of MITRE techniques (e.g., `["T1059", "T1071.001"]`) associated with the matched TI. |
| ThreatActors          | dynamic          | Array of threat actor names tied to the matched TI object. |
| EnrichmentContext     | dynamic          | Optional dictionary of enrichment tags (e.g., industry, country, malware family, confidence score). |
| TenantId              | string           | Identifier of the customer environment (multi-tenant scenarios). |
| TimeGenerated | datetime | Timestamp of record creation in this table. |


In [4]:
# ===============================================================================
# PARAMETERS AND LOG SOURCE CONFIGURATION
# ===============================================================================

# Workspace and Data Configuration
WORKSPACE_NAME = None  # log analytics workspace required to be set by customer; or leave as None to auto-detect the first non-default workspace
LOOKBACK_DAYS = 365  # Days to look back for logs (default: 12 months)

# Matching Mode Configuration - Default "current"
# - "current": TI indicator must be valid at the current time
# - "loose": Ignore TI validity windows entirely
MATCH_MODE = "current"

# ===============================================================================
# LOG SOURCE TOGGLES - SUPPORTED
# ===============================================================================
ENABLE_SIGNIN_LOGS = True
ENABLE_SYS_LOGS = True
ENABLE_COMMON_SECURITY_LOGS = True
ENABLE_NON_INTERACTIVE_SIGNIN_LOGS = True
ENABLE_SERVICE_PRINCIPAL_SIGNIN_LOGS = True
ENABLE_MANAGED_IDENTITY_SIGNIN_LOGS = True

# ===============================================================================
# LOG SOURCE TOGGLES - WORK IN PROGRESS
# ===============================================================================
ENABLE_WINDOWS_EVENT_LOGS = False
ENABLE_SECURITY_EVENT_LOGS = False
ENABLE_SECURITY_IOT_RAW_EVENT_LOGS = False
ENABLE_OFFICE_LOGS = False
ENABLE_DNS_LOGS = False
ENABLE_EVENT_LOGS = False
ENABLE_W3CIIS_LOGS = False
ENABLE_AUDIT_LOGS = False
ENABLE_USER_RISK_EVENTS = False


# ===============================================================================
# DEBUG AND TESTING CONFIGS
# ===============================================================================
SHOW_DEBUG_LOGS = False
REDUCED_DEBUG_LOGS = True  # If True, only show summary debug logs at the end of the script, and counts
SHOW_STATS = False
USE_TEST_DATA_LOGS = False
USE_TEST_DATA_THREAT_INTEL = False

# Performance switches
AUTO_TUNE_SHUFFLE_PARTITIONS = False  # If True, estimate shuffle partitions before joins
TARGET_PARTITION_SIZE_BYTES = 256 * 1024 * 1024  # 256MB

# Table Names
THREAT_INTEL_TABLE = "ThreatIntelIndicators"
THREAT_INTEL_OBJECTS_TABLE = "ThreatIntelObjects"
RESULTS_TABLE = "RetroThreatMatchResults_p3_ttps_SPRK_CL"

# Version number
VERSION = "1.0.1"

# ===============================================================================
# LOG CONFIG IMPORTS AND SOURCE MAP
# ===============================================================================
from pyspark.sql.types import StructType, StructField, StringType
import re

# Define structs for certain logs configs
SECURITY_IOT_RAW_EVENT_LOGS_STRUCT = StructType([
    StructField("LocalAddress", StringType(), True),
    StructField("RemoteAddress", StringType(), True)
])
WINDOWS_EVENT_LOGS_STRUCT = StructType([
    StructField("TargetUserName", StringType(), True),
    StructField("FileHash", StringType(), True),
    StructField("SourceAddress", StringType(), True),
    StructField("DestAddress", StringType(), True)
])

# Regex patterns
URL_REGEX_PAT = r"(https?://(?:[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+(?::[^@\s]+)?@)?(?!-)(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}(?::\d{1,5})?(?:(?:/[-a-zA-Z0-9._~:/?#\[\]@!$&*+=,;%]*[a-zA-Z0-9_/])?|(?:\?[-a-zA-Z0-9._~:/?#\[\]@!$&*+=,;%=]*[a-zA-Z0-9_=&])|(?:#[-a-zA-Z0-9._~:/?#\[\]@!$&*+=,;%=]*[a-zA-Z0-9_]))?)"
DOMAIN_REGEX_PAT = r"(?:^|\s|[^\w.])((?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z]{2,}(?![@])|(?<=@)(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z]{2,})(?=$|\s|[^\w]|[.]))"
IPV4_PAT = r"(?<![.\d])(?:(?!0+\.0+\.0+\.0+)(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3})(?![\w.])"
IPV6_PAT = r"(?<![:\d])(?:(?:[0-9A-Fa-f]{1,4}:){7}[0-9A-Fa-f]{1,4}|(?:(?:[0-9A-Fa-f]{1,4}:){1,7}:|(?:(?:[0-9A-Fa-f]{1,4}:){0,6}[0-9A-Fa-f]{1,4})?::(?:(?:[0-9A-Fa-f]{1,4}:){0,6}[0-9A-Fa-f]{1,4})?))(?:%[0-9A-Za-z]+)?(?![:.\w])"
IP_REGEX_PAT = "(?:" + IPV4_PAT + "|" + IPV6_PAT + ")"

# Additional regex patterns
EVENT_KV_REGEX = r'<Data Name=\"(\w+)\">{?([^<]*?)}?</Data>'
EVENT_KV_KEYS = ['Hashes']

# Common compiled regexs
IP_REGEX = re.compile(IP_REGEX_PAT)
URL_REGEX = re.compile(URL_REGEX_PAT)
DOMAIN_REGEX = re.compile(DOMAIN_REGEX_PAT)

# Log Sources
LOG_SOURCES = {
    "SigninLogs": {
        "table_name": "SigninLogs",
        "id_field": "Id",
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated",
        "enabled": ENABLE_SIGNIN_LOGS,
        "description": "Standard user sign-in logs",
        "join_conditions": [
            {"log_field": "IPAddress", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]}
        ]
    },
    "ManagedIdentitySigninLogs": {
        "table_name": "AADManagedIdentitySignInLogs",
        "id_field": "Id",
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated",
        "enabled": ENABLE_MANAGED_IDENTITY_SIGNIN_LOGS,
        "description": "Managed identity sign-in logs",
        "join_conditions": [
            {"log_field": "IPAddress", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]}
        ]
    },
    "AuditLogs": {
        "table_name": "AuditLogs",
        "id_field": "Id",
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated",
        "enabled": ENABLE_AUDIT_LOGS,
        "description": "Azure AD audit logs",
        "join_conditions": [
            {"log_field": "IPAddress", "supported_indicator_types": ["ipv4-addr:value"]}
        ]
    },
    "NonInteractiveUserSignInLogs": {
        "table_name": "AADNonInteractiveUserSignInLogs",
        "id_field": "Id",
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated",
        "enabled": ENABLE_NON_INTERACTIVE_SIGNIN_LOGS,
        "description": "Non-interactive user sign-ins",
        "join_conditions": [
            {"log_field": "IPAddress", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value" ]}
        ]
    },
    "ServicePrincipalSignInLogs": {
        "table_name": "AADServicePrincipalSignInLogs",
        "id_field": "Id",
        "tenant_field": "TenantId", 
        "time_field": "TimeGenerated",
        "enabled": ENABLE_SERVICE_PRINCIPAL_SIGNIN_LOGS,
        "description": "Service principal sign-in logs",
        "join_conditions": [
            {"log_field": "IPAddress", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]}
        ]
    },
    "UserRiskEvents": {
        "table_name": "AADUserRiskEvents",
        "id_field": "Id",
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated", 
        "enabled": ENABLE_USER_RISK_EVENTS,
        "description": "User risk events",
        "join_conditions": [
            {"log_field": "IPAddress",  "supported_indicator_types": ["ipv4-addr:value"]}
        ]
    },
    "Syslog": {
        "table_name": "SysLog_9_12_100m_SPRK_CL",
        "id_field": None,
        "tenant_field": "TenantId",    
        "time_field": "TimeGenerated", 
        "enabled": ENABLE_SYS_LOGS,
        "description": "System logs",
        "join_conditions": [
            {"log_field": "SyslogMessage", "log_field_array_regex": DOMAIN_REGEX, "supported_indicator_types": ["domain-name:value"]},
            {"log_field": "SyslogMessage", "log_field_array_regex": URL_REGEX, "supported_indicator_types": ["url:value"]},
            {"log_field": "SyslogMessage", "log_field_array_regex": IP_REGEX, "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
            {"log_field": "HostIP", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]}
        ]
    },
    "WindowsEventLogs": {
        "table_name": "WindowsEvent", 
        "json_field": "EventData",
        "json_struct": WINDOWS_EVENT_LOGS_STRUCT,
        "id_field": "EventId",
        "tenant_field": "TenantId",     
        "time_field": "TimeGenerated", 
        "enabled": ENABLE_WINDOWS_EVENT_LOGS,
        "description": "Windows event logs",
        "join_conditions": [
            {"log_field": "EventData.TargetUserName", "supported_indicator_types": ["email-addr:value"]},
            {"log_field": "EventData.FileHash", "supported_indicator_types": ["file:hashes.MD5", "file:hashes.SHA-1", "file:hashes.SHA-256", "file:ctime"]},
            {"log_field": "EventData.SourceAddress", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
            {"log_field": "EventData.DestAddress", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
        ]
    },
    "SecurityEventLogs": {
        "table_name": "SecurityEvent", 
        "id_field": "EventOriginId",
        "tenant_field": "TenantId",      
        "time_field": "TimeGenerated",   
        "enabled": ENABLE_SECURITY_EVENT_LOGS,
        "description": "Security event logs",
        "join_conditions": [
            {"log_field": "IpAddress", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
            {"log_field": "TargetUserName", "supported_indicator_types": ["email-addr:value"]},
            {"log_field": "FileHash", "supported_indicator_types": ["file:hashes.MD5", "file:hashes.SHA-1", "file:hashes.SHA-256"]},
        ]
    },
    "SecurityIoTRawEventLogs": {
        "table_name": "SecurityIoTRawEvent", 
        "json_field": "EventDetails",
        "json_struct": SECURITY_IOT_RAW_EVENT_LOGS_STRUCT,
        "id_field": "IoTRawEventId",                
        "tenant_field": None,      
        "time_field": "TimeGenerated",   
        "enabled": ENABLE_SECURITY_IOT_RAW_EVENT_LOGS,
        "description": "Security IoT raw event logs",
        "join_conditions": [
            {"log_field": "nested_data.LocalAddress",  "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
            {"log_field": "nested_data.RemoteAddress",  "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
        ]
    },
    "OfficeLogs": {
        "table_name": "OfficeActivity",  
        "id_field": "OfficeId",
        "tenant_field": "TenantId",      
        "time_field": "TimeGenerated",   
        "enabled": ENABLE_OFFICE_LOGS,
        "description": "Office logs",
        "join_conditions": [
            {"log_field": "ClientIP",   "log_field_array_regex": "\\[?(::ffff:)?((?:\\d{1,3}\\.){3}\\d{1,3}|[a-fA-F0-9:]+)(?:%\\d+)?\\]?", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
            {"log_field": "UserId",   "supported_indicator_types": ["email-addr:value"]},
        ]
    },
    "DnsLogs": {
        "table_name": "DnsEvents", 
        "id_field": "EventId",                   
        "tenant_field": "TenantId",         
        "time_field": "TimeGenerated",      
        "enabled": ENABLE_DNS_LOGS,
        "description": "DNS logs",
        "join_conditions": [
            {"log_field": "IPAddresses", "log_separator": ",", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
            {"log_field": "Name",  "supported_indicator_types": ["domain-name:value"]},
        ]
    },
    "CommonSecurityLogs": {
        "table_name": "CommonSecurityLog",
        "id_field": None,
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated",
        "enabled": ENABLE_COMMON_SECURITY_LOGS,
        "description": "Common security logs",
        "join_conditions": [
            {"log_field": "RequestURL", "log_field_value_regex": DOMAIN_REGEX_PAT, "supported_indicator_types": ["domain-name:value"]},
            {"log_field": "AdditionalExtensions", "log_field_value_regex": DOMAIN_REGEX_PAT, "supported_indicator_types": ["domain-name:value"]},
            {"log_field": "RequestURL", "supported_indicator_types": ["url:value"]},
            {"log_field": "AdditionalExtensions", "log_field_value_regex": URL_REGEX_PAT, "supported_indicator_types": ["url:value"]},
            {"log_field": "FileHash", "supported_indicator_types": ["file:hashes.MD5", "file:hashes.SHA-1", "file:hashes.SHA-256"]},
            {"log_field": "SourceIP","supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
            {"log_field": "DestinationIP", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]}
        ]
    },
    "EventLogs": {
        "table_name": "Event",
        "id_field": "EventId",
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated",
        "nested_regex_field": "EventData",
        "nested_regex_pattern": EVENT_KV_REGEX,
        "nested_regex_keys": EVENT_KV_KEYS,
        "enabled": ENABLE_EVENT_LOGS,
        "description": "Event logs",
        "join_conditions": [
            {"log_field": "Compute", "supported_indicator_types": ["domain-name:value"]},
            {"log_field": "nested_data.Hashes",  "log_field_array_regex": "(?<=:)\\s*([^,]*)", "supported_indicator_types": ["file:hashes.MD5", "file:hashes.SHA-1", "file:hashes.SHA-256"]},
        ]
    },
    "W3CIISLogs": {
        "table_name": "W3CIISLog",
        "id_field": "todo",
        "tenant_field": "TenantId",
        "time_field": "TimeGenerated",
        "enabled": ENABLE_W3CIIS_LOGS,
        "description":  "W3C IIS logs",
        "join_conditions": [
            {"log_field": "cIP", "supported_indicator_types": ["ipv4-addr:value", "ipv6-addr:value", "network-traffic:src_ref.value", "network-traffic:dst_ref.value"]},
        ]
    }
}

# ===============================================================================
# PARAMETER VALIDATION
# ===============================================================================
if LOOKBACK_DAYS <= 0:
    raise ValueError("LOOKBACK_DAYS must be positive")

if MATCH_MODE not in ["strict", "current", "loose"]:
    raise ValueError("MATCH_MODE must be one of: strict, current, loose")

if not RESULTS_TABLE or not isinstance(RESULTS_TABLE, str) or RESULTS_TABLE.strip() == "":
    raise ValueError("RESULTS_TABLE must be a non-empty string before saving results.")
if not RESULTS_TABLE.endswith('_SPRK_CL'):
    RESULTS_TABLE = f"{RESULTS_TABLE}_SPRK_CL"

enabled_sources = [source for source, config in LOG_SOURCES.items() if config["enabled"]]
if not enabled_sources and not USE_TEST_DATA_LOGS:
    raise ValueError("At least one log source must be enabled OR USE_TEST_DATA_LOGS must be True")

for source_name, config in LOG_SOURCES.items():
    for idx, join_condition in enumerate(config.get("join_conditions", [])):
        has_value_regex = "log_field_value_regex" in join_condition
        has_array_regex = "log_field_array_regex" in join_condition
        if has_value_regex and has_array_regex:
            raise ValueError(
                f"Configuration error in {source_name} join_condition[{idx}]: Cannot specify both 'log_field_value_regex' and 'log_field_array_regex'."
            )

print("Notebook version:", VERSION)
print(f"Configuration loaded: {WORKSPACE_NAME}, {LOOKBACK_DAYS} days lookback, '{MATCH_MODE}' matching mode")
if enabled_sources:
    print(f"Enabled log sources: {', '.join(enabled_sources)}")
else:
    print("No real log sources enabled - using test data only for fast execution")

# Collect supported indicator types from ENABLED sources only
supported_observable_keys = sorted({
    it
    for _, cfg in LOG_SOURCES.items()
    if cfg.get("enabled")
    for jc in cfg.get("join_conditions", [])
    for it in jc.get("supported_indicator_types", [])
})
if SHOW_DEBUG_LOGS:
    print(f"Collected {len(supported_observable_keys)} unique indicator types from ENABLED log sources:")
    for key in supported_observable_keys:
        print(f"  • {key}")

StatementMeta(MSGLarge, 1563, 5, Finished, Available, Finished)

Notebook version: 1.0.1
Configuration loaded: None, 365 days lookback, 'current' matching mode
Enabled log sources: SigninLogs, ManagedIdentitySigninLogs, NonInteractiveUserSignInLogs, ServicePrincipalSignInLogs, Syslog, CommonSecurityLogs


## Imports, Sentinel Provider, and Spark Configs

In [6]:
# ===============================================================================
# IMPORTS AND SETUP
# ===============================================================================
import json, uuid
from datetime import datetime, timedelta
import time

from pyspark.sql.functions import (
    broadcast, expr, lit, current_timestamp, col, array, struct, when,
    count as spark_count, row_number, first, collect_list, flatten, size,
    get_json_object, from_json, to_json, explode, regexp_extract, split, trim, lower, map_from_arrays,
    concat_ws, array_distinct, array_union, coalesce, udf, sum, count_distinct, collect_set
)

from pyspark.sql.types import (
    StringType, ArrayType, StructType, StructField, TimestampType
)
from pyspark.sql import Row
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from sentinel_lake.providers import MicrosoftSentinelProvider

# Start time
start = time.time()

# Spark conf levers (AQE, skew, local shuffle, parquet, thresholds)
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", str(2 * TARGET_PARTITION_SIZE_BYTES))
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", str(150 * 1024 * 1024))
spark.conf.set("spark.sql.parquet.compression.codec", "snappy")
spark.conf.set("spark.sql.parquet.filterPushdown", "true")
spark.conf.set("spark.sql.parquet.mergeSchema", "false")
spark.conf.set("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", str(TARGET_PARTITION_SIZE_BYTES))
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", str(TARGET_PARTITION_SIZE_BYTES))
spark.conf.set("spark.sql.files.maxPartitionBytes", str(TARGET_PARTITION_SIZE_BYTES))

# Provider init + load base TI tables
data_provider = MicrosoftSentinelProvider(spark)
print("✓ Microsoft Sentinel data provider initialized successfully")

# Ensure enrichment mapping variable is defined (may be populated later)
indicator_actor_by_indicator = None

# Job start time for logging and tracking
job_start_time = current_timestamp()

# Helper: quick sample count
def sample_count(df, sample_rate=0.01):
    try:
        sc = df.sample(fraction=sample_rate, seed=42).count()
        return int(sc / sample_rate)
    except Exception:
        return 0

# Optional: auto-tune shuffle partitions using rough size estimate
def maybe_auto_tune_shuffle(df_list):
    if not AUTO_TUNE_SHUFFLE_PARTITIONS:
        return
    try:
        MIN_LIMIT = 200
        MAX_LIMIT = 10000
        def est_parts(df):
            samp = df.limit(MAX_LIMIT).rdd.map(lambda r: len(str(r))).mean()
            cnt = df.count()
            est_bytes = samp * cnt
            return max(1, int(est_bytes // TARGET_PARTITION_SIZE_BYTES))
        parts = max(est_parts(d) for d in df_list if d is not None)
        if SHOW_DEBUG_LOGS: print(f"Max estimated partitions={parts}, setting limits [{MIN_LIMIT}, {MAX_LIMIT}]")
        parts = max(MIN_LIMIT, min(MAX_LIMIT, parts))
        spark.conf.set("spark.sql.shuffle.partitions", str(parts))
        if SHOW_DEBUG_LOGS: print(f"Auto-tuned spark.sql.shuffle.partitions={parts}")
    except Exception as e:
        if SHOW_DEBUG_LOGS: print(f"Auto-tune skipped: {e}")


if WORKSPACE_NAME is None or WORKSPACE_NAME.strip() == "":
    print("No workspace name provided, automatically selecting the first available non-default Log Analytics workspace.")

    databases = data_provider.list_databases()
    for db in databases:
        if db not in ["default"]:
            WORKSPACE_NAME = db
            print(f"Auto-selected workspace: {WORKSPACE_NAME}")
            break

StatementMeta(MSGLarge, 1563, 7, Finished, Available, Finished)

✓ Microsoft Sentinel data provider initialized successfully
No workspace name provided, automatically selecting the first available non-default Log Analytics workspace.
Auto-selected workspace: Woodgrove-LogAnalyiticsWorkspace


## Threat actor enrichment via ThreatIntelObjects (with fallback)

In [7]:
# ===============================================================================
# LOAD THREAT INTEL OBJECTS AND BUILD INDICATOR -> THREAT ACTORS
# ===============================================================================
try:
    ti_objects_df = data_provider.read_table(THREAT_INTEL_OBJECTS_TABLE, WORKSPACE_NAME)
    ti_objects_df = ti_objects_df.withColumnRenamed(
        "TenantId", "TIO_TenantId"
    ).withColumnRenamed("TimeGenerated", "TIO_TimeGenerated")

    parsed = ti_objects_df.select(
        col("StixType").alias("ObjectType"),
        get_json_object(col("Data"), "$.id").alias("ObjectId"),
        get_json_object(col("Data"), "$.name").alias("name"),
        get_json_object(col("Data"), "$.source_ref").alias("source_ref"),
        get_json_object(col("Data"), "$.target_ref").alias("target_ref"),
        get_json_object(col("Data"), "$.relationship_type").alias("relationship_type"),
        from_json(
            get_json_object(col("Data"), "$.aliases"), ArrayType(StringType())
        ).alias("aliases"),
        from_json(
            get_json_object(col("Data"), "$.threat_actor_types"),
            ArrayType(StringType()),
        ).alias("threat_actor_types"),
        col("Data").alias("ObjectData"),  # keep raw JSON payload here
    )

    # relationships
    relationships = parsed.filter(col("ObjectType") == lit("relationship")).select(
        "source_ref", "target_ref"
    )

    # attack-patterns and actors (include ObjectData in actors so we can reference it later)
    attack_patterns = parsed.filter(col("ObjectType") == lit("attack-pattern")).select(
        col("ObjectId").alias("AttackPatternRef"),
        col("ObjectData"),
    )

    actors = parsed.filter(col("ObjectType") == lit("threat-actor")).select(
        col("ObjectId").alias("ThreatActorId"),
        col("name").alias("ThreatActorName"),
        col("aliases").alias("ThreatActorAliases"),
        col("threat_actor_types").alias("ThreatActorTypes"),
        col("ObjectData"),
    )

    # attack-pattern <-> threat-actor relationships (both directions)
    ap2ta = relationships.filter(
        col("source_ref").startswith("attack-pattern--")
        & col("target_ref").startswith("threat-actor--")
    ).select(
        col("source_ref").alias("AttackPatternRef"),
        col("target_ref").alias("ThreatActorRef"),
    )

    ap2ta_rev = relationships.filter(
        col("target_ref").startswith("attack-pattern--")
        & col("source_ref").startswith("threat-actor--")
    ).select(
        col("target_ref").alias("AttackPatternRef"),
        col("source_ref").alias("ThreatActorRef"),
    )

    ap2ta_all = ap2ta.union(ap2ta_rev).dropDuplicates()

    # parse external_references from attack-pattern ObjectData
    ext_ref_item = StructType(
        [
            StructField("external_id", StringType(), True),
            StructField("source_name", StringType(), True),
            StructField("url", StringType(), True),
        ]
    )
    ext_ref_wrapper = StructType(
        [StructField("external_references", ArrayType(ext_ref_item), True)]
    )

    attack_patterns_parsed = attack_patterns.select(
        "AttackPatternRef",
        from_json(col("ObjectData"), ext_ref_wrapper)
        .getField("external_references")
        .alias("external_references"),
        get_json_object(col("ObjectData"), "$.name").alias("name"),
    )

    # empty array literal
    empty_str_arr = from_json(lit("[]"), ArrayType(StringType()))

    # extract external_id array without explode (optionally filter by source_name == 'mitre-attack' if desired)
    ttp_expr = "array_distinct(filter(transform(external_references, x -> x.external_id), x -> x is not null))"

    # SQL-only extraction from name: match a MITRE token at the start of the (trimmed) name.
    # Pattern explanation:
    #  - ^(?i)        : start of string, case-insensitive
    #  - (T\d+(?:\.\d+)*) : capture T + digits, optional .digits groups (e.g., T1583 or T1583.003)
    #  - (?:\s|:|$)   : ensure token is followed by space, colon, or end-of-string
    #
    # We use regexp_extract(trim(name), <pattern>, 1) to capture the token, then wrap into an array,
    # filter out empty, uppercase, and dedupe with array_distinct.
    ttp_from_name_expr = (
        "array_distinct("
        "  transform("
        "    filter("
        "      array(regexp_extract(trim(name), '^(?i)(T\\\\d+(?:\\\\.\\\\d+)*)(?:\\\\s|:|$)', 1)),"
        "      x -> x IS NOT NULL AND x <> ''"
        "    ),"
        "    x -> upper(x)"
        "  )"
        ")"
    )

    # Final TTPs: prefer external_references when present/non-empty; otherwise extract from name start
    ap_ttps = attack_patterns_parsed.select(
        "AttackPatternRef",
        when(
            (col("external_references").isNotNull()) & (size(col("external_references")) > 0),
            expr(ttp_expr),
        )
        .otherwise(
            when(col("name").isNotNull(), expr(ttp_from_name_expr)).otherwise(empty_str_arr)
        )
        .alias("TTPs"),
    )

    # ensure referenced attack-patterns have entries
    ap_refs = ap2ta_all.select("AttackPatternRef").distinct()
    ap_ttps_all = ap_refs.join(ap_ttps, on="AttackPatternRef", how="left").select(
        col("AttackPatternRef"), coalesce(col("TTPs"), empty_str_arr).alias("TTPs")
    )

    # roll up TTPs to threat actors
    actor_ttp = ap2ta_all.join(ap_ttps_all, on="AttackPatternRef", how="left").select(
        col("ThreatActorRef"), col("TTPs")
    )

    actor_ttp_by_actor = (
        actor_ttp.groupBy("ThreatActorRef")
        .agg(array_distinct(flatten(collect_list(col("TTPs")))).alias("TTPs"))
        .select(
            col("ThreatActorRef"), coalesce(col("TTPs"), empty_str_arr).alias("TTPs")
        )
    )

    print(
        f"Extracted {actor_ttp_by_actor.count()} threat actors with TTPs from ThreatIntelObjects"
    )
    if SHOW_DEBUG_LOGS:
        actor_ttp_by_actor.show(10, truncate=False)

    # attach actor TTPs into actors_enriched (use qualified column expressions)
    a = actors.alias("a")
    t = actor_ttp_by_actor.alias("t")

    actors_enriched = a.join(
        t, col("a.ThreatActorId") == col("t.ThreatActorRef"), how="left"
    ).select(
        col("a.ThreatActorId").alias("ThreatActorRef"),
        col("a.ThreatActorName"),
        col("a.ThreatActorAliases"),
        col("a.ThreatActorTypes"),
        coalesce(col("t.TTPs"), empty_str_arr).alias("TTPs"),
        col("a.ObjectData").alias("ObjectData"),
    )

    print(f"Enriched threat actors:")
    if SHOW_DEBUG_LOGS:
        actors_enriched.show(10, truncate=True)

    # --- indicator <-> threat-actor pairs (both directions) ---
    i2a = relationships.filter(
        col("source_ref").startswith("indicator--")
        & col("target_ref").startswith("threat-actor--")
    ).select(
        col("source_ref").alias("IndicatorRef"),
        col("target_ref").alias("ThreatActorRef"),
    )

    i2a_rev = relationships.filter(
        col("target_ref").startswith("indicator--")
        & col("source_ref").startswith("threat-actor--")
    ).select(
        col("target_ref").alias("IndicatorRef"),
        col("source_ref").alias("ThreatActorRef"),
    )

    i2a_all = i2a.union(i2a_rev).dropDuplicates()

    # join indicators -> actors_enriched and collect actor names and TTPs per indicator
    enriched = i2a_all.join(
        broadcast(actors_enriched),
        i2a_all.ThreatActorRef == actors_enriched.ThreatActorRef,
        "left",
    ).select(
        i2a_all.IndicatorRef.alias("IndicatorId"),
        actors_enriched.ThreatActorName,
        actors_enriched.ThreatActorAliases,
        actors_enriched.ThreatActorTypes,
        actors_enriched.TTPs,
    )

    # build ThreatActors array (name + aliases)
    enriched = (
        enriched.withColumn(
            "_name_arr",
            when(
                col("ThreatActorName").isNotNull(), array(col("ThreatActorName"))
            ).otherwise(empty_str_arr),
        )
        .withColumn(
            "ThreatActors",
            array_distinct(
                array_union(
                    col("_name_arr"), coalesce(col("ThreatActorAliases"), empty_str_arr)
                )
            ),
        )
        .drop("_name_arr")
    )

    # aggregate to indicator
    indicator_actor_by_indicator = enriched.groupBy("IndicatorId").agg(
        array_distinct(
            flatten(collect_list(coalesce(col("ThreatActors"), empty_str_arr)))
        ).alias("ThreatActors"),
        array_distinct(
            flatten(collect_list(coalesce(col("TTPs"), empty_str_arr)))
        ).alias("TTPs"),
    )

    # normalize defaults
    indicator_actor_by_indicator = indicator_actor_by_indicator.withColumn(
        "ThreatActors",
        when(#
            (col("ThreatActors").isNull()) | (size(col("ThreatActors")) == 0),
            array(lit("Unknown Actor")),
        ).otherwise(col("ThreatActors")),
    ).withColumn("TTPs", coalesce(col("TTPs"), empty_str_arr))

    indicator_actor_by_indicator.cache()
    # 1) explode ThreatActors to one actor per row
    exploded_by_actor = indicator_actor_by_indicator.select(
        F.col("IndicatorId"),
        F.explode(F.col("ThreatActors")).alias("ThreatActor"),
        F.col("TTPs")
    )

    # 2) group by actor and aggregate
    actors_rolled_up = exploded_by_actor.groupBy("ThreatActor").agg(
        F.collect_set("IndicatorId").alias("IndicatorIds"),
        # collect_list -> flatten (array of arrays) -> distinct to dedupe TTPs
        F.array_distinct(F.flatten(F.collect_list(F.coalesce(F.col("TTPs"), empty_str_arr)))).alias("TTPs")
    )

    # Optional: sort and show results
    actors_rolled_up = actors_rolled_up.orderBy("ThreatActor")
    print(f"Rolled up to {actors_rolled_up.count()} unique threat actors with indicators and TTPs:")
    if SHOW_DEBUG_LOGS:
        actors_rolled_up.show(truncate=True)
    print("✓ ThreatIntelObjects enrichment ready")

except Exception as e:
    indicator_actor_by_indicator = None
    print(f"⚠ ThreatIntelObjects not available or failed to load: {e}")

StatementMeta(MSGLarge, 1563, 8, Finished, Available, Finished)

{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Loading table: ThreatIntelObjects"}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Successfully loaded table ThreatIntelObjects"}
Extracted 15 threat actors with TTPs from ThreatIntelObjects
Enriched threat actors:
Rolled up to 631 unique threat actors with indicators and TTPs:
✓ ThreatIntelObjects enrichment ready


## Load and filter log data

In [8]:
# ===============================================================================
# Helper function for creating composite Id fields
# ===============================================================================
def get_record_id(id_field):
    """
    Create an ID for log entries.
    Returns the actual ID or None if no ID field is available.
    """
    if id_field is None:
        # Return null for RecordId when there's no ID field
        return lit(None).alias("Id")
    else:
        # Single ID field
        return col(id_field).cast("string").alias("Id")


# ===============================================================================
# Helper function for extracting values
# ===============================================================================
def extract_all(pattern, text):
    if text is None:
        return []
    return re.findall(pattern, text)


# ===============================================================================
# Collect table names in the selected workspace
# ===============================================================================
tables = data_provider.list_tables(WORKSPACE_NAME)
table_names = [t.name for t in tables]

# ===============================================================================
# LOAD AND FILTER LOG DATA (multiple sources, base behavior)
# ===============================================================================
def load_log_source(source_name, config, start_date, end_date):
    try:
        # Store the actual table name for use in EventReferences
        actual_table_name = config["table_name"]
        if actual_table_name not in table_names:
            # If a configured table is not found, silently skip it with a warning
            print(f"Warning: {actual_table_name} not found in workspace {WORKSPACE_NAME}, skipping...")
            return []

        # Load the log data. We expect the table to exist at this point
        df = data_provider.read_table(config["table_name"], WORKSPACE_NAME)
        df = df.filter(
            (col(config["time_field"]) <= end_date)
            & (col(config["time_field"]) >= start_date)
        )
        # Parse nested formats
        if "json_struct" in config and "json_field" in config:
            df = df.withColumn(
                "nested_data",
                from_json(col(config["json_field"]), config["json_struct"]),
            )
        elif "nested_regex_field" in config and "nested_regex_pattern" in config:
            df = (
                df.withColumn(
                    "nested_data_raw",
                    map_from_arrays(
                        expr(
                            f"regexp_extract_all(`{config['nested_regex_field']}`, '{config['nested_regex_pattern']}', 1)"
                        ),
                        expr(
                            f"regexp_extract_all(`{config['nested_regex_field']}`, '{config['nested_regex_pattern']}', 2)"
                        ),
                    ),
                )
                .withColumn(
                    "nested_data",
                    struct(
                        *[
                            col("nested_data_raw").getItem(k).alias(k)
                            for k in config["nested_regex_keys"]
                        ]
                    ),
                )
                .drop("nested_data_raw")
            )
        dfs = []
        for jc in config.get("join_conditions", []):
            print(f"join condition: {jc}")
            log_field = jc.get("log_field")
            if "." in log_field and not log_field.startswith("nested_data."):
                raise ValueError(
                    f"Configuration error in {source_name}: Nested field '{log_field}' is not supported unless under 'nested_data.*'"
                )
            field_col = (
                col(log_field)
                if "." not in log_field or log_field.startswith("nested_data.")
                else col(log_field)
            )
            if "log_field_value_regex" in jc:
                extracted = regexp_extract(field_col, jc["log_field_value_regex"], 1)
                sub = df.select(
                    lit(actual_table_name).alias("LogSource"),
                    trim(lower(extracted)).alias("ObservableValue"),
                    lit(config["id_field"]).alias("IdField"),
                    get_record_id(config["id_field"]),
                    (
                        col(config.get("tenant_field"))
                        if config.get("tenant_field")
                        else lit(None)
                    ).alias("TenantId"),
                    col(config["time_field"]).alias("TimeGenerated"),
                    lit(log_field).alias("LogField"),
                    extracted.alias("OriginalValue")
                )
            elif "log_field_array_regex" in jc:
                extract_all_udf = udf(lambda text: extract_all(jc['log_field_array_regex'], text), ArrayType(StringType()))

                # Extract all matches directly as an array
                temp = df.withColumn("_extracted_array", 
                    extract_all_udf(field_col)
                )

                # Explode the array to get individual values
                temp = temp.withColumn("_extracted_value", 
                    explode(col("_extracted_array"))
                )

                sub = temp.select(
                    lit(actual_table_name).alias("LogSource"),
                    trim(lower(col("_extracted_value"))).alias("ObservableValue"),
                    lit(config["id_field"]).alias("IdField"),
                    get_record_id(config["id_field"]),
                    (
                        col(config.get("tenant_field")).alias("TenantId")
                        if config.get("tenant_field")
                        else lit(None).alias("TenantId")
                    ),
                    col(config["time_field"]).alias("TimeGenerated"),
                    lit(log_field).alias("LogField"),
                    col("_extracted_value").alias("OriginalValue")
                ).filter(
                    col("ObservableValue").isNotNull() & (col("ObservableValue") != "")
                )
            else:
                if "log_separator" in jc:
                    temp = df.withColumn(
                        "_temp_array", split(field_col, jc["log_separator"])
                    ).withColumn("_temp_exploded", explode(col("_temp_array")))
                    sub = temp.select(
                        lit(actual_table_name).alias("LogSource"),
                        trim(lower(col("_temp_exploded"))).alias("ObservableValue"),
                        lit(config["id_field"]).alias("IdField"),
                        get_record_id(config["id_field"]),
                        (
                            col(config.get("tenant_field"))
                            if config.get("tenant_field")
                            else lit(None)
                        ).alias("TenantId"),
                        col(config["time_field"]).alias("TimeGenerated"),
                        lit(log_field).alias("LogField"),
                        col("_temp_exploded").alias("OriginalValue")
                    )
                else:
                    sub = df.select(
                        lit(actual_table_name).alias("LogSource"),
                        trim(lower(field_col)).alias("ObservableValue"),
                        lit(config["id_field"]).alias("IdField"),
                        get_record_id(config["id_field"]),
                        (
                            col(config.get("tenant_field"))
                            if config.get("tenant_field")
                            else lit(None)
                        ).alias("TenantId"),
                        col(config["time_field"]).alias("TimeGenerated"),
                        lit(log_field).alias("LogField"),
                        field_col.alias("OriginalValue")
                    )
            sub = sub.filter(
                col("ObservableValue").isNotNull() & (col("ObservableValue") != "")
            )
            if "log_filter_field" in jc and "log_filter_value" in jc:
                sub = sub.filter(col(jc["log_filter_field"]) == jc["log_filter_value"])
            dfs.append(sub)
        if not dfs:
            return None
        out = dfs[0]
        for other in dfs[1:]:
            out = out.union(other)
        out = out.dropDuplicates(["Id", "ObservableValue"]).repartition(
            "ObservableValue"
        )
        if SHOW_DEBUG_LOGS:
            print(f"✓ Filtered {source_name}")
        return out
    except Exception as e:
        print(f"✗ Error loading {source_name}: {e}")
        raise


# Calculate date range for filtering (though not currently used in load_log_source)
end_date = datetime.now() + timedelta(days=3)
start_date = end_date - timedelta(days=LOOKBACK_DAYS)

# Load enabled log sources
combined_logs_df = None
for source_name, config in LOG_SOURCES.items():
    if config.get("enabled", False):
        try:
            source_df = load_log_source(source_name, config, start_date, end_date)
            if source_df is not None:
                if combined_logs_df is None:
                    combined_logs_df = source_df
                else:
                    combined_logs_df = combined_logs_df.union(source_df)
        except Exception as e:
            print(f"⚠ Skipping {source_name} due to error: {e}")

if combined_logs_df is not None:
    combined_logs_df = combined_logs_df.cache()
    if SHOW_DEBUG_LOGS or REDUCED_DEBUG_LOGS:
        combined_logs_count = combined_logs_df.count()
        print(f"✓ Combined log sources loaded and cached. Total records: {combined_logs_count}")
elif not USE_TEST_DATA_LOGS:
    raise RuntimeError("No log data loaded and USE_TEST_DATA_LOGS is False")

StatementMeta(MSGLarge, 1563, 9, Finished, Available, Finished)

{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Loading table: SigninLogs"}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Successfully loaded table SigninLogs"}
join condition: {'log_field': 'IPAddress', 'supported_indicator_types': ['ipv4-addr:value', 'ipv6-addr:value', 'network-traffic:src_ref.value', 'network-traffic:dst_ref.value']}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Loading table: AADManagedIdentitySignInLogs"}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Successfully loaded table AADManagedIdentitySignInLogs"}
join condition: {'log_field': 'IPAddress', 'supported_indicator_types': ['ipv4-addr:value', 'ipv6-addr:value', 'network-traffic:src_ref.value', 'network-traffic:dst_ref.value']}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Loading table: AADNonInteractiveUserSignInLogs"}
{"level": "INFO", "run_id

## Load and split Threat Intelligence Indicators

In [9]:
# ===============================================================================
# LOAD AND SPLIT THREAT INTELLIGENCE INDICATORS BY OBSERVABLE KEY
# ===============================================================================
raw_threat_intel_df = data_provider.read_table(THREAT_INTEL_TABLE, WORKSPACE_NAME)
print(f"✓ Loaded threat intelligence table: {THREAT_INTEL_TABLE}")

# Current-time filter (for MATCH_MODE == 'current'); strict handled later at join using log times
current_time = current_timestamp()
threat_intel_df = raw_threat_intel_df
if MATCH_MODE == "current":
    threat_intel_df = threat_intel_df.filter((col("ValidFrom") <= current_time) & (col("ValidUntil") >= current_time))
if MATCH_MODE in ["current", "strict"]:
    threat_intel_df = threat_intel_df.filter(col("IsActive") == True)

indicator_dfs = {}
for key in supported_observable_keys:
    df = threat_intel_df.filter(col('ObservableKey') == key) 
    df = df.withColumnRenamed("TenantId", "TI_TenantId").withColumnRenamed("TimeGenerated", "TI_TimeGenerated")
    df = df.withColumn("ObservableValue", trim(lower(col("ObservableValue"))))
    # Keep needed columns for later
    df = df.select( 
        "Id", "TI_TenantId", "TI_TimeGenerated", "ObservableKey", "ObservableValue", "ValidFrom", "ValidUntil", "Data", "Pattern"
    )
    # Enrich ThreatActors via ThreatIntelObjects if available; else fallback to Data JSON field
    if indicator_actor_by_indicator is not None:
        df = df.withColumn("IndicatorId", get_json_object(col("Data"), "$.id"))
        df = df.join(indicator_actor_by_indicator, df.IndicatorId == indicator_actor_by_indicator.IndicatorId, "left").drop(indicator_actor_by_indicator.IndicatorId)
    else:
        df = df.withColumn("ThreatActors", when(col("Data").isNotNull(), from_json(get_json_object(col("Data"), "$.threat_actors"), ArrayType(StringType()))).otherwise(lit(None)))
    # Deduplicate: latest per Id
    w = Window.partitionBy("Id").orderBy(col("TI_TimeGenerated").desc())
    df = df.withColumn("row_num", row_number().over(w)).filter(col("row_num") == 1).drop("row_num")
    df = df.repartition("ObservableValue").cache()
    indicator_dfs[key] = df
print(f"✓ Prepared indicator splits for {len(indicator_dfs)} observable keys")

StatementMeta(MSGLarge, 1563, 10, Finished, Available, Finished)

{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Loading table: ThreatIntelIndicators"}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Successfully loaded table ThreatIntelIndicators"}
✓ Loaded threat intelligence table: ThreatIntelIndicators
✓ Prepared indicator splits for 9 observable keys


## Inject test data (optional)

In [10]:
# Test constants
TEST_IP = "142.202.188.59"
TEST_DOMAIN = "malicious-test.example.com"
TEST_URL = "http://malicious-test.example.com/payload"
TEST_FILE_HASH = "abcd1234567890abcd1234567890abcd1234567890abcd1234567890abcd1234"
TEST_TENANT_ID = "029c55c8-a7ec-418e-b741-de9d24add5fa"
TEST_TIMESTAMP = datetime.strptime("2025-07-15T16:29:31.883Z", "%Y-%m-%dT%H:%M:%S.%fZ")

def create_test_log_row(source_name="SigninLogs", observable_key="ipv4-addr:value", observable_value=TEST_IP):
    return {
        'LogSource': source_name, 'ObservableKey': observable_key, 'ObservableValue': observable_value,
        'Id': str(uuid.uuid4()), 'TenantId': TEST_TENANT_ID, 'TimeGenerated': TEST_TIMESTAMP,
    }

def create_test_ti_row(observable_key="ipv4-addr:value", observable_value=TEST_IP):
    any_df = next((df for df in indicator_dfs.values() if df is not None), None)
    if any_df is None: return None
    test_row = Row(**{f.name: None for f in any_df.schema.fields}).asDict()
    indicator_id = f"indicator--test-{uuid.uuid4()}"
    test_data = {
        'pattern': f"[{observable_key} = '{observable_value}']",
        'pattern_type': 'stix',
        'valid_from': '2025-07-15T16:25:04.2001568Z',
        'name': f"Test IOC - {observable_value}",
        'description': 'Test indicator',
        'indicator_types': ['WatchList'],
        'valid_until': '2025-07-31T21:03:26.8933330Z',
        'confidence': 75,
        'type': 'indicator',
        'id': indicator_id,
    }
    test_row.update({
        'TI_TenantId': TEST_TENANT_ID, 'TI_TimeGenerated': TEST_TIMESTAMP, 'WorkspaceId': TEST_TENANT_ID,
        'AzureTenantId': '536279f6-15cc-45f2-be2d-61e352b51eef', 'Id': f"TEST---{indicator_id}",
        'SourceSystem': 'Test Data Generator', 'LastUpdateMethod': 'TestDataInjection', 'IsDeleted': False,
        'AdditionalFields': json.dumps({'TLPLevel': 'Green'}), 'Data': json.dumps(test_data), 'IsActive': True,
        'ValidUntil': TEST_TIMESTAMP, 'ValidFrom': TEST_TIMESTAMP, 'Created': TEST_TIMESTAMP, 'Modified': TEST_TIMESTAMP,
        'Confidence': test_data['confidence'], 'Pattern': test_data['pattern'],
        'ObservableKey': observable_key, 'ObservableValue': observable_value,
    })
    return test_row

if USE_TEST_DATA_LOGS and combined_logs_df is not None:
    enabled_log_sources = [s for s, c in LOG_SOURCES.items() if c['enabled']]
    test_observables = [("ipv4-addr:value", TEST_IP), ("domain-name:value", TEST_DOMAIN), ("url:value", TEST_URL), ("file:hashes.'SHA-256'", TEST_FILE_HASH)]
    rows = []
    if enabled_log_sources:
        for s, cfg in LOG_SOURCES.items():
            if cfg['enabled']:
                for k, v in test_observables: rows.append(create_test_log_row(s, k, v))
    else:
        for k, v in test_observables: rows.append(create_test_log_row("TestLogSource", k, v))
    if rows:
        test_df = spark.createDataFrame(rows, schema=combined_logs_df.schema)
        combined_logs_df = combined_logs_df.union(test_df).cache()
        print(f"✓ Injected {len(rows)} test log entries")

if USE_TEST_DATA_THREAT_INTEL:
    test_obs = [("ipv4-addr:value", TEST_IP), ("domain-name:value", TEST_DOMAIN), ("url:value", TEST_URL), ("file:hashes.'SHA-256'", TEST_FILE_HASH)]
    for k, v in test_obs:
        if k in indicator_dfs and indicator_dfs[k] is not None:
            ti_row = create_test_ti_row(k, v)
            if ti_row:
                indicator_dfs[k] = indicator_dfs[k].union(spark.createDataFrame([ti_row], schema=indicator_dfs[k].schema))
    print("✓ Injected test TI indicators where possible")

StatementMeta(MSGLarge, 1563, 11, Finished, Available, Finished)

## Filter to intersection and join (strict/current/loose)

In [11]:
# Combine indicators
deduped_indicators_df = None
for key, df in indicator_dfs.items():
    if df is not None:
        sel = df.select('ObservableValue','ObservableKey','ThreatActors','Id','Pattern','TI_TenantId','TI_TimeGenerated','ValidFrom','ValidUntil','Data', 'TTPs')
        deduped_indicators_df = sel if deduped_indicators_df is None else deduped_indicators_df.union(sel)
        print(f"✓ Added {key} indicators: {sel.count():,} records")
        if SHOW_DEBUG_LOGS:
            sel.show(5, truncate=True)

if deduped_indicators_df is None:
    raise RuntimeError('No threat intelligence indicators available after filtering')
else:
    if SHOW_DEBUG_LOGS or REDUCED_DEBUG_LOGS:
        indicators_count = deduped_indicators_df.count()
        print(f"✓ Combined indicators: {indicators_count}")

# Check for MatchIds with multiple ObservableValues
ti_check = deduped_indicators_df.groupBy("Id").agg(
    collect_set("ObservableValue").alias("ObservableValues")
).filter(size("ObservableValues") > 1)

if ti_check.count() > 0:
    print("WARNING: Found TI indicators with multiple observable values:")
    ti_check.show(truncate=False)

if SHOW_DEBUG_LOGS:
    print("Sample of combined logs DataFrame:")
    combined_logs_df.show()

# check for combined logs with multiple observable values
combined_log_check = combined_logs_df.groupBy("ObservableValue").agg(
    collect_set("ObservableValue").alias("ObservableValues")
).filter(size("ObservableValues") > 1)

if combined_log_check.count() > 0:
    print("WARNING: Found combined logs with multiple observable values:")
    combined_log_check.show(truncate=False)

distinct_log_values = combined_logs_df.select('ObservableValue').distinct()
print("Distinct log values:")
if SHOW_DEBUG_LOGS or REDUCED_DEBUG_LOGS:
    distinct_log_count = distinct_log_values.count()
    print(f"✓ Found {distinct_log_count:,} distinct observable values in logs")
if SHOW_DEBUG_LOGS:
    distinct_log_values.show(truncate=False)

distinct_indicator_values = deduped_indicators_df.select('ObservableValue').distinct()
print("Distinct indicator values:")
if SHOW_DEBUG_LOGS or REDUCED_DEBUG_LOGS:
    distinct_indicator_count = distinct_indicator_values.count()
    print(f"✓ Found {distinct_indicator_count:,} distinct observable values in indicators")
if SHOW_DEBUG_LOGS:
    distinct_indicator_values.show(truncate=False)

intersection_values = distinct_log_values.join(broadcast(distinct_indicator_values), ['ObservableValue'], 'inner').repartition('ObservableValue').cache()
print(f"✓ Found {intersection_values.count():,} intersecting observable values between logs and indicators")

filtered_logs_df = combined_logs_df.join(broadcast(intersection_values), ['ObservableValue'], 'inner').cache()
filtered_ti_df = deduped_indicators_df.join(broadcast(intersection_values), ['ObservableValue'], 'inner').cache()

# Optional auto-tune before the heavy join
maybe_auto_tune_shuffle([filtered_logs_df, filtered_ti_df])

logs_alias = filtered_logs_df.select(
    col('ObservableValue'),
    col('LogSource').alias('logs_LogSource'),
    col('IdField').alias('logs_IdField'),
    col('Id').cast('string').alias('logs_Id'),
    col('TenantId').alias('logs_TenantId'),
    col('TimeGenerated').alias('logs_TimeGenerated')
)
 
ti_alias = filtered_ti_df.select(
    col('ObservableValue'),
    col('ObservableKey').alias('ti_ObservableKey'),
    col('ThreatActors').alias('ti_ThreatActors'),
    col('TTPs').alias('ti_TTPs'),
    col('Id').alias('ti_Id'),
    col('Pattern').alias('ti_Pattern'),
    col('TI_TenantId').alias('ti_TI_TenantId'),
    col('TI_TimeGenerated').alias('ti_TI_TimeGenerated'),
    col('ValidFrom').alias('ti_ValidFrom'),
    col('ValidUntil').alias('ti_ValidUntil'),
    col('Data').alias('ti_Data')
)

if MATCH_MODE == 'strict':
    cond = (
        logs_alias["ObservableValue"] == ti_alias["ObservableValue"]
    ) & (
        logs_alias["logs_TimeGenerated"] >= ti_alias["ti_ValidFrom"]
    ) & (
        logs_alias["logs_TimeGenerated"] <= ti_alias["ti_ValidUntil"]
    )
    base_join = logs_alias.join(broadcast(ti_alias), cond, 'inner')
else:
    base_join = logs_alias.join(broadcast(ti_alias), ['ObservableValue'], 'inner')

if SHOW_DEBUG_LOGS:
    print("Sample of logs alias:")
    logs_alias.show(truncate=False)
    print("Sample of ti alias:")
    ti_alias.show(truncate=True)

# Select and alias columns to avoid ambiguity downstream
matched_indicators_df = base_join.select(
    'logs_LogSource',
    'logs_IdField',
    'logs_Id',
    'logs_TenantId',
    'logs_TimeGenerated',
    'ObservableValue',
    'ti_Id',
    'ti_ObservableKey',
    'ti_Pattern',
    'ti_TI_TenantId',
    'ti_TI_TimeGenerated',
    'ti_ValidFrom',
    'ti_ValidUntil',
    'ti_TTPs',
    'ti_ThreatActors',
    'ti_Data'
)

if SHOW_DEBUG_LOGS or REDUCED_DEBUG_LOGS:
    print(f"✓ Matched indicators count: {matched_indicators_df.count():,}")
if SHOW_DEBUG_LOGS:
    matched_indicators_df.show(10, truncate=True)

StatementMeta(MSGLarge, 1563, 12, Finished, Available, Finished)

✓ Added domain-name:value indicators: 34,215 records
✓ Added file:hashes.MD5 indicators: 15,831 records
✓ Added file:hashes.SHA-1 indicators: 0 records
✓ Added file:hashes.SHA-256 indicators: 0 records
✓ Added ipv4-addr:value indicators: 150 records
✓ Added ipv6-addr:value indicators: 0 records
✓ Added network-traffic:dst_ref.value indicators: 0 records
✓ Added network-traffic:src_ref.value indicators: 7,676 records
✓ Added url:value indicators: 2,183 records
✓ Combined indicators: 60055
Distinct log values:
✓ Found 249,171 distinct observable values in logs
Distinct indicator values:
✓ Found 59,280 distinct observable values in indicators
✓ Found 9,528 intersecting observable values between logs and indicators
✓ Matched indicators count: 9,718


## Build results and aggregate (base schema retained)

In [13]:
# Build detailed matches
job_id = str(uuid.uuid4())
data_schema = StructType(
    [
        StructField("indicator_types", ArrayType(StringType()), True),
        StructField("threat_actors", ArrayType(StringType()), True),
    ]
)

print("")

result_df = (
    matched_indicators_df.withColumn("MatchId", col("ti_Id"))
    .withColumn("JobId", lit(job_id))
    .withColumn("JobStartTime", job_start_time)
    .withColumn("JobEndTime", current_timestamp())
    .withColumn("MatchType", lit("IoC"))
    .withColumn("TIReferenceId", col("ti_Id"))
    .withColumn("TIValue", col("ti_Pattern"))
    .withColumn("MatchCount", lit(1).cast("Long"))
    .withColumn(
        "EventReferences",
        array(
            struct(
                col("logs_LogSource").alias("Table"),
                col("logs_IdField").alias("RecordIdField"),
                col("logs_Id").alias("RecordId"),
                col("logs_TimeGenerated").alias("TimeGenerated"),
                col("ti_ObservableKey").alias("LogField"),
                col("ObservableValue").alias("MatchedValue")
            )
        ),
    )
    .withColumn("Data_parsed", from_json(col("ti_Data"), data_schema))
    .withColumn("TTPs", coalesce(col("ti_TTPs"), col("Data_parsed.indicator_types")))
    .withColumn(
        "ThreatActors",
        coalesce(
            col("ti_ThreatActors"), 
            col("Data_parsed.threat_actors"),
            array(lit("Unknown Actor"))
        ),
    )
    .withColumn("EnrichmentContext", col("ti_Data"))
    .withColumn("TenantId", col("logs_TenantId"))
    .withColumn("TimeGenerated", col("logs_TimeGenerated"))
    .withColumn("TI_TenantId", col("ti_TI_TenantId"))
    .withColumn("TI_TimeGenerated", col("ti_TI_TimeGenerated"))
    .drop("Data_parsed")
)

result_df = result_df.select(
    "MatchId",
    "JobId",
    "JobStartTime",
    "JobEndTime",
    "MatchType",
    col("ti_ObservableKey").alias("ObservableType"),
    col("ObservableValue").alias("ObservableValue"),
    "TIReferenceId",
    "TIValue",
    "MatchCount",
    "EventReferences",
    "TTPs",
    "ThreatActors",
    "EnrichmentContext",
    "TenantId",
    "TimeGenerated",
    "TI_TenantId",
    "TI_TimeGenerated",
)

if SHOW_DEBUG_LOGS:
    rc = result_df.count()
    print(f"✓ Built results DataFrame with {rc:,} individual matches")

rolled_up_df = result_df.groupBy("MatchId", "ObservableValue").agg(
    first("JobId").alias("JobId"),
    first("JobStartTime").alias("JobStartTime"),
    first("JobEndTime").alias("JobEndTime"),
    first("MatchType").alias("MatchType"),
    first("ObservableType").alias("ObservableType"),
    first("TIReferenceId").alias("TIReferenceId"),
    first("TIValue").alias("TIValue"),
    collect_list("EventReferences").alias("_temp_EventReferences"),
    first("TTPs").alias("TTPs"),
    first("ThreatActors").alias("ThreatActors"),
    first("EnrichmentContext").alias("EnrichmentContext"),
    first("TenantId").alias("TenantId"),
    first("TimeGenerated").alias("TimeGenerated"),
    first("TI_TenantId").alias("TI_TenantId"),
    first("TI_TimeGenerated").alias("TI_TimeGenerated"),
).withColumn(
    "_flattened_EventReferences", 
    flatten(col("_temp_EventReferences"))
).withColumn(
    "MatchCount",
    size(col("_flattened_EventReferences")).cast("Long")
).withColumn(
    "EventReferences",
    to_json(col("_flattened_EventReferences"))
).drop("_temp_EventReferences", "_flattened_EventReferences")

# Also serialize TTPs and ThreatActors to JSON strings for consistent storage
rolled_up_df = rolled_up_df.withColumn(
    "TTPs", 
    when(col("TTPs").isNotNull(), to_json(col("TTPs"))).otherwise(lit(None))
).withColumn(
    "ThreatActors",
    when(col("ThreatActors").isNotNull(), to_json(col("ThreatActors"))).otherwise(lit(None))
)


StatementMeta(MSGLarge, 1563, 14, Finished, Available, Finished)




## Save results (incremental by MatchId)

Note: We keep the base table schema (EventReferences as array<struct>) and append only new MatchIds. Event-level dedupe can be added later with a merge routine if needed.

In [14]:
# Optional: clear existing
CLEAR_RESULTS_TABLE = False
if CLEAR_RESULTS_TABLE:
    try:
        data_provider.delete_table(RESULTS_TABLE, WORKSPACE_NAME)
        print(f"✓ Deleted existing results table: {RESULTS_TABLE}")
    except Exception as e:
        print(f"⚠ Could not delete table {RESULTS_TABLE}: {e}")

try:
    # Try to read the existing table directly
    existing_df = None
    existing_count = 0

    try:
        existing_df = data_provider.read_table(RESULTS_TABLE, WORKSPACE_NAME)
        existing_count = existing_df.count()
        table_exists = existing_count > 0
    except Exception as read_error:
        # Table doesn't exist or can't be read
        table_exists = False
        if SHOW_DEBUG_LOGS:
            print(
                f"ℹ️ Table {RESULTS_TABLE} not found or empty: {str(read_error)[:100]}..."
            )

    # Table doesn't exist or is empty - create new - exit early
    if not table_exists:
        print(f"📁 Creating new results table: {RESULTS_TABLE}")

        # Show what we're creating
        event_count = rolled_up_df.count()
        initial_events = rolled_up_df.agg({"MatchCount": "sum"}).collect()[0][0] or 0

        print(f"  • Creating with {event_count:,} TI indicator records")
        print(f"  • Total event references: {initial_events:,}")

        if SHOW_DEBUG_LOGS:
            print("\n🔍 Sample of data being saved (first 20 rows):")
            rolled_up_df.show(20, truncate=True)

        data_provider.save_as_table(rolled_up_df, RESULTS_TABLE, WORKSPACE_NAME)
        print("✓ Created table with initial results")
    else:
        print(f"📁 Found existing results table: {RESULTS_TABLE}")
        print(f"\n📊 Initial counts:")
        print(f"  • Existing records in table: {existing_count:,}")

        event_count = rolled_up_df.count()
        new_count = rolled_up_df.count()
        print(f"  • New records to process: {new_count:,}")

        # Parse the existing EventReferences JSON back to arrays and explode
        existing_exploded = (
            existing_df.withColumn(
                "EventReferences_parsed",
                from_json(
                    col("EventReferences"),
                    ArrayType(
                        StructType(
                            [
                                StructField("Table", StringType(), True),
                                StructField("RecordId", StringType(), True),
                                StructField("TimeGenerated", TimestampType(), True),
                                StructField("LogField", StringType(), True),
                                StructField("MatchedValue", StringType(), True),
                            ]
                        )
                    ),
                ),
            )
            .select("MatchId", explode(col("EventReferences_parsed")).alias("EventRef"))
            .select(
                col("MatchId").alias("existing_MatchId"),
                col("EventRef.Table").alias("existing_Table"),
                col("EventRef.RecordId").alias("existing_RecordId"),
                col("EventRef.TimeGenerated").alias("existing_TimeGenerated"),
                col("EventRef.LogField").alias("existing_LogField"),
                col("EventRef.MatchedValue").alias("existing_MatchedValue"),
            )
        )

        existing_event_count = existing_exploded.count()
        print(f"  • Existing individual events: {existing_event_count:,}")

        # Parse and explode the new results
        new_exploded = (
            rolled_up_df.withColumn(
                "EventReferences_parsed",
                from_json(
                    col("EventReferences"),
                    ArrayType(
                        StructType(
                            [
                                StructField("Table", StringType(), True),
                                StructField("RecordId", StringType(), True),
                                StructField("TimeGenerated", TimestampType(), True),
                                StructField("LogField", StringType(), True),
                                StructField("MatchedValue", StringType(), True),
                            ]
                        )
                    ),
                ),
            )
            .select("*", explode(col("EventReferences_parsed")).alias("EventRef"))
            .select(
                "*",
                col("EventRef.Table").alias("new_Table"),
                col("EventRef.RecordId").alias("new_RecordId"),
                col("EventRef.TimeGenerated").alias("new_TimeGenerated"),
                col("EventRef.LogField").alias("new_LogField"),
                col("EventRef.MatchedValue").alias("new_MatchedValue"),
            )
            .drop("EventRef", "EventReferences_parsed")
        )

        new_event_count = new_exploded.count()
        print(f"  • New individual events to check: {new_event_count:,}")

        if SHOW_DEBUG_LOGS:
            print("\n🔍 Analyzing RecordId patterns:")
            
            # Check how many records have null RecordIds
            new_null_ids = new_exploded.filter(col("new_RecordId").isNull()).count()
            new_total = new_exploded.count()
            print(f"  • New events with NULL RecordId: {new_null_ids:,} / {new_total:,} ({(new_null_ids/new_total*100):.1f}%)" if new_total > 0 else "  • No new events")
            
            existing_null_ids = existing_exploded.filter(col("existing_RecordId").isNull()).count()
            existing_total = existing_exploded.count()
            print(f"  • Existing events with NULL RecordId: {existing_null_ids:,} / {existing_total:,} ({(existing_null_ids/existing_total*100):.1f}%)" if existing_total > 0 else "  • No existing events")
            
            # Show sample of events with NULL RecordIds
            if new_null_ids > 0:
                print("\n  Sample of new events with NULL RecordId:")
                new_exploded.filter(col("new_RecordId").isNull()).select(
                    "MatchId", "new_Table", "new_RecordId", "new_TimeGenerated", "new_LogField", "new_MatchedValue"
                ).show(5, truncate=False)
            
            # Check for exact duplicates (should find them if dedup is working)
            print("\n🔍 Checking for exact matches between new and existing:")
            exact_matches = new_exploded.join(
                existing_exploded,
                (new_exploded.MatchId == existing_exploded.existing_MatchId)
                & (new_exploded.new_Table == existing_exploded.existing_Table)
                & (
                    ((new_exploded.new_RecordId.isNotNull()) & (new_exploded.new_RecordId == existing_exploded.existing_RecordId))
                    | ((new_exploded.new_RecordId.isNull()) & (existing_exploded.existing_RecordId.isNull()))
                )
                & (new_exploded.new_TimeGenerated == existing_exploded.existing_TimeGenerated)
                & (new_exploded.new_LogField == existing_exploded.existing_LogField),
                "inner"
            )
            
            exact_count = exact_matches.count()
            print(f"  • Found {exact_count:,} exact matches (these should be deduplicated)")
            
            if exact_count > 0:
                print("\n  Sample of matches that should be deduplicated:")
                exact_matches.select(
                    "MatchId", "new_Table", "new_RecordId", "new_MatchedValue", "existing_MatchedValue"
                ).show(10, truncate=False)

        # Show sample of what we're comparing (for debugging)
        if SHOW_DEBUG_LOGS:
            print("\n🔍 Sample of existing events (first 5):")
            existing_exploded.show(5, truncate=False)

            print("\n🔍 Sample of new events (first 5):")
            new_exploded.select(
                "MatchId",
                "new_Table",
                "new_RecordId",
                "new_TimeGenerated",
                "new_LogField",
                "new_MatchedValue",
            ).show(5, truncate=False)

        # Anti-join to find truly new event references
        new_events_only = new_exploded.join(
            existing_exploded,
            (new_exploded.MatchId == existing_exploded.existing_MatchId)
            & (new_exploded.new_Table == existing_exploded.existing_Table)
            & (
                # If RecordId exists, use it for matching
                (
                    (new_exploded.new_RecordId.isNotNull())
                    & (existing_exploded.existing_RecordId.isNotNull())
                    & (new_exploded.new_RecordId == existing_exploded.existing_RecordId)
                )
                |
                # Otherwise use combination of fields for matching
                # This handles the case where RecordId is null
                (
                    (new_exploded.new_RecordId.isNull())
                    & (existing_exploded.existing_RecordId.isNull())
                    & (
                        new_exploded.new_TimeGenerated
                        == existing_exploded.existing_TimeGenerated
                    )
                    & (new_exploded.new_LogField == existing_exploded.existing_LogField)
                    & (
                        # Compare normalized values to handle case differences
                        lower(trim(new_exploded.new_MatchedValue))
                        == lower(trim(existing_exploded.existing_MatchedValue))
                    )
                )
            ),
            "leftanti",
        )
        
        if SHOW_DEBUG_LOGS:
            # Check for mismatched observable values
            mismatch_check = new_events_only.filter(
                col("ObservableValue") != col("new_MatchedValue")
            ).select(
                "MatchId",
                "ObservableValue",
                "new_MatchedValue",
                "new_LogField",
                "new_Table"
            ).limit(10)
            
            mismatch_count = mismatch_check.count()
            if mismatch_count > 0:
                print(f"\n⚠️ WARNING: Found events where ObservableValue != MatchedValue:")
                mismatch_check.show(truncate=False)
                print(f"Total mismatches: {mismatch_count}")

        actual_new_events = new_events_only.count()
        duplicate_events = new_event_count - actual_new_events

        print(f"\n📈 Deduplication results:")
        print(f"  • Total events checked: {new_event_count:,}")
        print(
            f"  • Duplicate events (already exist): {duplicate_events:,} ({(duplicate_events/new_event_count*100):.1f}%)"
            if new_event_count > 0
            else "  • Duplicate events: 0"
        )
        print(
            f"  • Truly new events to add: {actual_new_events:,} ({(actual_new_events/new_event_count*100):.1f}%)"
            if new_event_count > 0
            else "  • Truly new events: 0"
        )

        # Show which MatchIds have new events
        if SHOW_DEBUG_LOGS:
            matchids_with_new = new_events_only.select("MatchId").distinct()
            print(f"\n🎯 MatchIds with new events: {matchids_with_new.count()}")

            # Show breakdown by Table
            print("\n📊 New events by table:")
            new_events_only.groupBy("new_Table").count().orderBy(
                "count", ascending=False
            ).show()

        if SHOW_DEBUG_LOGS and actual_new_events > 0:
            print("\n🔍 Sample of TRULY NEW events being added (first 20):")
            new_events_only.select(
                "MatchId",
                "ObservableValue",
                "new_Table",
                "new_RecordId",
                "new_TimeGenerated",
                "new_LogField",
                "new_MatchedValue",
            ).show(20, truncate=False)

        # Group by observable value to see what's triggering new matches
        print("\n📊 New events by ObservableValue (top 10):")
        new_events_only.groupBy("ObservableValue").agg(
            spark_count("*").alias("NewEventCount"),
            collect_set("new_LogField").alias("LogFields"),
            collect_set("new_Table").alias("Tables"),
        ).orderBy("NewEventCount", ascending=False).show(10, truncate=False)

        # Show a few specific examples with full details
        print("\n🔬 Detailed examples of new events (first 5 unique MatchIds):")
        sample_matchids = new_events_only.select("MatchId").distinct().limit(5)
        for row in sample_matchids.collect():
            match_id = row["MatchId"]
            print(f"\n  MatchId: {match_id[:50]}...")
            sample_events = (
                new_events_only.filter(col("MatchId") == match_id)
                .select(
                    "ObservableValue",
                    "new_Table",
                    "new_LogField",
                    "new_MatchedValue",
                    "new_TimeGenerated",
                )
                .limit(3)
            )
            sample_events.show(truncate=False)

            # Re-aggregate the new events back by MatchId
            new_matches_df = (
                new_events_only.groupBy("MatchId", "ObservableValue")
                .agg(
                    first("JobId").alias("JobId"),
                    first("JobStartTime").alias("JobStartTime"),
                    first("JobEndTime").alias("JobEndTime"),
                    first("MatchType").alias("MatchType"),
                    first("ObservableType").alias("ObservableType"),
                    first("TIReferenceId").alias("TIReferenceId"),
                    first("TIValue").alias("TIValue"),
                    collect_list(
                        struct(
                            col("new_Table").alias("Table"),
                            col("new_RecordId").alias("RecordId"),
                            col("new_TimeGenerated").alias("TimeGenerated"),
                            col("new_LogField").alias("LogField"),
                            col("new_MatchedValue").alias("MatchedValue"),
                        )
                    ).alias("_new_EventReferences"),
                    first("TTPs").alias("TTPs"),
                    first("ThreatActors").alias("ThreatActors"),
                    first("EnrichmentContext").alias("EnrichmentContext"),
                    first("TenantId").alias("TenantId"),
                    current_timestamp().alias(
                        "TimeGenerated"
                    ),
                    first("TI_TenantId").alias("TI_TenantId"),
                    first("TI_TimeGenerated").alias("TI_TimeGenerated"),
                )
                .withColumn("MatchCount", size(col("_new_EventReferences")))
                .withColumn("EventReferences", to_json(col("_new_EventReferences")))
                .drop("_new_EventReferences")
            )

            # Filter out any MatchIds that now have 0 events (all were duplicates)
            new_matches_df = new_matches_df.filter(col("MatchCount") > 0)

            to_add = new_matches_df.count()
            new_events_count = (
                new_matches_df.agg({"MatchCount": "sum"}).collect()[0][0] or 0
            )

            print(f"\n✅ Final summary:")
            print(f"  • New TI indicator records to add: {to_add:,}")
            print(f"  • Total new event references in those records: {new_events_count:,}")

except Exception as e:
    print(f"✗ Error saving results: {e}")
    import traceback

    traceback.print_exc()
    raise

StatementMeta(MSGLarge, 1563, 15, Finished, Available, Finished)

{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Loading table: RetroThreatMatchResults_p3_ttps_SPRK_CL"}
{"level": "ERROR", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Table RetroThreatMatchResults_p3_ttps_SPRK_CL not found"}
📁 Creating new results table: RetroThreatMatchResults_p3_ttps_SPRK_CL
  • Creating with 9,667 TI indicator records
  • Total event references: 9,718
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Saving DataFrame as table: RetroThreatMatchResults_p3_ttps_SPRK_CL"}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Table RetroThreatMatchResults_p3_ttps_SPRK_CL does not exist. Creating new table."}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Creating custom table: RetroThreatMatchResults_p3_ttps_SPRK_CL"}
{"level": "INFO", "run_id": "1b27b126-8991-4443-b4bc-2003ae5f2774", "message": "Successfully created/updated table: R

## Threat actor summary

In [15]:
# ===============================================================================
# THREAT ACTOR DETECTION SUMMARY
# ===============================================================================
if SHOW_STATS and rolled_up_df.count() > 0:
    print("\n" + "="*80)
    print("THREAT ACTOR DETECTION SUMMARY")
    print("="*80)
    
    # Parse ThreatActors JSON strings back to arrays for analysis
    threat_actor_df = rolled_up_df.withColumn(
        "ThreatActorArray",
        from_json(col("ThreatActors"), ArrayType(StringType()))
    ).select(
        "MatchId",
        "ObservableType", 
        "ObservableValue",
        "MatchCount",
        explode(col("ThreatActorArray")).alias("ThreatActor")
    )
    
    # Overall threat actor statistics
    actor_stats = threat_actor_df.groupBy("ThreatActor").agg(
        spark_count("MatchId").alias("UniqueIndicators"),
        sum("MatchCount").alias("TotalEvents"),
        collect_set("ObservableType").alias("IndicatorTypes"),
        count_distinct("ObservableValue").alias("UniqueObservables")
    ).orderBy(col("TotalEvents").desc())
    
    print("\n📊 Threat Actors Detected (by event volume):")
    print("-" * 60)
    actor_stats.show(20, truncate=False)
    
    # Top indicators per threat actor
    print("\n🎯 Top Indicators by Threat Actor:")
    print("-" * 60)
    
    top_actors = actor_stats.limit(5).collect()
    for actor_row in top_actors:
        actor_name = actor_row["ThreatActor"]
        if actor_name != "Unknown Actor":
            print(f"\n  {actor_name}:")
            
            actor_indicators = threat_actor_df.filter(
                col("ThreatActor") == actor_name
            ).groupBy("ObservableValue", "ObservableType").agg(
                sum("MatchCount").alias("Events")
            ).orderBy(col("Events").desc()).limit(3)
            
            for ind_row in actor_indicators.collect():
                print(f"    • {ind_row['ObservableValue']} ({ind_row['ObservableType']}) - {ind_row['Events']} events")
    
    # Summary statistics
    print("\n📈 Summary Statistics:")
    print("-" * 60)
    
    total_actors = actor_stats.filter(col("ThreatActor") != "Unknown Actor").count()
    unknown_count = actor_stats.filter(col("ThreatActor") == "Unknown Actor").collect()
    unknown_events = unknown_count[0]["TotalEvents"] if unknown_count else 0
    
    total_events = actor_stats.agg(sum("TotalEvents")).collect()[0][0]
    attributed_events = total_events - unknown_events if total_events else 0
    
    print(f"  • Identified Threat Actors: {total_actors}")
    print(f"  • Total Events with Attribution: {attributed_events:,} ({(attributed_events/total_events*100):.1f}%)" if total_events > 0 else "  • Total Events with Attribution: 0")
    print(f"  • Events without Attribution: {unknown_events:,} ({(unknown_events/total_events*100):.1f}%)" if total_events > 0 else "  • Events without Attribution: 0")

StatementMeta(MSGLarge, 1563, 16, Finished, Available, Finished)

# Summary

In [16]:
print("Notebook version:", VERSION)

if SHOW_DEBUG_LOGS or REDUCED_DEBUG_LOGS:
    end = time.time()
    # Standard synapse compute cost for big data analytics
    COMPUTE_COST_PER_VCORE_HOUR = 0.15

    # Calculate time spent and vcore hours based on cluster
    elapsed_seconds = end - start
    elapsed_minutes = elapsed_seconds / 60
    elapsed_hours = elapsed_minutes / 60
    num_executors = int(spark.conf.get("spark.executor.instances"))
    num_drivers = 1
    cores_per_executor = int(spark.conf.get("spark.executor.cores"))
    total_vcores = (num_executors + num_drivers) * cores_per_executor
    vcore_hours = total_vcores * elapsed_hours
    job_cost = vcore_hours * COMPUTE_COST_PER_VCORE_HOUR
    mc = matched_indicators_df.count()

    # Output summary
    print(f"Cluster configs:")
    print(f"Executors: {num_executors}")
    print(f"Drivers: {num_drivers}")
    print(f"VCores per executor: {cores_per_executor}")
    print(f"Total vcores: {total_vcores}")
    print("------------------------------------------")
    print(f"Combined logs count: {combined_logs_count:,}")
    print(f"Indicators count: {indicators_count:,}")
    print("------------------------------------------")
    print(f"Distinct log count: {distinct_log_count:,}")
    print(f"Distinct indicator count: {distinct_indicator_count:,}")
    print("------------------------------------------")
    print(f"Matched log count: {mc:,}")
    print(f"Matched rolled up indicator count: {event_count:,}")
    print("------------------------------------------")
    print(f"Runtime: {elapsed_minutes:.2f} minutes")
    print("------------------------------------------")
    print(f"VCore-hours used: {vcore_hours:.2f}")
    print(f"Job cost: ${job_cost:.2f}")

    print("\n\n------------------------------------------")
    print(f"Indicators partitions: {deduped_indicators_df.rdd.getNumPartitions()}")
    print(f"Logs partitions: {combined_logs_df.rdd.getNumPartitions()}")
    print(f"Distinct indicators partitions: {distinct_indicator_values.rdd.getNumPartitions()}")
    print(f"Distinct logs partitions: {distinct_log_values.rdd.getNumPartitions()}")
    print(f"Intersecting values partitions: {intersection_values.rdd.getNumPartitions()}")
    print(f"Rolled up result df partitions: {rolled_up_df.rdd.getNumPartitions()}")
    print("------------------------------------------")


StatementMeta(MSGLarge, 1563, 17, Finished, Available, Finished)

Notebook version: 1.0.1
Cluster configs:
Executors: 4
Drivers: 1
VCores per executor: 16
Total vcores: 80
------------------------------------------
Combined logs count: 116,743,199
Indicators count: 60,055
------------------------------------------
Distinct log count: 249,171
Distinct indicator count: 59,280
------------------------------------------
Matched log count: 9,718
Matched rolled up indicator count: 9,667
------------------------------------------
Runtime: 23.84 minutes
------------------------------------------
VCore-hours used: 31.79
Job cost: $4.77


------------------------------------------
Indicators partitions: 1800
Logs partitions: 1200
Distinct indicators partitions: 1
Distinct logs partitions: 4
Intersecting values partitions: 200
Rolled up result df partitions: 13
------------------------------------------
