In [None]:
# Use the shell escape command to install the package
!pip install graphframes
!pip install python-dotenv

In [None]:
import requests
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, explode, col, lit,  date_format, to_date, hour, minute, when, sum
from pyspark.sql.types import ArrayType, StructType, StructField, StringType, BooleanType
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.regression import RandomForestRegressor
from pyspark.sql import functions as F
from dotenv import load_dotenv
import os

# Load variables from .env file
load_dotenv()

# Access sensitive variables
es_url = os.getenv("ELASTICSEARCH")
es_username = os.getenv("ELASTICSEARCH_USERNAME")
es_password = os.getenv("ELASTICSEARCH_PASSWORD")

spark = SparkSession.builder.appName("REST_API_with_PySpark_DF") \
    .config("spark.jars.packages", "graphframes:graphframes:0.8.3-spark3.5-s_2.12") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

# Important parameters can be found here
# Some have to be manually modified within the code
schema = StructType([
    StructField("@timestamp", StringType(), True),
    StructField("source.address", StringType(), True),
    StructField("destination.address", StringType(), True),
    StructField("network.bytes", StringType(), True),
])
days_to_fetch = 1  # Specify the number of days to fetch

@udf(returnType=ArrayType(schema))
def fetch_data(offset: int, limit: int, days: int):
    endpoint = "https://"+es_url+"/filebeat-*/_search"

    # Calculate start date based on the specified number of days
    import datetime
    start_date = (datetime.datetime.now() - datetime.timedelta(days=days)).isoformat()

    fields = [field.name for field in schema.fields]

    # Elasticsearch query with time filter and proper pagination
    es_query = {
        "size": limit,
        "from": offset,
        "_source": fields,  # Select only specified fields
        "query": {
            "bool": {
                "must": [
                    {
                        "term": {
                            "suricata.eve.event_type": "flow"
                        }
                    },
                    {
                        "range": {
                            "@timestamp": {
                                "gte": start_date
                            }
                        }
                    }
                ]
            }
        }
    }


    headers = {
        "Content-Type": "application/json"
    }
    
    response = requests.get(endpoint, json=es_query, headers=headers, verify=False, auth=(es_username, es_password))

    # Extract hits from the response
    hits = response.json().get('hits', {}).get('hits', [])

    # Extract necessary fields from hits and create a list of records
    records = [{"@timestamp": hit.get('_source', {}).get('@timestamp'),
                "source.address": hit.get('_source', {}).get('source', {}).get('address'),
                "destination.address": hit.get('_source', {}).get('destination', {}).get('address'),
                "network.bytes": hit.get('_source', {}).get('network', {}).get('bytes')}
               for hit in hits]

    return records  # assuming API returns a list of records

# Get total docs
total_records = requests.get("https://"+es_url+"/filebeat-*/_count", verify=False, auth=(es_username, es_password)).json().get('count', 0)

records_per_page = 500

# Create DataFrame with pagination information
offsets_df = spark.range(0, total_records, records_per_page).select(col("id").alias("offset"), lit(records_per_page).alias("limit"))

# Apply fetch_data UDF to get the response with time filter
response_df = offsets_df.withColumn("response", explode(fetch_data("offset", "limit", lit(days_to_fetch))))

# Uncomment the lines above if you want to further explode and select individual fields
# response_df.show(truncate=False)

# Extract variables using positional indexing
result_df = response_df.select(
#    "offset",
#    "limit",
    col("response")["@timestamp"].alias("@timestamp"),
    col("response")["source.address"].alias("source.address"),
    col("response")["destination.address"].alias("destination.address"),
    col("response")["network.bytes"].alias("network.bytes")
)


result_df.show(truncate=False)

In [None]:
import ipaddress

# Define a function to check if an IP address is in RFC1918 private IP range
def is_private_ip(ip):
    try:
        ip_obj = ipaddress.ip_address(ip)
        return ip_obj.is_private or (ip_obj.version == 6 and ip_obj.is_link_local) or (ip_obj.version == 4 and not ip_obj.is_global)
    except ValueError:
        return False  # Not a valid IP address

# Create UDF to check if IP is private
is_private_ip_udf = udf(is_private_ip, BooleanType())

# Create Boolean columns to check if source and destination IPs are private
result_df = result_df.withColumn("source_is_private", is_private_ip_udf(col("`source.address`"))) \
                     .withColumn("destination_is_private", is_private_ip_udf(col("`destination.address`")))

# Update labels based on private IPs
result_df = result_df.withColumn("source.address", when(~col("source_is_private"), "Public").otherwise(col("`source.address`"))) \
                     .withColumn("destination.address", when(~col("destination_is_private"), "Public").otherwise(col("`destination.address`")))

# Drop the auxiliary columns
result_df = result_df.drop("source_is_private", "destination_is_private")

# Filter out multicast addresses
result_df = result_df.filter(~col("`source.address`").rlike("^224\\.") & ~col("`destination.address`").rlike("^224\\."))
result_df = result_df.filter(~col("`source.address`").rlike("^255.255.255.255") & ~col("`destination.address`").rlike("^255.255.255.255"))
result_df = result_df.filter(~col("`source.address`").rlike("^0.0.0.0") & ~col("`destination.address`").rlike("^0.0.0.0"))
result_df = result_df.filter(~col("`source.address`").rlike("^ff00::") & ~col("`destination.address`").rlike("^ff00::"))
result_df = result_df.filter(~col("`source.address`").rlike("\\.255$") & ~col("`destination.address`").rlike("\\.255$"))
result_df = result_df.filter(~(col("`source.address`") == col("`destination.address`")))

# Filter out records where network.bytes are 0
result_df = result_df.filter(col("`network.bytes`") != 0)

# Group by source and destination addresses, summing the network bytes
aggregated_df = result_df.groupBy("`source.address`", "`destination.address`") \
                  .agg(sum("`network.bytes`").alias("network.bytes"))

aggregated_df.count()

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from graphframes import GraphFrame

# Create a GraphFrame
edges = aggregated_df.select(
    F.col("`source.address`").alias("src"),
    F.col("`destination.address`").alias("dst"),
    F.col("`network.bytes`").alias("weight")
)
# Rename source and destination addresses to 'id' and create a DataFrame with distinct addresses
source_vertices = aggregated_df.select(col("`source.address`").alias("id")).distinct()
destination_vertices = aggregated_df.select(col("`destination.address`").alias("id")).distinct()

# Combine source and destination vertices into a single DataFrame
vertices = source_vertices.union(destination_vertices)

graph = GraphFrame(vertices, edges)

# Display the vertices and edges
# graph.vertices.show()
# graph.edges.show()

# Use the network.bytes as weights
weighted_edges = graph.edges.withColumn("weight", col("weight").cast("double"))

# Create a new GraphFrame with weighted edges
weighted_graph = GraphFrame(graph.vertices, weighted_edges)
# Display the new weighted graph
# weighted_graph.vertices.show()
# weighted_graph.edges.show()


In [None]:
%matplotlib widget
import math
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.layout import spring_layout

# Convert vertices DataFrame to a list of tuples
vertices_df = [(row['id']) for row in vertices.collect()]

# Convert edges DataFrame to a list of tuples
edges_df = [(row['src'], row['dst'], {'weight': float(row['weight'])}) for row in edges.collect()]
print(len(edges_df))

# Create a networkx graph
nx_graph = nx.DiGraph()

# Add vertices
nx_graph.add_nodes_from(vertices_df)

# Add weighted edges
nx_graph.add_edges_from(edges_df)

# Compute layout for better visualization
#  layout = nx.fruchterman_reingold_layout(nx_graph)
layout = nx.spring_layout(nx_graph)

# Calculate edge colors based on weights
edge_weights = [d['weight'] for (u, v, d) in nx_graph.edges(data=True)]
print(len(edge_weights))

# Log adjustment
log_edge_weights = [math.log(weight + 1) for weight in edge_weights]

# Calculate min and max weights after the logarithmic transformation
min_log_weight = min(log_edge_weights)
max_log_weight = max(log_edge_weights)

# Normalize the logarithmic weights
normalized_log_weights = [(weight - min_log_weight) / (max_log_weight - min_log_weight) for weight in log_edge_weights]

# Rescale the logarithmic weights to a wider range
min_rescaled_weight = min(normalized_log_weights)
max_rescaled_weight = max(normalized_log_weights)
rescaled_weights = [(weight - min_rescaled_weight) / (max_rescaled_weight - min_rescaled_weight) for weight in normalized_log_weights]

# Create a color map from white to blue with increasing intensity
edge_colors = [(1 - weight, 1 - weight, 1) for weight in rescaled_weights]

# Plot the networkx graph
plt.figure(figsize=(10, 6))
nx.draw(nx_graph, pos=layout, with_labels=True, node_color='skyblue', connectionstyle=f"arc3,rad=.5", node_size=300, font_size=8, font_weight='bold', edge_color=edge_colors, width=2, alpha=0.8)

# Convert edge weights from bytes to megabytes
#edge_labels = nx.get_edge_attributes(nx_graph, 'weight')
#edge_labels_mb = {edge: f"{weight / (1024 * 1024):,.0f} MB" for edge, weight in edge_labels.items()}

# Draw edge labels
# nx.draw_networkx_edge_labels(nx_graph, pos=layout, edge_labels=edge_labels_mb, font_size=6, label_pos=0.5, font_color='red')

plt.title('Weighted Graph Visualization')
plt.show()
