# Threat Detection Model Training & Registry

**Purpose:** Train a GradientBoosting classifier to predict high-threat security events and register it to Snowflake Model Registry.

**Environment:** Snowflake Notebook on Container Runtime (CPU compute pool)

**Key Steps:**
1. Load security events data via Snowpark
2. Feature engineering using Snowpark DataFrames (push compute to Snowflake)
3. Train model using `snowflake.ml.modeling.ensemble.GradientBoostingClassifier`
4. Register model to Snowflake Model Registry with metrics and metadata
5. Test inference and create batch scoring pipeline

In [None]:
# Cell 1: setup_imports
import warnings
warnings.filterwarnings('ignore')

from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import functions as F
from snowflake.snowpark.types import IntegerType, FloatType, StringType
from snowflake.ml.registry import Registry

session = get_active_session()
print(f"[Cell 1] Connected to Snowflake")
print(f"Account: {session.get_current_account()}")
print(f"Database: {session.get_current_database()}")
print(f"Schema: {session.get_current_schema()}")

In [None]:
# set_execution_context
session.sql("USE DATABASE SONY_SECURITY_DEMO").collect()
session.sql("USE SCHEMA SOC").collect()
print(f"Context set: {session.get_current_database()}.{session.get_current_schema()}")

## 1. Load Security Events Data

Load security events from Snowflake using Snowpark. We extract features from the raw data including temporal features and text indicators from analyst notes.

In [None]:
# load_security_events
events_df = session.table("SONY_SECURITY_DEMO.SOC.SECURITY_EVENTS")

events_df = events_df.select(
    F.col("EVENT_ID"),
    F.col("TIMESTAMP"),
    F.col("EVENT_TYPE"),
    F.col("SEVERITY"),
    F.col("ACTION_TAKEN"),
    F.col("DATA_SOURCE"),
    F.col("THREAT_CATEGORY"),
    F.col("IS_ESCALATED"),
    F.col("ANALYST_NOTES"),
    F.col("THREAT_SCORE"),
    F.hour(F.col("TIMESTAMP")).alias("EVENT_HOUR"),
    F.dayofweek(F.col("TIMESTAMP")).alias("DAY_OF_WEEK")
)

print(f"Loaded {events_df.count()} security events")
events_df.limit(5).show()

## 2. Feature Engineering

Create features from categorical columns and analyst notes. Uses `F.sql_expr()` for reliable SQL generation with string comparisons. All computation pushes to Snowflake.

In [None]:
# Cell 6: feature_engineering_severity
# Encode SEVERITY, EVENT_TYPE, and THREAT_CATEGORY as numeric risk scores (1-5)
# NOTE: Uses F.sql_expr() to ensure proper SQL string literal quoting

features_df = events_df.with_column(
    "SEVERITY_ENCODED",
    F.when(F.sql_expr("SEVERITY = 'Critical'"), 4)
    .when(F.sql_expr("SEVERITY = 'High'"), 3)
    .when(F.sql_expr("SEVERITY = 'Medium'"), 2)
    .otherwise(1)
)

features_df = features_df.with_column(
    "EVENT_TYPE_RISK",
    F.when(F.sql_expr("EVENT_TYPE IN ('C2 communication', 'Lateral movement', 'Data exfiltration attempt')"), 5)
    .when(F.sql_expr("EVENT_TYPE IN ('Privilege escalation', 'Malware detected', 'Credential theft')"), 4)
    .when(F.sql_expr("EVENT_TYPE IN ('Suspicious process', 'Unusual access pattern')"), 3)
    .when(F.sql_expr("EVENT_TYPE IN ('Authentication failure', 'Network scan')"), 2)
    .otherwise(1)
)

features_df = features_df.with_column(
    "THREAT_CAT_RISK",
    F.when(F.sql_expr("THREAT_CATEGORY IN ('APT', 'Ransomware')"), 5)
    .when(F.sql_expr("THREAT_CATEGORY IN ('Data theft', 'Credential theft')"), 4)
    .when(F.sql_expr("THREAT_CATEGORY IN ('Malware', 'Insider threat')"), 3)
    .when(F.sql_expr("THREAT_CATEGORY IN ('Reconnaissance', 'Phishing')"), 2)
    .otherwise(1)
)

print("[Cell 6] Added: SEVERITY_ENCODED, EVENT_TYPE_RISK, THREAT_CAT_RISK")

In [None]:
# Cell 7: feature_engineering_source_action
# Encode DATA_SOURCE, ACTION_TAKEN, weekend flag, and escalation flag
# NOTE: Uses F.sql_expr() for string comparisons

features_df = features_df.with_column(
    "SOURCE_RELIABILITY",
    F.when(F.sql_expr("DATA_SOURCE IN ('EDR', 'SIEM')"), 5)
    .when(F.sql_expr("DATA_SOURCE IN ('Firewall', 'IAM')"), 4)
    .when(F.sql_expr("DATA_SOURCE = 'Cloud_Audit'"), 3)
    .otherwise(2)
)

features_df = features_df.with_column(
    "ACTION_RISK",
    F.when(F.sql_expr("ACTION_TAKEN = 'Allowed'"), 4)
    .when(F.sql_expr("ACTION_TAKEN = 'Alerted'"), 3)
    .when(F.sql_expr("ACTION_TAKEN = 'Investigated'"), 2)
    .when(F.sql_expr("ACTION_TAKEN = 'Quarantined'"), 1)
    .otherwise(0)
)

features_df = features_df.with_column(
    "IS_WEEKEND",
    F.when(F.sql_expr("DAY_OF_WEEK IN (0, 6)"), 1).otherwise(0)
)

features_df = features_df.with_column(
    "IS_ESCALATED_INT",
    F.when(F.col("IS_ESCALATED") == True, 1).otherwise(0)
)

print("[Cell 7] Added: SOURCE_RELIABILITY, ACTION_RISK, IS_WEEKEND, IS_ESCALATED_INT")

In [None]:
# Cell 8: feature_engineering_text_indicators
# Extract binary flags from ANALYST_NOTES using SQL LIKE patterns
# NOTE: Uses F.sql_expr() for reliable text pattern matching

features_df = features_df.with_column(
    "HAS_APT_MENTION",
    F.when(
        F.sql_expr("ANALYST_NOTES LIKE '%APT%' OR ANALYST_NOTES LIKE '%advanced persistent%'"), 
        1
    ).otherwise(0)
)

features_df = features_df.with_column(
    "HAS_RANSOMWARE_MENTION",
    F.when(
        F.sql_expr("ANALYST_NOTES LIKE '%ransomware%' OR ANALYST_NOTES LIKE '%encryption%'"),
        1
    ).otherwise(0)
)

features_df = features_df.with_column(
    "HAS_EXFIL_MENTION",
    F.when(
        F.sql_expr("ANALYST_NOTES LIKE '%exfiltration%' OR ANALYST_NOTES LIKE '%data theft%'"),
        1
    ).otherwise(0)
)

features_df = features_df.with_column(
    "HAS_MITRE_REF",
    F.when(F.sql_expr("ANALYST_NOTES LIKE '%T1%'"), 1).otherwise(0)
)

features_df = features_df.with_column(
    "HAS_ESCALATION_MENTION",
    F.when(
        F.sql_expr("ANALYST_NOTES LIKE '%escalated%' OR ANALYST_NOTES LIKE '%Tier 2%'"),
        1
    ).otherwise(0)
)

print("[Cell 8] Added: HAS_APT_MENTION, HAS_RANSOMWARE_MENTION, HAS_EXFIL_MENTION, HAS_MITRE_REF, HAS_ESCALATION_MENTION")

In [None]:
# create_target_label
features_df = features_df.with_column(
    "IS_HIGH_THREAT",
    F.when(F.col("THREAT_SCORE") >= 0.6, 1).otherwise(0)
)

target_dist = features_df.group_by("IS_HIGH_THREAT").count()
print("Target distribution (IS_HIGH_THREAT):")
target_dist.show()

## 3. Prepare Training Data

Select feature columns and split into train/test sets. The split is done in Snowflake using `SAMPLE` to maintain efficient distributed processing.

In [None]:
# Cell 11: prepare_training_data
# Select feature columns and split into train/test sets
# NOTE: Drops any rows with NULL values to prevent training errors

FEATURE_COLS = [
    "SEVERITY_ENCODED",
    "EVENT_TYPE_RISK",
    "THREAT_CAT_RISK", 
    "SOURCE_RELIABILITY",
    "ACTION_RISK",
    "IS_ESCALATED_INT",
    "EVENT_HOUR",
    "DAY_OF_WEEK",
    "IS_WEEKEND",
    "HAS_APT_MENTION",
    "HAS_RANSOMWARE_MENTION",
    "HAS_EXFIL_MENTION",
    "HAS_MITRE_REF",
    "HAS_ESCALATION_MENTION"
]

LABEL_COL = ["IS_HIGH_THREAT"]

ml_df = features_df.select(FEATURE_COLS + LABEL_COL).dropna()

train_df, test_df = ml_df.random_split([0.8, 0.2], seed=42)

print(f"[Cell 11] Training samples: {train_df.count()}")
print(f"[Cell 11] Test samples: {test_df.count()}")

## 4. Train Model

Train a GradientBoosting classifier using Snowflake ML's built-in modeling API. Training occurs on Snowflake compute.

In [None]:
# Cell 13: train_gradient_boosting_model
# NOTE: Converting to pandas first to avoid UDF permission issues with source table

from sklearn.ensemble import GradientBoostingClassifier as SklearnGBC

train_pd = train_df.to_pandas()
X_train = train_pd[FEATURE_COLS]
y_train = train_pd[LABEL_COL[0]]

print(f"[Cell 13] Training GradientBoosting model on {len(X_train)} samples...")

model = SklearnGBC(
    n_estimators=100,
    max_depth=5,
    learning_rate=0.1,
    random_state=42
)

model.fit(X_train, y_train)
print("[Cell 13] Training complete!")

## 5. Evaluate Model

Generate predictions on test data and calculate classification metrics.

In [None]:
# Cell 15: evaluate_model_predictions
# Using sklearn metrics since we trained with sklearn

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

test_pd = test_df.to_pandas()
X_test = test_pd[FEATURE_COLS]
y_test = test_pd[LABEL_COL[0]]

y_pred = model.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print("=" * 50)
print("[Cell 15] MODEL EVALUATION RESULTS")
print("=" * 50)
print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")
print("=" * 50)

## 6. Register Model to Snowflake Model Registry

Register the trained model with comprehensive metadata including:
- Sample input data for schema inference
- Performance metrics (accuracy, precision, recall, F1)
- Descriptive comment with use case and training context

In [None]:
# Cell 17: register_model_to_registry
# Log sklearn model to Snowflake Model Registry

reg = Registry(
    session=session, 
    database_name="SONY_SECURITY_DEMO", 
    schema_name="SOC"
)

sample_input = train_pd[FEATURE_COLS].head(10)

model_version = reg.log_model(
    model,
    model_name="THREAT_DETECTION_MODEL",
    version_name="v1_0_0",
    sample_input_data=sample_input,
    comment="""
    GradientBoosting threat detection model for Sony Interactive SOC.
    
    Purpose: Predicts high-threat security events (IS_HIGH_THREAT) based on:
    - Severity and event type risk scores
    - Threat category and data source reliability  
    - Temporal features (hour, day of week, weekend)
    - Text indicators from analyst notes (APT, ransomware, exfiltration mentions)
    
    Training Data: SONY_SECURITY_DEMO.SOC.SECURITY_EVENTS (500 events)
    Target: Events with THREAT_SCORE >= 0.6 labeled as high-threat
    Use Case: Batch predictions for SOC prioritization and alerting
    """,
    metrics={
        "accuracy": round(accuracy, 4),
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "f1_score": round(f1, 4)
    }
)

print("\n" + "=" * 50)
print("[Cell 17] MODEL REGISTERED TO SNOWFLAKE")
print("=" * 50)
print(f"Model Name: {model_version.model_name}")
print(f"Version: {model_version.version_name}")
print(f"Registry: SONY_SECURITY_DEMO.SOC")
print("=" * 50)

## 7. Verify Registration & Test Inference

Confirm the model is registered and test inference with a sample high-risk event.

In [None]:
# verify_model_registration
print("Registered Models in SONY_SECURITY_DEMO.SOC:")
models = reg.show_models()
print(models)

print(f"\nModel versions for THREAT_DETECTION_MODEL:")
mv = reg.get_model("THREAT_DETECTION_MODEL")
print(mv.show_versions())

In [None]:
# test_model_inference
print("Testing inference with a high-risk event scenario...")

test_event = session.create_dataframe([[
    4,  # SEVERITY_ENCODED (Critical)
    5,  # EVENT_TYPE_RISK (C2/Lateral movement)
    5,  # THREAT_CAT_RISK (APT)
    5,  # SOURCE_RELIABILITY (EDR)
    4,  # ACTION_RISK (Allowed)
    1,  # IS_ESCALATED_INT
    14, # EVENT_HOUR
    2,  # DAY_OF_WEEK
    0,  # IS_WEEKEND
    1,  # HAS_APT_MENTION
    0,  # HAS_RANSOMWARE_MENTION
    1,  # HAS_EXFIL_MENTION
    1,  # HAS_MITRE_REF
    1   # HAS_ESCALATION_MENTION
]], schema=FEATURE_COLS)

result = model_version.run(test_event, function_name="predict")
print("\nInference result (expected: high-risk prediction):")
result.show()

## 8. Create Feature View for Batch Scoring

Create a view that computes features on-the-fly for all security events, enabling batch scoring with the registered model.

In [None]:
# create_feature_view
create_view_sql = """
CREATE OR REPLACE VIEW SONY_SECURITY_DEMO.SOC.THREAT_MODEL_FEATURES AS
SELECT 
    EVENT_ID,
    TIMESTAMP,
    EVENT_TYPE,
    SEVERITY,
    THREAT_CATEGORY,
    DATA_SOURCE,
    ACTION_TAKEN,
    IS_ESCALATED,
    USER_ID,
    HOSTNAME,
    ANALYST_NOTES,
    THREAT_SCORE AS RULE_BASED_SCORE,
    
    -- Encoded features for model input
    CASE SEVERITY 
        WHEN 'Critical' THEN 4 WHEN 'High' THEN 3 
        WHEN 'Medium' THEN 2 ELSE 1 
    END AS SEVERITY_ENCODED,
    
    CASE EVENT_TYPE 
        WHEN 'C2 communication' THEN 5 WHEN 'Lateral movement' THEN 5 
        WHEN 'Data exfiltration attempt' THEN 5 WHEN 'Privilege escalation' THEN 4 
        WHEN 'Malware detected' THEN 4 WHEN 'Credential theft' THEN 4
        WHEN 'Suspicious process' THEN 3 WHEN 'Unusual access pattern' THEN 3
        WHEN 'Authentication failure' THEN 2 WHEN 'Network scan' THEN 2 
        ELSE 1 
    END AS EVENT_TYPE_RISK,
    
    CASE THREAT_CATEGORY
        WHEN 'APT' THEN 5 WHEN 'Ransomware' THEN 5 
        WHEN 'Data theft' THEN 4 WHEN 'Credential theft' THEN 4
        WHEN 'Malware' THEN 3 WHEN 'Insider threat' THEN 3
        WHEN 'Reconnaissance' THEN 2 WHEN 'Phishing' THEN 2 
        ELSE 1
    END AS THREAT_CAT_RISK,
    
    CASE DATA_SOURCE 
        WHEN 'EDR' THEN 5 WHEN 'SIEM' THEN 5 
        WHEN 'Firewall' THEN 4 WHEN 'IAM' THEN 4 
        WHEN 'Cloud_Audit' THEN 3 ELSE 2 
    END AS SOURCE_RELIABILITY,
    
    CASE ACTION_TAKEN 
        WHEN 'Allowed' THEN 4 WHEN 'Alerted' THEN 3 
        WHEN 'Investigated' THEN 2 WHEN 'Quarantined' THEN 1 
        ELSE 0 
    END AS ACTION_RISK,
    
    CASE WHEN IS_ESCALATED THEN 1 ELSE 0 END AS IS_ESCALATED_INT,
    HOUR(TIMESTAMP) AS EVENT_HOUR,
    DAYOFWEEK(TIMESTAMP) AS DAY_OF_WEEK,
    CASE WHEN DAYOFWEEK(TIMESTAMP) IN (0, 6) THEN 1 ELSE 0 END AS IS_WEEKEND,
    
    -- Text indicators
    CASE WHEN ANALYST_NOTES ILIKE '%APT%' 
              OR ANALYST_NOTES ILIKE '%advanced persistent%' THEN 1 ELSE 0 END AS HAS_APT_MENTION,
    CASE WHEN ANALYST_NOTES ILIKE '%ransomware%' 
              OR ANALYST_NOTES ILIKE '%encryption%' THEN 1 ELSE 0 END AS HAS_RANSOMWARE_MENTION,
    CASE WHEN ANALYST_NOTES ILIKE '%exfiltration%' 
              OR ANALYST_NOTES ILIKE '%data theft%' THEN 1 ELSE 0 END AS HAS_EXFIL_MENTION,
    CASE WHEN ANALYST_NOTES ILIKE '%T1%' THEN 1 ELSE 0 END AS HAS_MITRE_REF,
    CASE WHEN ANALYST_NOTES ILIKE '%escalated%' 
              OR ANALYST_NOTES ILIKE '%Tier 2%' THEN 1 ELSE 0 END AS HAS_ESCALATION_MENTION

FROM SONY_SECURITY_DEMO.SOC.SECURITY_EVENTS
"""

session.sql(create_view_sql).collect()
print("Created feature view: SONY_SECURITY_DEMO.SOC.THREAT_MODEL_FEATURES")

## 9. Score All Events & Create Results Table

Use the registered model to score all security events and persist results for SOC operations.

In [None]:
# batch_score_all_events
print("Scoring all security events with registered model...")

feature_view_df = session.table("SONY_SECURITY_DEMO.SOC.THREAT_MODEL_FEATURES")

features_for_scoring = feature_view_df.select(
    F.col("EVENT_ID"),
    F.col("TIMESTAMP"),
    F.col("EVENT_TYPE"),
    F.col("SEVERITY"),
    F.col("THREAT_CATEGORY"),
    F.col("DATA_SOURCE"),
    F.col("USER_ID"),
    F.col("HOSTNAME"),
    F.col("ANALYST_NOTES"),
    F.col("RULE_BASED_SCORE"),
    *[F.col(c) for c in FEATURE_COLS]
)

scored_df = model_version.run(features_for_scoring, function_name="predict")
scored_df = scored_df.with_column_renamed("PREDICTED_THREAT", "ML_HIGH_THREAT_PREDICTION")
scored_df = scored_df.with_column("SCORED_AT", F.current_timestamp())

print(f"\nScored {scored_df.count()} events")
print("\nTop 10 ML-predicted high-risk events:")
scored_df.filter(F.col("ML_HIGH_THREAT_PREDICTION") == 1).order_by(
    F.col("RULE_BASED_SCORE").desc()
).limit(10).show()

In [None]:
# persist_scored_events
scored_df.write.mode("overwrite").save_as_table("SONY_SECURITY_DEMO.SOC.SCORED_EVENTS")
print("Created scored events table: SONY_SECURITY_DEMO.SOC.SCORED_EVENTS")

## 10. Summary Statistics

Compare ML predictions with rule-based scoring and show final metrics.

In [None]:
# summary_statistics
summary = session.sql("""
    SELECT 
        COUNT(*) AS total_events,
        SUM(ML_HIGH_THREAT_PREDICTION) AS ml_high_threat_count,
        ROUND(SUM(ML_HIGH_THREAT_PREDICTION) * 100.0 / COUNT(*), 2) AS ml_high_threat_pct,
        SUM(CASE WHEN RULE_BASED_SCORE >= 0.6 THEN 1 ELSE 0 END) AS rule_high_threat_count,
        ROUND(AVG(RULE_BASED_SCORE), 3) AS avg_rule_score
    FROM SONY_SECURITY_DEMO.SOC.SCORED_EVENTS
""").to_pandas()

print("\n" + "=" * 60)
print("THREAT DETECTION SUMMARY")
print("=" * 60)
print(f"Total Events Scored: {summary['TOTAL_EVENTS'].values[0]}")
print(f"ML High-Threat Predictions: {summary['ML_HIGH_THREAT_COUNT'].values[0]} ({summary['ML_HIGH_THREAT_PCT'].values[0]}%)")
print(f"Rule-Based High-Threat (>=0.6): {summary['RULE_HIGH_THREAT_COUNT'].values[0]}")
print(f"Average Rule-Based Score: {summary['AVG_RULE_SCORE'].values[0]}")
print("=" * 60)

print("\nNotebook complete! Model registered and all events scored.")
print("\nNext steps:")
print("1. Query SONY_SECURITY_DEMO.SOC.SCORED_EVENTS for ML predictions")
print("2. Use the SECURITY_ANALYST agent to investigate high-threat events")
print("3. Create alerts based on ML_HIGH_THREAT_PREDICTION = 1")