In [1]:
%%configure -f
{
    "conf":{
        "spark.driver.maxResultSize":"50g",
        "spark.rpc.message.maxSize":"1024",
        "spark.kryoserializer.buffer.max":"512m"
    }
}

StatementMeta(, 154, -1, Finished, Available)

In [2]:
DateSliceEnd = ""

StatementMeta(maxpiasevoliL, 154, 2, Finished, Available)

In [3]:
import datetime as dt, json, pickle
import pyspark.sql.functions as f
from pyspark.sql.window import Window
from pyspark.sql.types import ArrayType, StringType, StructField, StructType, IntegerType, MapType, Row
from pyspark.ml.feature import VectorAssembler
import numpy as np

StatementMeta(maxpiasevoliL, 154, 3, Finished, Available)

In [4]:
date = DateSliceEnd.split('T')[0]
end_processing_timestamp = dt.datetime.strptime(date, '%Y-%m-%d')
start_processing_timestamp = end_processing_timestamp - dt.timedelta(days=3)

print('Start processing timestamp: ', start_processing_timestamp)
print('End processing timestamp: ', end_processing_timestamp)

StatementMeta(maxpiasevoliL, 154, 4, Finished, Available)

Start processing timestamp:  2023-09-18 00:00:00
End processing timestamp:  2023-10-09 00:00:00


In [5]:
output_path = "abfss://skg@rdamlapeussa.dfs.core.windows.net/output_data/tenant_graph_v1_all_tenants"

output_path = output_path + "/all_data/" + date
drop_alert_type_dups_on_same_entity = False

print('output path: ', output_path)

StatementMeta(maxpiasevoliL, 154, 5, Finished, Available)

output path:  abfss://skg@rdamlapeussa.dfs.core.windows.net/output_data/tenant_graph_v1_all_tenants/all_data/2023-10-09


In [6]:
kusto_service_name = 'eoaeusprod_eastus'

def get_incident_data():
    '''
    Get classified security incidents
    '''

    query = f"""SecurityIncident
    | where CreatedTime between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
    | extend IsFusionIncident = toint(Description has "aka.ms/SentinelFusion" or Description has "Fusion has identified")
    | project WorkspaceId, IncidentNumber, IncidentName, IsFusionIncident, Severity, Title, ProviderName, AlertIds, CreatedTime, LastModifiedTime, Classification, ClassificationReason, ClosedTime, WorkspaceTenantId, IncidentDescription=Description
    """
        
    incidents_df = spark.read \
        .format("com.microsoft.kusto.spark.synapse.datasource") \
        .option("spark.synapse.linkedService", kusto_service_name) \
        .option("kustoDatabase", "SecurityInsights") \
        .option("kustoQuery", query) \
        .load()

    return incidents_df

incident_data = get_incident_data()

StatementMeta(maxpiasevoliL, 154, 6, Finished, Available)

In [7]:
kusto_service_name = 'eoaeusprod_eastus'

def get_alert_data():
    '''
    Get security alerts
    '''

    query = f"""SecurityAlert
    | where StartTime between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
    | project SystemAlertId, Entities, AlertProviderName=ProviderName, AlertDisplayName=DisplayName, AlertSeverity, StartTime, WorkspaceId, Tactics, Techniques, WorkspaceTenantId, AlertDescription=Description
    """
        
    alerts_df = spark.read \
        .format("com.microsoft.kusto.spark.synapse.datasource") \
        .option("spark.synapse.linkedService", kusto_service_name) \
        .option("kustoDatabase", "SecurityInsights") \
        .option("kustoQuery", query) \
        .load()

    return alerts_df

alert_data = get_alert_data()

StatementMeta(maxpiasevoliL, 154, 7, Finished, Available)

In [8]:
window = Window.partitionBy("WorkspaceId", "IncidentNumber").orderBy(f.col("LastModifiedTime").desc())

deduped_incident_data = (
    incident_data
        .withColumn("rowNumber", f.row_number().over(window))
        .filter(f.col("rowNumber") == 1)
        .drop("rowNumber")
        .persist()
)

StatementMeta(maxpiasevoliL, 154, 8, Finished, Available)

In [9]:
window = Window.partitionBy("WorkspaceId", "SystemAlertId").orderBy(f.col("StartTime").desc())

deduped_alert_data = (
    alert_data
        .withColumn("rowNumber", f.row_number().over(window))
        .filter(f.col("rowNumber") == 1)
        .drop("rowNumber")
        .persist()
)

StatementMeta(maxpiasevoliL, 154, 9, Finished, Available)

In [10]:
incident_node_attributes = (
    deduped_incident_data
        .withColumn("IncidentNodeAttributes", f.create_map(
            f.lit("IncidentNumber"), f.col("IncidentNumber"),
            f.lit("WorkspaceId"), f.col("WorkspaceId"),
            f.lit("TenantId"), f.col("WorkspaceTenantId"),
            f.lit("Title"), f.col("Title"),
            f.lit("Description"), f.col("IncidentDescription"),
            f.lit("Severity"), f.col("Severity"),
            f.lit("ProviderName"), f.col("ProviderName"),
            f.lit("CreatedTime"), f.col("CreatedTime"),
            f.lit("ClosedTime"), f.col("ClosedTime"),
            f.lit("LastModifiedTime"), f.col("LastModifiedTime"),
            f.lit("TimeToClassificationInMinutes"), 
            (
                f.coalesce(
                    f.col("ClosedTime"), 
                    f.col("CreatedTime")
                ).cast("long") - f.col("CreatedTime").cast("long")
            )/60,
            f.lit("Classification"), f.coalesce(f.col("Classification"), f.lit("NoClassification")),
            f.lit("IsFusionIncident"), ((f.col("IncidentDescription").contains("Fusion has identified")) | (f.col("IncidentDescription").contains("https://aka.ms/SentinelFusion"))).cast("string"),
            f.lit("node_type"), f.lit("securityincident")
        ))
        .withColumn("IncidentNodeAttributes", f.to_json("IncidentNodeAttributes"))
        .select("WorkspaceId", "IncidentName", "WorkspaceTenantId", "IncidentNodeAttributes")
        .persist()
)

alert_node_attributes = (
    deduped_alert_data
        .withColumn("AlertNodeAttributes", f.create_map(
            f.lit("WorkspaceId"), f.col("WorkspaceId"),
            f.lit("TenantId"), f.col("WorkspaceTenantId"),
            f.lit("AlertDisplayName"), f.col("AlertDisplayName"),
            f.lit("Description"), f.col("AlertDescription"),
            f.lit("AlertProviderName"), f.col("AlertProviderName"),
            f.lit("AlertSeverity"), f.col("AlertSeverity"),
            f.lit("StartTime"), f.col("StartTime"),
            f.lit("Tactics"), f.col("Tactics"),
            f.lit("Techniques"), f.col("Techniques"),
            f.lit("node_type"), f.lit("securityalert")
        ))
        .withColumn("AlertNodeAttributes", f.to_json("AlertNodeAttributes"))
        .select("WorkspaceId", "SystemAlertId", "WorkspaceTenantId", "AlertNodeAttributes")
        .persist()
)

deduped_alert_data = deduped_alert_data.select("WorkspaceId", "WorkspaceTenantId", "SystemAlertId", "Entities", "AlertDisplayName", "AlertProviderName", "StartTime").persist()

StatementMeta(maxpiasevoliL, 154, 10, Finished, Available)

In [11]:
import re, json

StatementMeta(maxpiasevoliL, 154, 11, Finished, Available)

In [12]:
entity_schema = StructType([
    StructField("EntityType", StringType(), nullable=False),
    StructField("IdentifierField", StringType(), nullable=False),
    StructField("IdentifierValue", StringType(), nullable=False),
    StructField("EnrichedEntityNodeAttributes", StringType(), nullable=False)
])

def process_entity_identifiers(entities_json_string):
    try:
        entity_dict_list = json.loads(entities_json_string)
    except:
        return []

    def is_local_ipv4(ip_address):
        try:
            if ip_address == "0.0.0.0" or ip_address == "127.0.0.1" or ip_address[:7] == "192.168" or ip_address[:3] == "10.":
                return True
            elif (ip_address[0:4] == "172."):
                ip_address_fields = ip_address.split(".")
                if len(fields) >= 2:
                    ip_range = int(ip_address_fields[1])
                    if ip_range >= 16 and ip_range <= 32:
                        return True
            
            return False
        except:
            return False

    def get_identifier_value(entity_dict, identifier_field):
        return str(entity_dict[identifier_field]).lower() if identifier_field in entity_dict else ""
    
    entity_field_delimiter = "__"
    def union_fields(identifier_list):
        str_identifier_list = [str(identifier).lower() for identifier in identifier_list]
        return entity_field_delimiter.join(str_identifier_list)

    final_entities_list = []

    for entity_dict in entity_dict_list:
        if "Type" in entity_dict:
            type_value = entity_dict["Type"].lower()
            node_attributes = {"node_type": type_value}
            

            if type_value == "account":

                identifier_field = "AadUserId"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

                if "Name" in entity_dict and entity_dict["Name"] not in ["root", "system", "guest", "admin", "administrator", "user"]:
                    node_attributes['identifier_fields'] = "Name"
                    final_entities_list.append([type_value, "Name", entity_dict["Name"], json.dumps(node_attributes.copy())])

                if "Name" in entity_dict and entity_dict["Name"] not in ["root", "system", "guest", "admin", "administrator", "user"] and "UPNSuffix" in entity_dict:
                    node_attributes['identifier_fields'] = "Email"
                    final_entities_list.append([type_value, "Email", entity_dict["Name"] + "@"  + entity_dict['UPNSuffix'], json.dumps(node_attributes.copy())])
                
                if "Sid" in entity_dict:
                    if entity_dict["Sid"] in ['S-1-5-18']: continue
                    node_attributes['identifier_fields'] = "Sid"
                    final_entities_list.append([type_value, "Sid", entity_dict["Sid"], json.dumps(node_attributes.copy())])
                
                node_attributes['identifier_fields'] = "AadUserId"

            elif type_value == "cloud-application":

                identifier_field = "AppId"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

                if "Name" in entity_dict and "InstanceName" in entity_dict:
                    node_attributes['identifier_fields'] = "Name__InstanceName"
                    final_entities_list.append([type_value, "Name__InstanceName", union_fields([entity_dict["Name"], entity_dict['InstanceName']]), json.dumps(node_attributes.copy())])
                
                node_attributes['identifier_fields'] = "AppId"

            elif type_value == "file":

                identifier_field = "Name"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

                # if "Name" in entity_dict and "Directory" in entity_dict:
                #     final_entities_list.append([type_value, "Directory__Name", entity_dict['Directory'] + entity_field_delimiter + entity_dict["Name"], json.dumps(node_attributes.copy())])
                node_attributes['identifier_fields'] = "Name"

            elif type_value == "filehash":

                if "Algorithm" in entity_dict and "Value" in entity_dict:
                    node_attributes['identifier_fields'] = "Algorithm__Value"
                    final_entities_list.append([type_value, "Algorithm__Value", union_fields([entity_dict["Algorithm"], entity_dict["Value"]]), json.dumps(node_attributes.copy())])
                
                continue

            elif type_value == "host":

                identifier_field = "AadDeviceId"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

                if "AzureID" in entity_dict:
                    node_attributes['identifier_fields'] = "AzureID"
                    final_entities_list.append([type_value, "AzureID", entity_dict["AzureID"], json.dumps(node_attributes.copy())])

                if "HostName" in entity_dict:
                    node_attributes['identifier_fields'] = "HostName"
                    final_entities_list.append([type_value, "HostName", entity_dict["HostName"], json.dumps(node_attributes.copy())])

                if "OMSAgentID" in entity_dict:
                    node_attributes['identifier_fields'] = "OMSAgentID"
                    final_entities_list.append([type_value, "OMSAgentID", entity_dict["OMSAgentID"], json.dumps(node_attributes.copy())])

                node_attributes['identifier_fields'] = "AadDeviceId"

            elif type_value == "iotdevice":

                identifier_field = "DeviceId"
                node_attributes['identifier_fields'] = "DeviceId"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

            elif type_value == "ip":

                identifier_field = "Address"
                node_attributes['identifier_fields'] = "Address"
                identifier_value = get_identifier_value(entity_dict, identifier_field)
                if identifier_value in ('0.0.0.0', '127.0.0.1', '8.8.8.8'): continue
                node_attributes['IsLocalIPv4'] = str(is_local_ipv4(identifier_value))
                #if is_local_ipv4(identifier_value): continue
            
            elif type_value == "mailbox" or type_value == "mailboxconfiguration":

                identifier_field = "MailboxPrimaryAddress"
                node_attributes['identifier_fields'] = "MailboxPrimaryAddress"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

            elif type_value == "mailcluster":

                if "Source" in entity_dict and "Query" in entity_dict:
                    node_attributes['identifier_fields'] = "Source__Query"
                    final_entities_list.append([type_value, "Source__Query", union_fields([entity_dict["Source"], entity_dict["Query"]]), json.dumps(node_attributes.copy())])
                
                continue
            
            elif type_value == "mailmessage":

                identifier_field = "Sender"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

                if "Subject" in entity_dict: node_attributes['subject'] = entity_dict['Subject']

                if "Recipient" in entity_dict:
                    node_attributes['identifier_fields'] = "Recipient"
                    final_entities_list.append([type_value, "Recipient", entity_dict["Recipient"], json.dumps(node_attributes.copy())])

                if "SenderIP" in entity_dict and entity_dict["SenderIP"] not in ('0.0.0.0', '127.0.0.1', '8.8.8.8'):
                    node_attributes['identifier_fields'] = "SenderIP"
                    ip_address = entity_dict["SenderIP"]
                    copy_node_attributes = node_attributes.copy()
                    copy_node_attributes['IsLocalIPv4'] = str(is_local_ipv4(ip_address))
                    final_entities_list.append([type_value, "SenderIP", ip_address, json.dumps(copy_node_attributes)])
                
                node_attributes['identifier_fields'] = "Sender"

            elif type_value == "oauth-application":

                identifier_field = "OAuthObjectId"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

                if "OAuthAppId" in entity_dict:
                    node_attributes['identifier_fields'] = "OAuthAppId"
                    final_entities_list.append([type_value, "OAuthAppId", entity_dict["OAuthAppId"], json.dumps(node_attributes.copy())])
                
                node_attributes['identifier_fields'] = "OAuthObjectId"
            
            elif type_value == "process":

                if "ProcessId" in entity_dict and "CreatedTimeUtc" in entity_dict and "CommandLine" in entity_dict:
                    node_attributes['identifier_fields'] = "ProcessId__CreatedTimeUtc__CommandLine"
                    final_entities_list.append([type_value, "ProcessId__CreatedTimeUtc__CommandLine", union_fields([entity_dict["ProcessId"], entity_dict["CreatedTimeUtc"], entity_dict["CommandLine"]]), json.dumps(node_attributes.copy())])

                if "CommandLine" in entity_dict:
                    node_attributes['identifier_fields'] = "ExtractedFileName"
                    command_line = str(get_identifier_value(entity_dict, "CommandLine")).lower()
                    extracted_files = re.findall(r'([^\\\/\s"\']*\.(?:exe|pdf|dll|xlsx|docx|zip|png|txt|ps1|html|png|tmp))', command_line)
                    for extracted_file in extracted_files:
                        final_entities_list.append([type_value, "ExtractedFileName", extracted_file, json.dumps(node_attributes.copy())])

                continue
            
            elif type_value == "security-group":

                identifier_field = "ObjectGuid"
                identifier_value = get_identifier_value(entity_dict, identifier_field)

                if "SID" in entity_dict:
                    node_attributes['identifier_fields'] = "SID"
                    final_entities_list.append([type_value, "SID", entity_dict["SID"], json.dumps(node_attributes.copy())])
                
                node_attributes['identifier_fields'] = "ObjectGuid"

            elif type_value == "service-principal":

                identifier_field = "ServicePrincipalObjectId"
                node_attributes['identifier_fields'] = "ServicePrincipalObjectId"
                identifier_value = get_identifier_value(entity_dict, identifier_field)
            
            elif type_value == "url":

                identifier_field = "Url"
                node_attributes['identifier_fields'] = "Url"
                identifier_value = get_identifier_value(entity_dict, identifier_field)
            #     # TODO: add method to check if url is absolute or not
                #node_attributes['IsAbsoluteUrl'] = str(is_absolute_url(identifier_value))
            
            # both cloud resource types extract the subscription id from the resource url
            # elif type_value == "gcp-resource":

            #     identifier_field = "RelatedAzureResourceIds"
            #     resource_url = str(get_identifier_value(entity_dict, identifier_field)).lower()
            #     match = re.search(r'(?<=subscriptions/)[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}', resource_url)
            #     identifier_value = match.group() if match else ""
            
            elif type_value == "azure-resource":

                identifier_field = "ResourceId"
                resource_url = str(get_identifier_value(entity_dict, identifier_field)).lower()
                identifier_value = resource_url

                subscription_match = re.search(r'(?<=subscriptions/)[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}', resource_url)                
                if subscription_match:
                    node_attributes['identifier_fields'] = "SubscriptionId"
                    final_entities_list.append([type_value, "SubscriptionId", subscription_match.group(), json.dumps(node_attributes.copy())])

                resource_group_list = re.findall(r'resourcegroups/([^/]*)/', resource_url)
                if len(resource_group_list):
                    node_attributes['identifier_fields'] = "ResourceGroup"
                    final_entities_list.append([type_value, "ResourceGroup", resource_group_list[0], json.dumps(node_attributes.copy())])

                node_attributes['identifier_fields'] = "ResourceId"

            else:
                continue
            
            if identifier_value != "":
                final_entities_list.append([type_value, identifier_field, identifier_value, json.dumps(node_attributes)])
    
    return final_entities_list

process_entity_identifiers_udf = f.udf(process_entity_identifiers, ArrayType(entity_schema))

StatementMeta(maxpiasevoliL, 154, 12, Finished, Available)

In [13]:
deduped_alert_data_with_entities = deduped_alert_data.withColumn("EntityIdentifierList", process_entity_identifiers_udf("Entities")).drop("Entities")

StatementMeta(maxpiasevoliL, 154, 13, Finished, Available)

In [14]:
deduped_alert_data_with_entities_flattened = (
    deduped_alert_data_with_entities
        .select("*", f.explode("EntityIdentifierList"))
        .select("*", "col.*")
        .drop("col", "EntityIdentifierList")
        .where(f.col("IdentifierValue") != "")
        .withColumn("IdentifierValue", f.lower("IdentifierValue"))
)

print('Alert count pre-dedup: ', deduped_alert_data_with_entities_flattened.select("WorkspaceId", "SystemAlertId").distinct().count())

incident_to_alert_data = (
    deduped_incident_data
        .withColumn("AlertIds", f.from_json("AlertIds", ArrayType(StringType())))
        .withColumn("AssociatedAlertCount", f.size("AlertIds"))
        .select("*", f.explode("AlertIds").alias("SystemAlertId"))
        .where(f.col("SystemAlertId") != "")
        .select("WorkspaceId", "IncidentName", "SystemAlertId", "AssociatedAlertCount", "WorkspaceTenantId")
        .distinct()
)

print('Incident count pre-dedup: ', incident_to_alert_data.select("WorkspaceId", "IncidentName").distinct().count())

# account for alerts that are no longer present in alert table
incident_window = Window.partitionBy("WorkspaceId", "IncidentName")
incident_to_alert_data = (
    deduped_alert_data
        .select("WorkspaceId", "SystemAlertId")
        .distinct()
        .join(incident_to_alert_data, ["WorkspaceId", "SystemAlertId"])
)

# de-duplicate alert data again to remove duplicate alert types pointing to same entity identifier
if drop_alert_type_dups_on_same_entity:
    alert_window = Window.partitionBy("WorkspaceId", "AlertDisplayName", "AlertProviderName", "EntityType", "IdentifierField", "IdentifierValue").orderBy(f.col("StartTime").desc())

    deduped_alert_data_with_entities_flattened = (
        deduped_alert_data_with_entities_flattened
            .withColumn("rowNumber", f.row_number().over(alert_window))
            .persist()
    )

    alerts_to_drop = (
        deduped_alert_data_with_entities_flattened
            .filter(f.col("rowNumber") != 1)
            .select("WorkspaceId", "SystemAlertId")
            .distinct()
    )

    deduped_alert_data_with_entities_flattened = (
        deduped_alert_data_with_entities_flattened
            .filter(f.col("rowNumber") == 1)
            .drop("rowNumber")
    )

    # use alert-entity dedup to update incident to alert mappings
    incident_to_alert_data = (
        incident_to_alert_data
            .join(alerts_to_drop, ["WorkspaceId", "SystemAlertId"], 'leftanti')
    )

deduped_alert_data_with_entities_flattened = (
    deduped_alert_data_with_entities_flattened
        .drop("AlertDisplayName", "AlertProviderName", "StartTime")
        .persist()
)

# drop incidents where not all alerts are present
incident_to_alert_data = (
    incident_to_alert_data
        .withColumn("JoinedAlertCount", f.size(f.collect_set("SystemAlertId").over(incident_window)))
        .where(f.col("JoinedAlertCount") == f.col("AssociatedAlertCount"))
        .persist()
)

print('Alert count post-dedup: ', deduped_alert_data_with_entities_flattened.select("WorkspaceId", "SystemAlertId").distinct().count())
print('Incident count post dedup: ', incident_to_alert_data.select("WorkspaceId", "IncidentName").distinct().count())

StatementMeta(maxpiasevoliL, 154, 14, Finished, Available)

Alert count pre-dedup:  39682511
Incident count pre-dedup:  14117854
Alert count post-dedup:  39682511
Incident count post dedup:  12547490


In [15]:
deduped_alert_data_with_entities_flattened.printSchema()

StatementMeta(maxpiasevoliL, 154, 15, Finished, Available)

root
 |-- WorkspaceId: string (nullable = true)
 |-- WorkspaceTenantId: string (nullable = true)
 |-- SystemAlertId: string (nullable = true)
 |-- EntityType: string (nullable = true)
 |-- IdentifierField: string (nullable = true)
 |-- IdentifierValue: string (nullable = true)
 |-- EnrichedEntityNodeAttributes: string (nullable = true)



In [16]:
enriched_entity_node_attributes = (
    deduped_alert_data_with_entities_flattened
        .select(f.col("WorkspaceTenantId").alias("TenantId"), "EntityType", "IdentifierValue", "EnrichedEntityNodeAttributes")
        .distinct()
        .persist()
)

deduped_alert_data_with_entities_flattened = deduped_alert_data_with_entities_flattened.drop("EnrichedEntityNodeAttributes").distinct().persist()

StatementMeta(maxpiasevoliL, 154, 16, Finished, Available)

In [17]:
(
    deduped_alert_data_with_entities_flattened
        .select("EntityType", "IdentifierField")
        .distinct()
        .orderBy(f.col("EntityType").asc(), f.col("IdentifierField").asc())
        .show(n=50, truncate=False)
)

StatementMeta(maxpiasevoliL, 154, 17, Finished, Available)

+--------------------+--------------------------------------+
|EntityType          |IdentifierField                       |
+--------------------+--------------------------------------+
|account             |AadUserId                             |
|account             |Email                                 |
|account             |Name                                  |
|account             |Sid                                   |
|azure-resource      |ResourceGroup                         |
|azure-resource      |ResourceId                            |
|azure-resource      |SubscriptionId                        |
|cloud-application   |AppId                                 |
|cloud-application   |Name__InstanceName                    |
|file                |Name                                  |
|filehash            |Algorithm__Value                      |
|host                |AadDeviceId                           |
|host                |AzureID                               |
|host   

In [18]:
header = f"""EvaluationSkillInvocations
| where PreciseTimeStamp between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
"""

regex_query = header + """| where State == 'Completed' and Success == 'true'
| extend SessionId = case(isempty(SessionId), sessionid, SessionId)
| extend SkillInputs = tolower(SkillInputs)
| extend SkillOutput = tolower(SkillOutput)
| extend SkillInputGuids = extract_all("([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})", SkillInputs)
| extend SkillOutputGuids = extract_all("([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})", SkillOutput)
| extend SkillInputSids = extract_all(@"(s-\d-\d+-[\d-]+)", SkillInputs)
| extend SkillOutputSids = extract_all(@"(s-\d-\d+-[\d-]+)", SkillOutput)
| extend SkillInputEmails = extract_all(@"([a-z0-9._%+-]+@[a-z0-9.-]+\.@[a-z]{2,})", SkillInputs)
| extend SkillOutputEmails = extract_all(@"([a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,})", SkillOutput)
| extend SkillInputIps = extract_all(@"((?:[0-9]{1,3}\.){3}[0-9]{1,3})", SkillInputs)
| extend SkillOutputIps = extract_all(@"((?:[0-9]{1,3}\.){3}[0-9]{1,3})", SkillOutput)
| extend SkillInputFiles = extract_all(@'([^"\\\/\s]*\.(?:exe|pdf|dll|xlsx|docx|zip|png|txt|ps1|html|png|tmp))', SkillInputs)
| extend SkillOutputFiles = extract_all(@'([^"\\\/\s]*\.(?:exe|pdf|dll|xlsx|docx|zip|png|txt|ps1|html|png|tmp|json))', SkillOutput)
| extend Identifiers = array_concat(
    SkillInputGuids, SkillOutputGuids, 
    SkillInputEmails, SkillOutputEmails, 
    SkillInputIps, SkillOutputIps,
    SkillInputSids, SkillOutputSids,
    SkillInputFiles, SkillOutputFiles
)
| mv-expand Identifiers
| extend Identifier = tostring(Identifiers)
| join kind=inner(
    Sessions
    | distinct SessionId, TenantId
) on SessionId
| distinct TenantId, SessionId, SkillInvocationId, SkillName, Identifier
"""

evaluation_skill_invocations_identifiers_regex = (
    spark
        .read 
        .format("com.microsoft.kusto.spark.synapse.datasource")
        .option("spark.synapse.linkedService", "medeinaapiprod")
        .option("kustoDatabase", "medeinalogs")
        .option("authType", "LS")
        .option("kustoQuery", regex_query)
        .load()
        .drop("SessionId1")
        .withColumn("ExtractedIncidentNodeAttributes", f.lit('{"extracted_via_regex": "True"}'))
        .withColumn("ExtractedAlertNodeAttributes", f.lit('{"extracted_via_regex": "True"}'))
        .withColumn("ExtractedEntityNodeAttributes", f.lit('{"extracted_via_regex": "True"}'))
        .withColumn("ExtractedEntityType", f.lit('no_entity_type_for_regex'))
)

StatementMeta(maxpiasevoliL, 154, 18, Finished, Available)

In [19]:
print("Potential identifiers from regex: ", evaluation_skill_invocations_identifiers_regex.count())

StatementMeta(maxpiasevoliL, 154, 19, Finished, Available)

Potential identifiers from regex:  657730


# Sentinel IncidentName Extraction
(Skill --> Incident --> Alert --> Entity --> Alert)

In [20]:
def join_session_and_incident_on_incident_guid(session_sdf, join_type = "right_outer"):
    return (
        deduped_incident_data
            .select("WorkspaceId", "IncidentName", "WorkspaceTenantId")
            .distinct()
            .join(session_sdf,
                (deduped_incident_data.WorkspaceTenantId == session_sdf.TenantId)
                & (deduped_incident_data.IncidentName == session_sdf.Identifier),
                join_type
            )
            .select(
                "WorkspaceId",
                "TenantId",
                "IncidentName",
                "SessionId",
                "SkillInvocationId",
                "ExtractedIncidentNodeAttributes"
            )
            .persist()
    )

StatementMeta(maxpiasevoliL, 154, 20, Finished, Available)

In [21]:
regex_session_and_incident_on_incident_guid = join_session_and_incident_on_incident_guid(evaluation_skill_invocations_identifiers_regex, "inner")

StatementMeta(maxpiasevoliL, 154, 21, Finished, Available)

In [22]:
print("Incident identifiers extracted via regex: ", regex_session_and_incident_on_incident_guid.count())

StatementMeta(maxpiasevoliL, 154, 22, Finished, Available)

Incident identifiers extracted via regex:  679


In [23]:
# static method for extracting incident identifiers
"""
KQL for adding Sentinel IncidentName col
"""
incident_name_query_str = f"""EvaluationSkillInvocations
| where PreciseTimeStamp between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
| where State == 'Completed' and Success == 'true'
| where SkillName == 'GetSentinelIncidentBasicInfoByUrl'
| extend SessionId = case(isempty(SessionId), sessionid, SessionId)
| extend IncidentName = tostring(parse_json(SkillInputs).Url)
| extend IncidentName = tostring(split(IncidentName, '/')[-1])
| extend ExtractedIncidentNodeAttributes = dynamic_to_json(bag_pack("extracted_via_static", "True"))
| join kind=inner(
    Sessions
    | distinct SessionId, TenantId
) on SessionId
| project SessionId, SkillInvocationId, IncidentName, TenantId, ExtractedIncidentNodeAttributes
"""

evaluation_skill_invocations_incident_static_sdf = spark.read \
    .format("com.microsoft.kusto.spark.synapse.datasource") \
    .option("spark.synapse.linkedService", "medeinaapiprod") \
    .option("kustoDatabase", "medeinalogs") \
    .option("authType", "LS")\
    .option("kustoQuery", incident_name_query_str) \
    .load() \
    .withColumnRenamed("IncidentName", "Identifier")

StatementMeta(maxpiasevoliL, 154, 23, Finished, Available)

In [24]:
static_session_and_incident_on_incident_guid = join_session_and_incident_on_incident_guid(evaluation_skill_invocations_incident_static_sdf, "inner")

StatementMeta(maxpiasevoliL, 154, 24, Finished, Available)

In [25]:
print("Incident identifiers extracted via static: ", static_session_and_incident_on_incident_guid.count())

StatementMeta(maxpiasevoliL, 154, 25, Finished, Available)

Incident identifiers extracted via static:  31


# Sentinel Security Alert Extraction
(Skill --> Alert --> Entity --> Alert)

In [26]:
def join_session_and_alert_on_alert_guid(session_sdf, join_type = "right_outer"):
    return (
        deduped_alert_data
            .select("WorkspaceId", "SystemAlertId", "WorkspaceTenantId")
            .distinct()
            .join(session_sdf,
                (deduped_alert_data.WorkspaceTenantId == session_sdf.TenantId)
                & (deduped_alert_data.SystemAlertId == session_sdf.Identifier),
                join_type
            )
            .select(
                "WorkspaceId",
                "TenantId",
                "SystemAlertId",
                "SessionId",
                "SkillInvocationId",
                "ExtractedAlertNodeAttributes"
            )
            .persist()
    )

StatementMeta(maxpiasevoliL, 154, 26, Finished, Available)

In [27]:
regex_session_and_alert_on_alert_guid = join_session_and_alert_on_alert_guid(evaluation_skill_invocations_identifiers_regex, "inner")

StatementMeta(maxpiasevoliL, 154, 27, Finished, Available)

In [28]:
print("Alert identifiers extracted via regex: ", regex_session_and_alert_on_alert_guid.count())

StatementMeta(maxpiasevoliL, 154, 28, Finished, Available)

Alert identifiers extracted via regex:  167


In [29]:
# static method for extracting alerts
"""
KQL for adding Sentinel Alerts col
"""
alerts_query_str = f"""EvaluationSkillInvocations
| where PreciseTimeStamp between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
| where State == 'Completed' and Success == 'true'
| where SkillName has 'GetSentinelIncidentAlerts'
| extend SessionId = case(isempty(SessionId), sessionid, SessionId)
| extend SkillOutput = parse_json(SkillOutput) 
| mv-expand SkillOutput = SkillOutput.value to typeof(dynamic)
| extend AlertName = tostring(SkillOutput.name)
| where isnotempty(AlertName)
| extend ExtractedAlertNodeAttributes = dynamic_to_json(bag_pack("extracted_via_static", "True"))
| join kind=inner(
    Sessions
    | distinct SessionId, TenantId
) on SessionId
| project SessionId, SkillInvocationId, AlertName, TenantId, ExtractedAlertNodeAttributes
"""

evaluation_skill_invocations_alert_static_sdf = spark.read \
    .format("com.microsoft.kusto.spark.synapse.datasource") \
    .option("spark.synapse.linkedService", "medeinaapiprod") \
    .option("kustoDatabase", "medeinalogs") \
    .option("authType", "LS")\
    .option("kustoQuery", alerts_query_str) \
    .load() \
    .withColumnRenamed("AlertName", "Identifier")

StatementMeta(maxpiasevoliL, 154, 29, Finished, Available)

In [30]:
print("Potential alert identifiers extracted via static: ", evaluation_skill_invocations_alert_static_sdf.count())

StatementMeta(maxpiasevoliL, 154, 30, Finished, Available)

Potential alert identifiers extracted via static:  889


In [31]:
static_session_and_alert_on_alert_guid = join_session_and_alert_on_alert_guid(evaluation_skill_invocations_alert_static_sdf, "inner")

StatementMeta(maxpiasevoliL, 154, 31, Finished, Available)

In [32]:
print("Alert identifiers extracted via static: ", static_session_and_alert_on_alert_guid.count())

StatementMeta(maxpiasevoliL, 154, 32, Finished, Available)

Alert identifiers extracted via static:  35


# Sentinl V3 Security Entity Extraction
(Skill --> Entity --> Alert)

In [33]:
def join_session_and_alert_on_identifier(session_sdf, join_type = "right_outer", join_on_entity_type = True):
    if join_on_entity_type:
        session_and_alert_data = (
            deduped_alert_data_with_entities_flattened
                .join(
                    session_sdf, 
                    (deduped_alert_data_with_entities_flattened.WorkspaceTenantId == session_sdf.TenantId) 
                        & (deduped_alert_data_with_entities_flattened.IdentifierValue == session_sdf.Identifier)
                        & (deduped_alert_data_with_entities_flattened.EntityType == session_sdf.ExtractedEntityType),
                    join_type
                )
        )
    else:
        # including this case since regex won't have specified entity type
        session_and_alert_data = (
            deduped_alert_data_with_entities_flattened
                .join(
                    session_sdf, 
                    (deduped_alert_data_with_entities_flattened.WorkspaceTenantId == session_sdf.TenantId) 
                        & (deduped_alert_data_with_entities_flattened.IdentifierValue == session_sdf.Identifier),
                    join_type
                )
        )

    session_and_alert_on_identifier = (
        session_and_alert_data
            .select(
                "WorkspaceId",
                "TenantId",
                f.coalesce(f.col("EntityType"), f.col("ExtractedEntityType")).alias("EntityType"),
                f.col("Identifier").alias("IdentifierValue"),
                "SystemAlertId",
                "SessionId",
                "SkillInvocationId",
                "ExtractedEntityNodeAttributes"
            )
            .persist()
    )

    return session_and_alert_on_identifier

StatementMeta(maxpiasevoliL, 154, 33, Finished, Available)

In [34]:
regex_session_and_alert_on_identifier = join_session_and_alert_on_identifier(evaluation_skill_invocations_identifiers_regex, join_type="inner", join_on_entity_type=False)

StatementMeta(maxpiasevoliL, 154, 34, Finished, Available)

In [35]:
print("Entities extracted via regex: ", regex_session_and_alert_on_identifier.count())

StatementMeta(maxpiasevoliL, 154, 35, Finished, Available)

Entities extracted via regex:  6720478


In [36]:
regex_session_and_alert_on_identifier.groupBy("EntityType").agg(f.size(f.collect_set("IdentifierValue")).alias("EntityCount")).orderBy(f.col("EntityType").asc()).show(n=100)

StatementMeta(maxpiasevoliL, 154, 36, Finished, Available)

+--------------------+-----------+
|          EntityType|EntityCount|
+--------------------+-----------+
|             account|        291|
|      azure-resource|         34|
|                file|        212|
|                host|         38|
|                  ip|        627|
|             mailbox|        775|
|mailboxconfiguration|          1|
|         mailmessage|        984|
|             process|        345|
|      security-group|          2|
|                 url|          4|
+--------------------+-----------+



In [37]:
# static method for extracting entities 
header = f"""let entity_field_delimiter = "__";
EvaluationSkillInvocations 
| where PreciseTimeStamp between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
"""

entities_query_str = header + """
| where State == 'Completed' and Success == 'true'
| where SkillName contains 'GetSentinelIncidentEntitiesByUrl'
| extend SessionId = case(isempty(SessionId), sessionid, SessionId) 
| extend SkillOutput = parse_json(SkillOutput) 
| join kind=inner(
    Sessions
    | distinct SessionId, TenantId
) on SessionId 
| mv-expand SkillOutput = todynamic(SkillOutput.entities) // to typeof(dynamic)
| extend EntityId = tostring(SkillOutput.name)
| extend EntityType = tolower(tostring(SkillOutput.kind))
| extend Properties = todynamic(SkillOutput.properties) 
| where isnotempty(EntityType)
| extend Identifiers = case(
    EntityType == 'account', pack_array(strcat(tostring(Properties.accountName), "@", tostring(Properties.upnSuffix)), tostring(Properties.aadUserId), tostring(Properties.sid), tostring(Properties.accountName)),
    EntityType == 'cloudapplication', pack_array(tostring(Properties.appId), strcat(tostring(Properties.appName), entity_field_delimiter, tostring(Properties.instanceName))),
    EntityType == 'file', pack_array(tostring(Properties.fileName), strcat(tostring(Properties.directory), entity_field_delimiter, tostring(Properties.fileName))),
    EntityType == 'filehash', pack_array(strcat(tostring(Properties.algorithm), entity_field_delimiter, tostring(Properties.hashValue))),
    EntityType == 'host', pack_array(tostring(Properties.azureID), tostring(Properties.hostName), tostring(Properties.omsAgentID)),
    EntityType == 'iotdevice', pack_array(tostring(Properties.deviceId)),
    EntityType == 'ip', pack_array(tostring(Properties.address)),
    EntityType == 'mailbox' or EntityType == "mailboxconfiguration", pack_array(tostring(Properties.mailboxPrimaryAddress)),
    EntityType == 'mailcluster', pack_array(strcat(tostring(Properties.source), entity_field_delimiter, tostring(Properties.query))),
    EntityType == 'mailmessage', pack_array(tostring(Properties.sender), tostring(Properties.recipient), tostring(Properties.senderIP)),
    EntityType == 'oauth-application', pack_array(tostring(Properties.oAuthAppId)),
    EntityType == 'process', pack_array(strcat(tostring(Properties.processId), entity_field_delimiter, tostring(Properties.creationTimeUtc), entity_field_delimiter, tostring(Properties.commandLine))),
    EntityType == 'security-group', pack_array(tostring(Properties.sid), tostring(Properties.objectGuid)),
    EntityType == 'service-principal', pack_array(tostring(Properties.servicePrincipalObjectId)),
    EntityType == 'url', pack_array(tostring(Properties.url)),
    EntityType == 'azure-resource', pack_array(
        tolower(tostring(Properties.resourceId)),
        extract("subscriptions/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})", 1, tostring(tolower(Properties.resourceId))),
        extract("resourcegroups/([^/]*)/", 1, tolower(tostring(Properties.resourceId)))),
    pack_array("")
)
| extend ExtractedEntityType = case(EntityType == "cloudapplication", "cloud-application", EntityType)
| extend EntityTypeBag = bag_pack("node_type", ExtractedEntityType, "extracted_via_static", "True")
| extend ExtractedEntityNodeAttributes = dynamic_to_json(bag_merge(Properties, EntityTypeBag))
| mv-expand Identifiers
| extend Identifier = tolower(tostring(Identifiers))
| project SessionId, SkillInvocationId, TenantId, Identifier, ExtractedEntityType, ExtractedEntityNodeAttributes
"""

evaluation_skill_invocations_entities_static_sdf = spark.read \
    .format("com.microsoft.kusto.spark.synapse.datasource") \
    .option("spark.synapse.linkedService", "medeinaapiprod") \
    .option("kustoDatabase", "medeinalogs") \
    .option("authType", "LS")\
    .option("kustoQuery", entities_query_str) \
    .load()

StatementMeta(maxpiasevoliL, 154, 37, Finished, Available)

In [38]:
print("Potential entity identifiers extracted via static: ", evaluation_skill_invocations_entities_static_sdf.count())

StatementMeta(maxpiasevoliL, 154, 38, Finished, Available)

Potential entity identifiers extracted via static:  1222


In [39]:
evaluation_skill_invocations_entities_static_sdf.printSchema()

StatementMeta(maxpiasevoliL, 154, 39, Finished, Available)

root
 |-- SessionId: string (nullable = true)
 |-- SkillInvocationId: string (nullable = true)
 |-- TenantId: string (nullable = true)
 |-- Identifier: string (nullable = true)
 |-- ExtractedEntityType: string (nullable = true)
 |-- ExtractedEntityNodeAttributes: string (nullable = true)



In [40]:
static_session_and_alert_on_identifier = join_session_and_alert_on_identifier(evaluation_skill_invocations_entities_static_sdf)

StatementMeta(maxpiasevoliL, 154, 40, Finished, Available)

In [41]:
print("Entity identifiers extracted via static: ", static_session_and_alert_on_identifier.count())

StatementMeta(maxpiasevoliL, 154, 41, Finished, Available)

Entity identifiers extracted via static:  11387


In [42]:
static_session_and_alert_on_identifier.printSchema()

StatementMeta(maxpiasevoliL, 154, 42, Finished, Available)

root
 |-- WorkspaceId: string (nullable = true)
 |-- TenantId: string (nullable = true)
 |-- EntityType: string (nullable = true)
 |-- IdentifierValue: string (nullable = true)
 |-- SystemAlertId: string (nullable = true)
 |-- SessionId: string (nullable = true)
 |-- SkillInvocationId: string (nullable = true)
 |-- ExtractedEntityNodeAttributes: string (nullable = true)



In [43]:
# check entity types for which alert enrichment is working
print("check entity types for which alert enrichment is working")
static_session_and_alert_on_identifier.where(f.col("SystemAlertId").isNotNull()).select("EntityType").distinct().orderBy(f.col("EntityType").asc()).show()

StatementMeta(maxpiasevoliL, 154, 43, Finished, Available)

check entity types for which alert enrichment is working
+-----------------+
|       EntityType|
+-----------------+
|          account|
|cloud-application|
|             host|
|               ip|
|          mailbox|
|      mailcluster|
|      mailmessage|
|              url|
+-----------------+



In [44]:
# check all entity types extracted by static method
print("check all entity types extracted by static method")
static_session_and_alert_on_identifier.select("EntityType").distinct().orderBy(f.col("EntityType").asc()).show()

StatementMeta(maxpiasevoliL, 154, 44, Finished, Available)

check all entity types extracted by static method
+-----------------+
|       EntityType|
+-----------------+
|          account|
|    azureresource|
|cloud-application|
|    dnsresolution|
|             file|
|         filehash|
|             host|
|               ip|
|          mailbox|
|      mailcluster|
|      mailmessage|
|              url|
+-----------------+



In [45]:
evaluation_skill_invocations_entities_induction_raw = (
    spark
        .read
        .format('delta')
        .load('abfss://mlap@rdamlapeussa.dfs.core.windows.net/skg/induction_entity_extraction/gpt_extracted_entities_prompts_v2')
        .filter(
            (f.col("ProcessedTimestamp") >= start_processing_timestamp)
            & (f.col("ProcessedTimestamp") < end_processing_timestamp)
        )
        .filter(~f.col('Entities').contains("NONE"))
        .filter(~f.col('Entities').contains("None"))
        
)

StatementMeta(maxpiasevoliL, 154, 45, Finished, Available)

In [46]:
rows = []
for row in evaluation_skill_invocations_entities_induction_raw.collect():
    
    entities = row.Entities
    session_id = row.SessionId
    skill_invocation_id = row.SkillInvocationId
    tenant_id = row.TenantId 
    skillname = row.SkillName
    parsed_list = []
    prev_type = ""
    
    for entity in str(entities).strip().split(","):
        
        entity_ = entity.split(":",1)
        if len(entity_) == 2:
            attr_type = entity_[0]
            value = entity_[1]
        else:
            continue

        if attr_type == "domain" and prev_type == "user":
            parsed_list.append(str(prev_type + ":" + prev_value+"@"+value)) 
            
        else:    
            parsed_list.append(attr_type + ":" + value)

        prev_type = attr_type
        prev_value = value

    parsed_list_str = ",".join(str(x) for x in parsed_list)
    row = {}
    row['SessionId'] = session_id
    row['SkillInvocationId'] = skill_invocation_id
    row['SkillName'] = skillname
    row['TenantId'] = tenant_id
    row['Entities'] = parsed_list_str
    rows.append(row)

schema = StructType([StructField("SessionId", StringType(), True), StructField("SkillInvocationId", StringType(), True), StructField("SkillName", StringType(), True),
                        StructField("TenantId", StringType(), True), StructField("Entities", StringType(), True),])
evaluation_skill_invocations_entities_induction = (
    spark
        .createDataFrame((Row(**x) for x in rows), schema)
        .select('*', f.explode(f.split("Entities", ",")).alias("EntityPair"))
        .select('*', f.trim(f.element_at(f.split("EntityPair", ":"), 1)).alias("ExtractedEntityType"), f.trim(f.element_at(f.split("EntityPair", ":"), 2)).alias("Identifier"))
        .drop('Entities', 'EntityPair', 'SkillName')
        .where(f.col('ExtractedEntityType').isin('user', 'device', 'ipaddress'))
        .withColumn('ExtractedEntityType', f.when(f.col("ExtractedEntityType") == "user", "account").otherwise(f.col("ExtractedEntityType")))
        .withColumn('ExtractedEntityType', f.when(f.col("ExtractedEntityType") == "ipaddress", "ip").otherwise(f.col("ExtractedEntityType")))
        .distinct()
        .withColumn(
            'ExtractedEntityNodeAttributes', 
            f.to_json(
                f.create_map(
                    f.lit('node_type'), f.col("ExtractedEntityType"),
                    f.lit('extracted_via_induction'), f.lit("True")
                )
            )
        )
)

StatementMeta(maxpiasevoliL, 154, 46, Finished, Available)

In [47]:
evaluation_skill_invocations_entities_induction.groupBy("ExtractedEntityType").count().orderBy(f.col("count").desc()).show()

StatementMeta(maxpiasevoliL, 154, 47, Finished, Available)

+-------------------+-----+
|ExtractedEntityType|count|
+-------------------+-----+
|            account|  239|
|                 ip|  104|
|             device|   63|
+-------------------+-----+



In [48]:
print("Induction raw dataframe count: ", evaluation_skill_invocations_entities_induction.count())

StatementMeta(maxpiasevoliL, 154, 48, Finished, Available)

Induction raw dataframe count:  406


In [49]:
# @Anand, include the MSR induction method for extracting alerts here
induction_session_and_alert_on_identifier = join_session_and_alert_on_identifier(evaluation_skill_invocations_entities_induction, join_on_entity_type=False)

StatementMeta(maxpiasevoliL, 154, 49, Finished, Available)

In [50]:
print("Total count of edges added by induction: ", induction_session_and_alert_on_identifier.count())

StatementMeta(maxpiasevoliL, 154, 50, Finished, Available)

Total count of edges added by induction:  614


In [51]:
print("Count of induction rows joined to sentinel alerts: ", induction_session_and_alert_on_identifier.where(f.col("SystemAlertId").isNull()).count())

StatementMeta(maxpiasevoliL, 154, 51, Finished, Available)

Count of induction rows joined to sentinel alerts:  371


In [52]:
induction_session_and_alert_on_identifier.printSchema()

StatementMeta(maxpiasevoliL, 154, 52, Finished, Available)

root
 |-- WorkspaceId: string (nullable = true)
 |-- TenantId: string (nullable = true)
 |-- EntityType: string (nullable = true)
 |-- IdentifierValue: string (nullable = true)
 |-- SystemAlertId: string (nullable = true)
 |-- SessionId: string (nullable = true)
 |-- SkillInvocationId: string (nullable = true)
 |-- ExtractedEntityNodeAttributes: string (nullable = true)



In [53]:
static_session_and_alert_on_identifier.printSchema()

StatementMeta(maxpiasevoliL, 154, 53, Finished, Available)

root
 |-- WorkspaceId: string (nullable = true)
 |-- TenantId: string (nullable = true)
 |-- EntityType: string (nullable = true)
 |-- IdentifierValue: string (nullable = true)
 |-- SystemAlertId: string (nullable = true)
 |-- SessionId: string (nullable = true)
 |-- SkillInvocationId: string (nullable = true)
 |-- ExtractedEntityNodeAttributes: string (nullable = true)



In [54]:
# Merge entity attributes with extraction technique attributes
extracted_entity_node_attributes = (
    regex_session_and_alert_on_identifier
    .select("TenantId", "EntityType", "IdentifierValue", "ExtractedEntityNodeAttributes")
    .union(
        static_session_and_alert_on_identifier.select("TenantId", "EntityType", "IdentifierValue", "ExtractedEntityNodeAttributes")
    )
    .union(
        induction_session_and_alert_on_identifier.select("TenantId", "EntityType", "IdentifierValue", "ExtractedEntityNodeAttributes")
    )
    .groupBy("TenantId", "EntityType", "IdentifierValue")
    .agg(f.collect_list("ExtractedEntityNodeAttributes").alias("ExtractedEntityNodeAttributesList"))
)

def combine_node_attributes(enriched_node_attributes_dict, extracted_node_attributes_dict_list):
    if extracted_node_attributes_dict_list is None:
        extracted_node_attributes_dict_list = []
    dict_list = extracted_node_attributes_dict_list + [enriched_node_attributes_dict]
    return json.dumps({k:v for node_attributes_dict in dict_list for k,v in json.loads(node_attributes_dict).items()})

combined_node_attributes_udf = f.udf(combine_node_attributes, StringType())

entity_node_attributes = (
    enriched_entity_node_attributes
        .join(
            extracted_entity_node_attributes,
            ['TenantId', 'EntityType', 'IdentifierValue'],
            'full_outer'
        )
        .fillna('{}', ['EnrichedEntityNodeAttributes'])
        .withColumn("EntityNodeAttributes", combined_node_attributes_udf("EnrichedEntityNodeAttributes", "ExtractedEntityNodeAttributesList"))
        .drop("EnrichedEntityNodeAttributes", "ExtractedEntityNodeAttributesList")
        .persist()
)

StatementMeta(maxpiasevoliL, 154, 54, Finished, Available)

In [55]:
# Merge incident attributes
extracted_incident_node_attributes = (
    regex_session_and_incident_on_incident_guid
        .select("WorkspaceId", "IncidentName", "ExtractedIncidentNodeAttributes")
        .union(
            static_session_and_incident_on_incident_guid.select("WorkspaceId", "IncidentName", "ExtractedIncidentNodeAttributes")
        )
        .groupBy("WorkspaceId", "IncidentName")
        .agg(f.collect_list("ExtractedIncidentNodeAttributes").alias("ExtractedIncidentNodeAttributesList"))
)

incident_node_attributes = (
    incident_node_attributes
        .join(
            extracted_incident_node_attributes,
            ["WorkspaceId", "IncidentName"],
            'left_outer'
        )
        .withColumn("IncidentNodeAttributes", combined_node_attributes_udf("IncidentNodeAttributes", "ExtractedIncidentNodeAttributesList"))
        .drop("ExtractedIncidentNodeAttributesList")
        .persist()
)

StatementMeta(maxpiasevoliL, 154, 55, Finished, Available)

In [56]:
# Merge incident attributes
extracted_alert_node_attributes = (
    regex_session_and_alert_on_alert_guid
        .select("WorkspaceId", "SystemAlertId", "ExtractedAlertNodeAttributes")
        .union(
            static_session_and_alert_on_alert_guid.select("WorkspaceId", "SystemAlertId", "ExtractedAlertNodeAttributes")
        )
        .groupBy("WorkspaceId", "SystemAlertId")
        .agg(f.collect_list("ExtractedAlertNodeAttributes").alias("ExtractedAlertNodeAttributesList"))
)

alert_node_attributes = (
    alert_node_attributes
        .join(
            extracted_alert_node_attributes,
            ["WorkspaceId", "SystemAlertId"],
            'left_outer'
        )
        .withColumn("AlertNodeAttributes", combined_node_attributes_udf("AlertNodeAttributes", "ExtractedAlertNodeAttributesList"))
        .drop("ExtractedAlertNodeAttributesList")
        .persist()
)

StatementMeta(maxpiasevoliL, 154, 56, Finished, Available)

# Construct Networkx Graphs

In [57]:
# Get session nodes and properties

from pyspark.sql.functions import col, to_timestamp

evaluation_skill_invocations_query_str = f"""EvaluationSkillInvocations
| where isnotempty(PromptId)
| where PreciseTimeStamp between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
| where State == "Completed" and Success == "true"  
| extend SessionId = case(isempty(SessionId), sessionid, SessionId)
| extend SkillInputs = replace_regex(SkillInputs, @"\\\\u0022", " ") // replace unicode quotation mark
| distinct SessionId, PromptId, SkillInvocationId, SkillName, SkillInputs, SkillOutput, ParentSkillInvocationId, PreciseTimeStamp
| join kind=inner(
    Sessions
    | distinct SessionId, TenantId
) on SessionId
"""

"""
KQL for adding Prompt col
"""
prompt_query_str = f"""EvaluationSkillInvocations
| where PreciseTimeStamp between(todatetime('{start_processing_timestamp.strftime('%Y-%m-%d')}') .. todatetime('{end_processing_timestamp.strftime('%Y-%m-%d')}'))
| where isnotempty(PromptId)
| where State == "Completed" and Success == "true"
| extend SessionId = case(isempty(SessionId), sessionid, SessionId)
| where SkillName == 'Prompt' and not(SkillInputs has('Sort the following JSON objects'))
| extend Prompt = tostring(parse_json(SkillInputs).Input)
| where isnotempty(Prompt)
| summarize Prompt=tostring(make_set(Prompt)), MinPreciseTimeStamp=min(PreciseTimeStamp), MaxPreciseTimeStamp=max(PreciseTimeStamp) by SessionId, PromptId
| join kind=inner(
    Sessions
    | distinct SessionId, TenantId
) on SessionId
"""

evaluation_skill_invocations_sdf  = spark.read \
    .format("com.microsoft.kusto.spark.synapse.datasource") \
    .option("spark.synapse.linkedService", "medeinaapiprod") \
    .option("kustoDatabase", "medeinalogs") \
    .option("authType", "LS")\
    .option("kustoQuery", evaluation_skill_invocations_query_str) \
    .load()

evaluation_skill_invocations_prompt_sdf = spark.read \
    .format("com.microsoft.kusto.spark.synapse.datasource") \
    .option("spark.synapse.linkedService", "medeinaapiprod") \
    .option("kustoDatabase", "medeinalogs") \
    .option("authType", "LS")\
    .option("kustoQuery", prompt_query_str) \
    .load()

print('Prompt count: ', evaluation_skill_invocations_prompt_sdf.count())

session_ids_sdf = evaluation_skill_invocations_prompt_sdf.select("SessionId").distinct()
prompt_ids_sdf = evaluation_skill_invocations_prompt_sdf.select("PromptId").distinct()

evaluation_skill_invocations_sdf = evaluation_skill_invocations_sdf.join(prompt_ids_sdf, "PromptId").persist()

tenant_ids_sdf = evaluation_skill_invocations_sdf.select("TenantId").distinct()

print('Tenant count: ', tenant_ids_sdf.count())

skill_ids_sdf = evaluation_skill_invocations_sdf.select("SkillInvocationId").distinct()

StatementMeta(maxpiasevoliL, 154, 57, Finished, Available)

Prompt count:  8952
Tenant count:  52


In [58]:
edges_df = spark.createDataFrame(
    data = spark.sparkContext.emptyRDD(), 
    schema = StructType([
    StructField("WorkspaceId", StringType(), nullable=False),
    StructField("TenantId", StringType(), nullable=False),
    StructField("Source", StringType(), nullable=False),
    StructField("Target", StringType(), nullable=False)])
)

StatementMeta(maxpiasevoliL, 154, 58, Finished, Available)

In [59]:
# skill to incident edges
for session_and_incident in [
    regex_session_and_incident_on_incident_guid,
    static_session_and_incident_on_incident_guid
]:

    edges_df = (
        edges_df
            .union(session_and_incident
                .join(skill_ids_sdf, "SkillInvocationId")
                .select(
                    "WorkspaceId", 
                    "TenantId", 
                    f.concat_ws("__", "TenantId", f.lit("SkillInvocationId"), "SkillInvocationId").alias("Source"), 
                    f.concat_ws("__", "TenantId", f.lit("securityincident"), "IncidentName").alias("Target")
                )
            )
    )

StatementMeta(maxpiasevoliL, 154, 59, Finished, Available)

In [60]:
# skill to alert edges
for session_and_alert in [
    regex_session_and_alert_on_alert_guid,
    static_session_and_alert_on_alert_guid
]:
    edges_df = (
        edges_df
            .union(session_and_alert
                .join(skill_ids_sdf, "SkillInvocationId")
                .select(
                    "WorkspaceId", 
                    "TenantId", 
                    f.concat_ws("__", "TenantId", f.lit("SkillInvocationId"), "SkillInvocationId").alias("Source"), 
                    f.concat_ws("__", "TenantId", f.lit("securityalert"), "SystemAlertId").alias("Target")
                )
            )
    )

StatementMeta(maxpiasevoliL, 154, 60, Finished, Available)

In [61]:
# skill to entity edges
for session_and_entity in [
    regex_session_and_alert_on_identifier,
    static_session_and_alert_on_identifier,
    induction_session_and_alert_on_identifier
]:

    edges_df = (
        edges_df
            .union(session_and_entity
                .join(skill_ids_sdf, "SkillInvocationId")
                .select(
                    "WorkspaceId", 
                    "TenantId", 
                    f.concat_ws("__", "TenantId", f.lit("SkillInvocationId"), "SkillInvocationId").alias("Source"), 
                    f.concat_ws("__", "TenantId", "EntityType", "IdentifierValue").alias("Target")
                )
            )
    )

StatementMeta(maxpiasevoliL, 154, 61, Finished, Available)

In [62]:
# incident to alert edges
edges_df = (
    edges_df
        .union(
            incident_to_alert_data
                .withColumnRenamed("WorkspaceTenantId", "TenantId")
                #.join(tenant_ids_sdf, "TenantId")
                .select(
                    "WorkspaceId", 
                    "TenantId", 
                    f.concat_ws("__", "TenantId", f.lit("securityincident"), "IncidentName").alias("Source"),
                    f.concat_ws("__", "TenantId", f.lit("securityalert"), "SystemAlertId").alias("Target")
                )
        )
)

StatementMeta(maxpiasevoliL, 154, 62, Finished, Available)

In [63]:
# alert to entity edges
edges_df = (
    edges_df
        .union(
            deduped_alert_data_with_entities_flattened
                .withColumnRenamed("WorkspaceTenantId", "TenantId")
                #.join(tenant_ids_sdf, "TenantId")
                .select(
                    "WorkspaceId", 
                    "TenantId", 
                    f.concat_ws("__", "TenantId", f.lit("securityalert"), "SystemAlertId").alias("Source"),
                    f.concat_ws("__", "TenantId", "EntityType", "IdentifierValue").alias("Target")
                )
        )
)

StatementMeta(maxpiasevoliL, 154, 63, Finished, Available)

In [64]:
edges_df = (
    edges_df
        .union(
            evaluation_skill_invocations_sdf
                .select(
                    f.lit("").alias("WorkspaceId"), 
                    "TenantId", 
                    f.concat_ws("__", "TenantId", f.lit("SessionId"), "SessionId").alias("Source"),
                    f.concat_ws("__", "TenantId", f.lit("PromptId"), "PromptId").alias("Target")
                )
        )
        .union(
            evaluation_skill_invocations_sdf
                .select(
                    f.lit("").alias("WorkspaceId"), 
                    "TenantId", 
                    f.concat_ws("__", "TenantId", f.lit("PromptId"), "PromptId").alias("Source"),
                    f.concat_ws("__", "TenantId", f.lit("SkillInvocationId"), "SkillInvocationId").alias("Target")
                )
        )
)

StatementMeta(maxpiasevoliL, 154, 64, Finished, Available)

In [65]:
evaluation_skill_invocations_sdf.printSchema()

StatementMeta(maxpiasevoliL, 154, 65, Finished, Available)

root
 |-- PromptId: string (nullable = true)
 |-- SessionId: string (nullable = true)
 |-- SkillInvocationId: string (nullable = true)
 |-- SkillName: string (nullable = true)
 |-- SkillInputs: string (nullable = true)
 |-- SkillOutput: string (nullable = true)
 |-- ParentSkillInvocationId: string (nullable = true)
 |-- PreciseTimeStamp: timestamp (nullable = true)
 |-- SessionId1: string (nullable = true)
 |-- TenantId: string (nullable = true)



In [66]:
import datetime as dt, datetime as dt, pandas as pd, networkx as nx

StatementMeta(maxpiasevoliL, 154, 66, Finished, Available)

In [67]:
node_attributes_df = spark.createDataFrame(
    data = spark.sparkContext.emptyRDD(), 
    schema = StructType([
    StructField("TenantId", StringType(), nullable=False),
    StructField("NodeId", StringType(), nullable=False),
    StructField("NodeAttributes", StringType(), nullable=False)])
)

StatementMeta(maxpiasevoliL, 154, 67, Finished, Available)

In [68]:
# session, prompt and skill attributes
node_attributes_df = (
    node_attributes_df
        .union(
            evaluation_skill_invocations_sdf
                .groupBy('TenantId', 'SessionId')
                .agg(
                    f.min('PreciseTimeStamp').alias('MinPreciseTimeStamp'),
                    f.max('PreciseTimeStamp').alias('MaxPreciseTimeStamp')
                )
                .select(
                    'TenantId',
                    f.concat_ws("__", "TenantId", f.lit("SessionId"), "SessionId"),
                    f.to_json(
                        f.create_map(
                            f.lit("node_type"), f.lit("SessionId"),
                            f.lit("MinPreciseTimeStamp"), f.col("MinPreciseTimeStamp"),
                            f.lit("MaxPreciseTimeStamp"), f.col("MaxPreciseTimeStamp")
                        )
                    )
                )
        )
        .union(
            evaluation_skill_invocations_prompt_sdf
                .select(
                    'TenantId',
                    f.concat_ws("__", "TenantId", f.lit("PromptId"), "PromptId"),
                    f.to_json(
                        f.create_map(
                            f.lit("node_type"), f.lit("PromptId"),
                            f.lit("Prompt"), f.col("Prompt"),
                            f.lit("MinPreciseTimeStamp"), f.col("MinPreciseTimeStamp"),
                            f.lit("MaxPreciseTimeStamp"), f.col("MaxPreciseTimeStamp")
                        )
                    )
                )
        )
        .union(
            evaluation_skill_invocations_sdf
                .select(
                    'TenantId',
                    f.concat_ws("__", "TenantId", f.lit("SkillInvocationId"), "SkillInvocationId"),
                    f.to_json(
                        f.create_map(
                            f.lit("node_type"), f.lit("SkillInvocationId"),
                            f.lit("SkillName"), f.col("SkillName"),
                            f.lit("SkillInputs"), f.col("SkillInputs"),
                            f.lit("SkillOutput"), f.col("SkillOutput"),
                            f.lit("ParentSkillInvocationId"), f.col("ParentSkillInvocationId"),
                            f.lit("PreciseTimeStamp"), f.col("PreciseTimeStamp"),
                        )
                    )
                )
        )
)


StatementMeta(maxpiasevoliL, 154, 68, Finished, Available)

In [69]:
# incident, alert, entity attributes
node_attributes_df = (
    node_attributes_df
        .union(
            incident_node_attributes
                .withColumnRenamed("WorkspaceTenantId", "TenantId")
                # .join(
                #     tenant_ids_sdf,
                #     'TenantId'
                # )
                .select(
                    'TenantId',
                    f.concat_ws("__", "TenantId", f.lit("securityincident"), "IncidentName"),
                    'IncidentNodeAttributes'
                )
        )
        .union(
            alert_node_attributes
                .withColumnRenamed("WorkspaceTenantId", "TenantId")
                # .join(
                #     tenant_ids_sdf,
                #     'TenantId'
                # )
                .select(
                    'TenantId',
                    f.concat_ws("__", "TenantId", f.lit("securityalert"), "SystemAlertId"),
                    'AlertNodeAttributes'
                )
        )
        .union(
            entity_node_attributes
                # .join(
                #     tenant_ids_sdf,
                #     'TenantId'
                # )
                .select(
                    'TenantId',
                    f.concat_ws("__", "TenantId", "EntityType", "IdentifierValue"),
                    'EntityNodeAttributes'
                )
        )
)


StatementMeta(maxpiasevoliL, 154, 69, Finished, Available)

In [70]:
edges_df = edges_df.distinct().persist()
node_attributes_df = node_attributes_df.distinct().persist()

StatementMeta(maxpiasevoliL, 154, 70, Finished, Available)

# COOK Enrichment

In [71]:
cook_path_str = "abfss://skg@rdamlapeussa.dfs.core.windows.net/input_data/cook_ontology_9_5_23"

StatementMeta(maxpiasevoliL, 154, 71, Finished, Available)

In [72]:
cook_subgraph = {}

rdf = spark.read.format("delta").load(cook_path_str)

for row in rdf.collect():
    v = row.Value
    k = row.Key
    
    v_pickled = pickle.loads(v)
    cook_subgraph[k] = v_pickled

cook_ontology_graph = cook_subgraph["cc_subgraph"]

StatementMeta(maxpiasevoliL, 154, 72, Finished, Available)

In [73]:
def make_edges_df(G):
    edges = []
    for source, target in G.edges():
        node_dict = {'Source': source, 'Target': target}
        edges.append(node_dict)
    return pd.DataFrame(edges)

def make_node_attr_df(G):
    nodes = []
    for node, attribute in G.nodes(data=True):
        node_dict = {'NodeId': node, 'NodeAttributes': attribute}
        nodes.append(node_dict)
    return pd.DataFrame(nodes)

def make_edge_attr_df(G):
    edges = []
    for start_node, end_node, attribute in G.edges(data=True):
        edge_dict = {'Source': start_node, 'Target': end_node, 'EdgeAttributes': attribute}
        edges.append(edge_dict)
    return pd.DataFrame(edges)

StatementMeta(maxpiasevoliL, 154, 73, Finished, Available)

In [74]:
# Convert cook subgraph (networkx graph) => cook_edges_df, cook_node_attr_df, cook_edges_attr_df
cook_edges_pdf = make_edges_df(cook_ontology_graph)
cook_edges_sdf = spark.createDataFrame(
    data = cook_edges_pdf, 
    schema = StructType([
    StructField("Source", StringType(), nullable=False),
    StructField("Target", StringType(), nullable=False)])
)

cook_node_attr_pdf = make_node_attr_df(cook_ontology_graph)
cook_node_attr_pdf['NodeAttributes'] = cook_node_attr_pdf['NodeAttributes'].apply(lambda x: json.dumps(x))
cook_node_attr_sdf = spark.createDataFrame(
    data = cook_node_attr_pdf, 
    schema = StructType([
    StructField("NodeId", StringType(), nullable=False),
    StructField("NodeAttributes", StringType(), nullable=False)])
)

cook_edge_attr_pdf = make_edge_attr_df(cook_ontology_graph)
cook_edge_attr_pdf['EdgeAttributes'] = cook_edge_attr_pdf['EdgeAttributes'].apply(lambda x: json.dumps(x))
cook_edge_attr_sdf = spark.createDataFrame(
    data = cook_edge_attr_pdf, 
    schema = StructType([
    StructField("Source", StringType(), nullable=False),
    StructField("Target", StringType(), nullable=False),
    StructField("EdgeAttributes", StringType(), nullable=False)])
)

StatementMeta(maxpiasevoliL, 154, 74, Finished, Available)

In [75]:
# Enrich node_attributes_df with AlertDisplayName
# Define the schema of the JSON string
alert_properties_schema = StructType([
    StructField("node_type", StringType()),
    StructField("AlertDisplayName", StringType())
])

alert_node_attributes_df = (
    node_attributes_df
        .withColumn("node_type", f.from_json(f.col("NodeAttributes"), alert_properties_schema).getItem("node_type"))
        .filter(f.col("node_type") == "securityalert")
        .withColumn("AlertDisplayName", f.from_json(f.col("NodeAttributes"), alert_properties_schema).getItem("AlertDisplayName"))
        .persist()
)

alert_node_attr_pdf = alert_node_attributes_df.toPandas()

StatementMeta(maxpiasevoliL, 154, 75, Finished, Available)

In [76]:
# Merge alerts and cook detections df
cook_node_attr_pdf['displayName'] = cook_node_attr_pdf['NodeAttributes'].apply(lambda x: json.loads(x)['properties']['displayName'])
cook_node_attr_pdf['node_type'] = cook_node_attr_pdf['NodeAttributes'].apply(lambda x: json.loads(x)['node_type'])

cook_alerts_tenant_pdf = pd.merge(alert_node_attr_pdf, cook_node_attr_pdf, left_on='AlertDisplayName', right_on='displayName')

cook_alerts_tenant_sdf = spark.createDataFrame(cook_alerts_tenant_pdf)

cook_tenant_ids = cook_alerts_tenant_pdf['TenantId'].unique().tolist()
cook_tenant_ids_df = spark.createDataFrame(cook_tenant_ids, StringType()).toDF("TenantId")

print("Number of tenants joined to COOK ontology: ", len(cook_tenant_ids))

StatementMeta(maxpiasevoliL, 154, 76, Finished, Available)

Number of tenants joined to COOK ontology:  6340


In [77]:
# alert -> cook detection edges
edges_df = (
    edges_df
        .unionByName(
            cook_alerts_tenant_sdf
                .select(
                    f.lit("").alias("WorkspaceId"), 
                    "TenantId", 
                    col("NodeId_x").alias("Source"),
                    col("NodeId_y").alias("Target")
                )
        )
)

StatementMeta(maxpiasevoliL, 154, 77, Finished, Available)

In [78]:
# Create a Cartesian product of cook_tenant_ids_df and cook_edges_sdf
cook_edges_tenant_cartesian_df = cook_tenant_ids_df.crossJoin(cook_edges_sdf)
cook_edges_tenant_cartesian_df_selected = cook_edges_tenant_cartesian_df.select(
    f.lit("").alias("WorkspaceId"),
    "TenantId",
    "Source",
    "Target"
)

edges_df = edges_df.unionByName(cook_edges_tenant_cartesian_df_selected)

StatementMeta(maxpiasevoliL, 154, 78, Finished, Available)

In [79]:
# Create a Cartesian product of cook_tenant_ids_df and cook_node_attr_sdf
node_attr_tenant_cartesian_df = cook_tenant_ids_df.crossJoin(cook_node_attr_sdf)
node_attr_tenant_cartesian_df_selected = node_attr_tenant_cartesian_df.select(
    "TenantId",
    "NodeId",
    "NodeAttributes"
)

node_attributes_df = node_attributes_df.unionByName(node_attr_tenant_cartesian_df_selected)

StatementMeta(maxpiasevoliL, 154, 79, Finished, Available)

In [80]:
edge_attributes_df = spark.createDataFrame(
    data = spark.sparkContext.emptyRDD(), 
    schema = StructType([
    StructField("TenantId", StringType(), nullable=False),
    StructField("Source", StringType(), nullable=False),
    StructField("Target", StringType(), nullable=False),
    StructField("EdgeAttributes", StringType(), nullable=False)])
)

StatementMeta(maxpiasevoliL, 154, 80, Finished, Available)

In [81]:
# Create a Cartesian product of tenant_ids_df and cook_edge_attr_sdf
edges_attr_tenant_cartesian_df = cook_tenant_ids_df.crossJoin(cook_edge_attr_sdf)
edges_attr_tenant_cartesian_df_selected = edges_attr_tenant_cartesian_df.select(
    "TenantId",
    "Source",
    "Target",
    "EdgeAttributes"
)

edge_attributes_df = edge_attributes_df.unionByName(edges_attr_tenant_cartesian_df_selected)

StatementMeta(maxpiasevoliL, 154, 81, Finished, Available)

# Write Tenant Graph DataFrames

In [82]:
edges_df = edges_df.distinct().withColumn("ProcessedTimestamp", f.lit(start_processing_timestamp)).persist()
node_attributes_df = node_attributes_df.distinct().withColumn("ProcessedTimestamp", f.lit(start_processing_timestamp)).persist()
edge_attributes_df = edge_attributes_df.distinct().withColumn("ProcessedTimestamp", f.lit(start_processing_timestamp)).persist()

StatementMeta(maxpiasevoliL, 154, 82, Finished, Available)

In [83]:
edges_df.write.mode('overwrite').parquet(output_path + '/edges')

StatementMeta(maxpiasevoliL, 154, 83, Finished, Available)

In [84]:
node_attributes_df.write.mode('overwrite').parquet(output_path + '/node_attributes')

StatementMeta(maxpiasevoliL, 154, 84, Finished, Available)

In [85]:
edge_attributes_df.write.mode('overwrite').parquet(output_path + '/edge_attributes')

StatementMeta(maxpiasevoliL, 154, 85, Finished, Available)