## Step 1: Install PySpark and Import Libraries

In [None]:
# Install PySpark (Colab only)
!pip install pyspark -q

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.mllib.evaluation import MulticlassMetrics
import time
import json
import os
import gc  # Add this import

print("‚úÖ Libraries imported successfully")

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m434.2/434.2 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m203.0/203.0 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pyspark (pyproject.toml) ... [?25l[?25hdone
‚úÖ Libraries imported successfully


## Step 2: Mount Google Drive

In [None]:
# Mount Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_DIR = "/content/drive/MyDrive/NetworkIDS"
    print(f"‚úÖ Google Drive mounted successfully!")
    IS_COLAB = True
except:
    BASE_DIR = "d:/Coding/real-time-network-intrusion-detection-spark-kafka/data"
    print(f"‚úÖ Running locally. Data directory: {BASE_DIR}")
    IS_COLAB = False

# Define paths
DATA_PATH = f"{BASE_DIR}/output/parquet/cicids_merged_harmonized"
MODEL_DIR = f"{BASE_DIR}/output/models"
os.makedirs(MODEL_DIR, exist_ok=True)

print(f"üìÇ Data path: {DATA_PATH}")
print(f"üìÇ Model directory: {MODEL_DIR}")

Mounted at /content/drive
‚úÖ Google Drive mounted successfully!
üìÇ Data path: /content/drive/MyDrive/NetworkIDS/output/parquet/cicids_merged_harmonized
üìÇ Model directory: /content/drive/MyDrive/NetworkIDS/output/models


## Step 3: Create Spark Session

In [None]:
os.environ["PYSPARK_SUBMIT_ARGS"] = "--driver-memory 8g --executor-memory 8g pyspark-shell"


In [None]:
# Create Spark session optimized for ML training with better stability
import gc

# Force garbage collection before creating session
gc.collect()

spark = SparkSession.builder \
    .appName("NIDS-ModelTraining") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.memory", "8g") \
    .config("spark.sql.shuffle.partitions", "100") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.parquet.enableVectorizedReader", "false") \
    .config("spark.driver.maxResultSize", "2g") \
    .config("spark.network.timeout", "800s") \
    .config("spark.executor.heartbeatInterval", "60s") \
    .config("spark.sql.broadcastTimeout", "600") \
    .config("spark.rpc.askTimeout", "600s") \
    .config("spark.storage.memoryFraction", "0.5") \
    .config("spark.memory.fraction", "0.6") \
    .master("local[2]") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")
print(f"‚úÖ Spark session created")
print(f"üìä Spark version: {spark.version}")
print(f"üîß Using 2 cores to reduce memory pressure")

‚úÖ Spark session created
üìä Spark version: 4.0.1
üîß Using 2 cores to reduce memory pressure


## Step 4: Load Harmonized Dataset

In [None]:
# Load the harmonized dataset
print("Loading harmonized dataset...")
start_time = time.time()

df = spark.read.parquet(DATA_PATH)

# Show basic info
print(f"‚úÖ Dataset loaded in {time.time() - start_time:.2f} seconds")
print(f"üìä Columns: {len(df.columns)}")
print(f"\nSchema (key columns):")
for col in ['features_scaled', 'binary_label', 'unified_label', 'sample_weight', 'multiclass_weight']:
    if col in df.columns:
        print(f"  - {col}: {df.schema[col].dataType}")

Loading harmonized dataset...
‚úÖ Dataset loaded in 4.62 seconds
üìä Columns: 34

Schema (key columns):
  - features_scaled: VectorUDT()
  - binary_label: IntegerType()
  - unified_label: IntegerType()
  - sample_weight: DoubleType()
  - multiclass_weight: DoubleType()


In [None]:
# Check label distributions
print("Binary Label Distribution:")
df.groupBy('binary_label').count().show()

print("\nUnified Label Distribution:")
df.groupBy('unified_label').count().orderBy('unified_label').show(10)

Binary Label Distribution:
+------------+--------+
|binary_label|   count|
+------------+--------+
|           1| 2779281|
|           0|15484134|
+------------+--------+


Unified Label Distribution:
+-------------+--------+
|unified_label|   count|
+-------------+--------+
|            0|15484134|
|            1|  699820|
|            2|  705921|
|            3|  165820|
|            4|     928|
|            5|  161095|
|            6|  284263|
|            7|   90819|
|            8|  670615|
+-------------+--------+



## Step 5: Prepare Data for Training

We'll use stratified sampling to maintain class distribution in train/test sets.

In [None]:
# Select only needed columns for training (reduces memory)
df_train = df.select(
    'features_scaled',
    'binary_label',
    'unified_label',
    'sample_weight',
    'multiclass_weight'
)

print(f"‚úÖ Selected {len(df_train.columns)} columns for training")
df_train.printSchema()

‚úÖ Selected 5 columns for training
root
 |-- features_scaled: vector (nullable = true)
 |-- binary_label: integer (nullable = true)
 |-- unified_label: integer (nullable = true)
 |-- sample_weight: double (nullable = true)
 |-- multiclass_weight: double (nullable = true)



In [None]:
# === SPARK HEALTH CHECK - Run before each training ===
def ensure_spark_active():
    """Ensure Spark session is active, recreate if needed"""
    global spark, train_sampled, test_sampled
    try:
        # Test if Spark is alive
        spark.sparkContext.parallelize([1,2,3]).count()
        print("‚úÖ Spark session is active")
        return True
    except Exception as e:
        print(f"‚ö†Ô∏è Spark session dead: {e}")
        print("üîÑ Recreating Spark session...")

        try:
            spark.stop()
        except:
            pass

        gc.collect()

        spark = SparkSession.builder \
            .appName("NIDS-ModelTraining") \
            .config("spark.driver.memory", "8g") \
            .config("spark.executor.memory", "8g") \
            .config("spark.sql.shuffle.partitions", "100") \
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
            .config("spark.sql.parquet.enableVectorizedReader", "false") \
            .config("spark.driver.maxResultSize", "2g") \
            .config("spark.network.timeout", "800s") \
            .config("spark.executor.heartbeatInterval", "60s") \
            .config("spark.memory.fraction", "0.6") \
            .master("local[2]") \
            .getOrCreate()

        spark.sparkContext.setLogLevel("ERROR")

        # Reload data after session recreation
        print("üîÑ Reloading training data...")
        df = spark.read.parquet(DATA_PATH)
        df_train = df.select(
            'features_scaled', 'binary_label', 'unified_label',
            'sample_weight', 'multiclass_weight'
        )
        train_df, test_df = df_train.randomSplit([0.8, 0.2], seed=42)
        train_sampled = train_df.sample(fraction=0.15, seed=42).cache()
        test_sampled = test_df.sample(fraction=0.15, seed=42).cache()
        train_sampled.count()  # Materialize
        test_sampled.count()

        print("‚úÖ New Spark session created and data reloaded")
        return True

In [None]:
# Split data into train/test sets (80/20 split)
print("Splitting data into train/test sets...")

# Perform stratified split on binary_label
train_df, test_df = df_train.randomSplit([0.8, 0.2], seed=42)

# Cache the splits
train_df.cache()
test_df.cache()

# Get counts (triggers caching)
train_total = train_df.count()
test_total = test_df.count()

print(f"‚úÖ Training set: {train_total:,} records")
print(f"‚úÖ Test set: {test_total:,} records")
print(f"üìä Train/Test ratio: {train_total/(train_total+test_total)*100:.1f}% / {test_total/(train_total+test_total)*100:.1f}%")

# Unpersist the original df to free memory
df.unpersist()
gc.collect()

Splitting data into train/test sets...
‚úÖ Training set: 14,608,562 records
‚úÖ Test set: 3,654,853 records
üìä Train/Test ratio: 80.0% / 20.0%


16

In [None]:
# Sample data to reduce memory (KEY FIX for Colab's 12GB limit)
print("üìä Sampling data to fit in memory...")
SAMPLE_FRACTION = 0.15  # Use 15% of data

# Unpersist full datasets before sampling
train_df.unpersist()
test_df.unpersist()
gc.collect()

train_sampled = train_df.sample(fraction=SAMPLE_FRACTION, seed=42)
test_sampled = test_df.sample(fraction=SAMPLE_FRACTION, seed=42)

# Cache sampled data
train_sampled.cache()
test_sampled.cache()

# Force materialization
train_count = train_sampled.count()
test_count = test_sampled.count()

print(f"‚úÖ Training set: {train_count:,} records (sampled)")
print(f"‚úÖ Test set: {test_count:,} records (sampled)")
print(f"üí° Using {SAMPLE_FRACTION*100:.0f}% sample to prevent OOM crashes")

üìä Sampling data to fit in memory...
‚úÖ Training set: 2,191,961 records (sampled)
‚úÖ Test set: 548,492 records (sampled)
üí° Using 15% sample to prevent OOM crashes


In [None]:
# Quick verification - just show sample, skip expensive groupBy on full data
print("Sample from train set:")
train_df.select('binary_label', 'unified_label').show(5)

print("‚úÖ Data ready for training")

Sample from train set:
+------------+-------------+
|binary_label|unified_label|
+------------+-------------+
|           0|            0|
|           0|            0|
|           0|            0|
|           0|            0|
|           0|            0|
+------------+-------------+
only showing top 5 rows
‚úÖ Data ready for training


## Step 6: Train Binary Classification Models

### 6.1 Random Forest - Binary Classification

In [None]:
# Random Forest for Binary Classification - WITH RECOVERY
print("="*60)
print("Training Random Forest - Binary Classification")
print("="*60)

# Ensure Spark is alive before training
ensure_spark_active()

start_time = time.time()

rf_binary = RandomForestClassifier(
    featuresCol='features_scaled',
    labelCol='binary_label',
    weightCol='sample_weight',
    numTrees=30,       # Reduced for stability
    maxDepth=6,        # Reduced for stability
    maxBins=32,
    minInstancesPerNode=10,  # Prevent overfitting, speeds up
    seed=42
)

print("Training model on sampled data...")
try:
    rf_binary_model = rf_binary.fit(train_sampled)
    elapsed = time.time() - start_time
    print(f"‚úÖ Training completed in {elapsed/60:.2f} minutes")
except Exception as e:
    print(f"‚ùå Training failed: {e}")
    print("üí° Try reducing SAMPLE_FRACTION or numTrees")
    raise


Training Random Forest - Binary Classification
‚úÖ Spark session is active
Training model on sampled data...
‚úÖ Training completed in 0.72 minutes


In [None]:
# Evaluate Random Forest - Binary
print("Evaluating Random Forest - Binary Classification...")

ensure_spark_active()

# Predictions on sampled test set
rf_binary_preds = rf_binary_model.transform(test_sampled)

# Binary metrics
binary_evaluator_auc = BinaryClassificationEvaluator(
    labelCol='binary_label',
    rawPredictionCol='rawPrediction',
    metricName='areaUnderROC'
)

binary_evaluator_pr = BinaryClassificationEvaluator(
    labelCol='binary_label',
    rawPredictionCol='rawPrediction',
    metricName='areaUnderPR'
)

multi_evaluator = MulticlassClassificationEvaluator(
    labelCol='binary_label',
    predictionCol='prediction'
)

auc_roc = binary_evaluator_auc.evaluate(rf_binary_preds)
auc_pr = binary_evaluator_pr.evaluate(rf_binary_preds)
accuracy = multi_evaluator.evaluate(rf_binary_preds, {multi_evaluator.metricName: 'accuracy'})
f1 = multi_evaluator.evaluate(rf_binary_preds, {multi_evaluator.metricName: 'f1'})
precision = multi_evaluator.evaluate(rf_binary_preds, {multi_evaluator.metricName: 'weightedPrecision'})
recall = multi_evaluator.evaluate(rf_binary_preds, {multi_evaluator.metricName: 'weightedRecall'})

print("\n" + "="*50)
print("Random Forest - Binary Classification Results")
print("="*50)
print(f"AUC-ROC:   {auc_roc:.4f}")
print(f"AUC-PR:    {auc_pr:.4f}")
print(f"Accuracy:  {accuracy:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")

rf_binary_results = {
    'model': 'Random Forest',
    'task': 'Binary Classification',
    'auc_roc': auc_roc,
    'auc_pr': auc_pr,
    'accuracy': accuracy,
    'f1': f1,
    'precision': precision,
    'recall': recall
}

# Clean up predictions to free memory
rf_binary_preds.unpersist()
import gc
gc.collect()


Evaluating Random Forest - Binary Classification...
‚úÖ Spark session is active

Random Forest - Binary Classification Results
AUC-ROC:   0.9757
AUC-PR:    0.9414
Accuracy:  0.9715
F1 Score:  0.9719
Precision: 0.9728
Recall:    0.9715


732

### 6.2 Gradient Boosted Trees - Binary Classification

In [None]:
# GBT for Binary Classification - WITH RECOVERY
print("="*60)
print("Training Gradient Boosted Trees - Binary Classification")
print("="*60)

ensure_spark_active()
gc.collect()

start_time = time.time()

gbt_binary = GBTClassifier(
    featuresCol='features_scaled',
    labelCol='binary_label',
    weightCol='sample_weight',
    maxIter=20,        # Reduced for stability
    maxDepth=5,        # Reduced for stability
    seed=42
)

print("Training model on sampled data...")
try:
    gbt_binary_model = gbt_binary.fit(train_sampled)
    elapsed = time.time() - start_time
    print(f"‚úÖ Training completed in {elapsed/60:.2f} minutes")
except Exception as e:
    print(f"‚ùå Training failed: {e}")
    print("üí° Try reducing maxIter or maxDepth")
    raise


Training Gradient Boosted Trees - Binary Classification
‚úÖ Spark session is active
Training model on sampled data...
‚úÖ Training completed in 1.46 minutes


In [None]:
# Evaluate GBT - Binary
print("Evaluating GBT - Binary Classification...")

ensure_spark_active()

gbt_binary_preds = gbt_binary_model.transform(test_sampled)

auc_roc = binary_evaluator_auc.evaluate(gbt_binary_preds)
auc_pr = binary_evaluator_pr.evaluate(gbt_binary_preds)
accuracy = multi_evaluator.evaluate(gbt_binary_preds, {multi_evaluator.metricName: 'accuracy'})
f1 = multi_evaluator.evaluate(gbt_binary_preds, {multi_evaluator.metricName: 'f1'})
precision = multi_evaluator.evaluate(gbt_binary_preds, {multi_evaluator.metricName: 'weightedPrecision'})
recall = multi_evaluator.evaluate(gbt_binary_preds, {multi_evaluator.metricName: 'weightedRecall'})

print("\n" + "="*50)
print("GBT - Binary Classification Results")
print("="*50)
print(f"AUC-ROC:   {auc_roc:.4f}")
print(f"AUC-PR:    {auc_pr:.4f}")
print(f"Accuracy:  {accuracy:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")

gbt_binary_results = {
    'model': 'Gradient Boosted Trees',
    'task': 'Binary Classification',
    'auc_roc': auc_roc,
    'auc_pr': auc_pr,
    'accuracy': accuracy,
    'f1': f1,
    'precision': precision,
    'recall': recall
}

# Cleanup
gbt_binary_preds.unpersist()
gc.collect()

Evaluating GBT - Binary Classification...
‚úÖ Spark session is active

GBT - Binary Classification Results
AUC-ROC:   0.9852
AUC-PR:    0.9713
Accuracy:  0.9817
F1 Score:  0.9818
Precision: 0.9819
Recall:    0.9817


336

## Step 7: Train Multi-class Classification Models

### 7.1 Random Forest - Multi-class (9 attack types)

In [None]:

# Random Forest for Multi-class Classification - WITH RECOVERY & OPTIMIZATION
print("="*60)
print("Training Random Forest - Multi-class Classification (9 classes)")
print("="*60)

ensure_spark_active()
gc.collect()

start_time = time.time()

rf_multi = RandomForestClassifier(
    featuresCol='features_scaled',
    labelCol='unified_label',
    weightCol='multiclass_weight',
    numTrees=30,           # Reduced from 50
    maxDepth=8,            # Reduced from 10
    maxBins=32,
    minInstancesPerNode=10,  # Prevents overfitting, speeds training
    minInfoGain=0.001,       # Skip splits with minimal gain
    seed=42
)

print("Training model on sampled data...")
try:
    rf_multi_model = rf_multi.fit(train_sampled)
    elapsed = time.time() - start_time
    print(f"‚úÖ Training completed in {elapsed/60:.2f} minutes")
except Exception as e:
    print(f"‚ùå Training failed: {e}")
    print("üí° Try reducing numTrees to 20 or maxDepth to 6")
    raise



Training Random Forest - Multi-class Classification (9 classes)
‚úÖ Spark session is active
Training model on sampled data...
‚úÖ Training completed in 0.96 minutes


In [None]:
# Evaluate Random Forest - Multi-class
print("Evaluating Random Forest - Multi-class Classification...")

ensure_spark_active()

rf_multi_preds = rf_multi_model.transform(test_sampled)

# Create evaluator specifically for multi-class with unified_label
mc_evaluator = MulticlassClassificationEvaluator(
    labelCol='unified_label',
    predictionCol='prediction'
)

accuracy = mc_evaluator.evaluate(rf_multi_preds, {mc_evaluator.metricName: 'accuracy'})
f1 = mc_evaluator.evaluate(rf_multi_preds, {mc_evaluator.metricName: 'f1'})
precision = mc_evaluator.evaluate(rf_multi_preds, {mc_evaluator.metricName: 'weightedPrecision'})
recall = mc_evaluator.evaluate(rf_multi_preds, {mc_evaluator.metricName: 'weightedRecall'})

print("\n" + "="*50)
print("Random Forest - Multi-class Classification Results")
print("="*50)
print(f"Accuracy:  {accuracy:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")

rf_multi_results = {
    'model': 'Random Forest',
    'task': 'Multi-class Classification (9 classes)',
    'accuracy': accuracy,
    'f1': f1,
    'precision': precision,
    'recall': recall
}

# Cleanup
rf_multi_preds.unpersist()
gc.collect()

Evaluating Random Forest - Multi-class Classification...
‚úÖ Spark session is active

Random Forest - Multi-class Classification Results
Accuracy:  0.5573
F1 Score:  0.6813
Precision: 0.9666
Recall:    0.5573


90

In [None]:

# Confusion Matrix for Multi-class (regenerate predictions for display)
print("\nGenerating Confusion Matrix...")
ensure_spark_active()

rf_multi_preds = rf_multi_model.transform(test_sampled)

print("Per-class Prediction Accuracy:")
rf_multi_preds.groupBy('unified_label') \
    .agg(
        F.count('*').alias('total'),
        F.sum(F.when(F.col('prediction') == F.col('unified_label'), 1).otherwise(0)).alias('correct')
    ) \
    .withColumn('accuracy', F.round(F.col('correct') / F.col('total'), 4)) \
    .orderBy('unified_label').show()

rf_multi_preds.unpersist()
gc.collect()



Generating Confusion Matrix...
‚úÖ Spark session is active
Per-class Prediction Accuracy:
+-------------+------+-------+--------+
|unified_label| total|correct|accuracy|
+-------------+------+-------+--------+
|            0|464986| 224359|  0.4825|
|            1| 21029|  20331|  0.9668|
|            2| 21303|  21265|  0.9982|
|            3|  4933|   4929|  0.9992|
|            4|    30|     27|     0.9|
|            5|  4867|   3915|  0.8044|
|            6|  8539|   8478|  0.9929|
|            7|  2686|   2677|  0.9966|
|            8| 20119|  19687|  0.9785|
+-------------+------+-------+--------+



395

## Step 8: Save Trained Models

In [None]:
# Re-verify Google Drive connection before saving
if IS_COLAB:
    try:
        os.listdir(BASE_DIR)
        print("‚úÖ Google Drive connection verified")
    except:
        print("‚ö†Ô∏è Drive disconnected! Remounting...")
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        print("‚úÖ Drive remounted successfully")

‚úÖ Google Drive connection verified


In [None]:
# Save models
print("Saving trained models...")

# Save Random Forest - Binary
rf_binary_path = f"{MODEL_DIR}/rf_binary_classifier"
rf_binary_model.write().overwrite().save(rf_binary_path)
print(f"‚úÖ Saved: {rf_binary_path}")

# Save GBT - Binary
gbt_binary_path = f"{MODEL_DIR}/gbt_binary_classifier"
gbt_binary_model.write().overwrite().save(gbt_binary_path)
print(f"‚úÖ Saved: {gbt_binary_path}")

# Save Random Forest - Multi-class
rf_multi_path = f"{MODEL_DIR}/rf_multiclass_classifier"
rf_multi_model.write().overwrite().save(rf_multi_path)
print(f"‚úÖ Saved: {rf_multi_path}")

Saving trained models...
‚úÖ Saved: /content/drive/MyDrive/NetworkIDS/output/models/rf_binary_classifier
‚úÖ Saved: /content/drive/MyDrive/NetworkIDS/output/models/gbt_binary_classifier
‚úÖ Saved: /content/drive/MyDrive/NetworkIDS/output/models/rf_multiclass_classifier


In [None]:
# Save training results summary
all_results = {
    'rf_binary': rf_binary_results,
    'gbt_binary': gbt_binary_results,
    'rf_multiclass': rf_multi_results,
    'train_size': train_count,
    'test_size': test_count,
    'total_records': train_count + test_count
}

results_path = f"{MODEL_DIR}/training_results.json"
with open(results_path, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"\n‚úÖ Results saved to: {results_path}")


‚úÖ Results saved to: /content/drive/MyDrive/NetworkIDS/output/models/training_results.json


## Step 9: Model Comparison Summary

In [None]:
# Print final comparison
print("\n" + "="*70)
print("MODEL TRAINING SUMMARY")
print("="*70)

print("\nüìä BINARY CLASSIFICATION (Attack vs Benign)")
print("-"*70)
print(f"{'Model':<25} {'AUC-ROC':<10} {'Accuracy':<10} {'F1':<10} {'Precision':<10} {'Recall':<10}")
print("-"*70)
print(f"{'Random Forest':<25} {rf_binary_results['auc_roc']:<10.4f} {rf_binary_results['accuracy']:<10.4f} {rf_binary_results['f1']:<10.4f} {rf_binary_results['precision']:<10.4f} {rf_binary_results['recall']:<10.4f}")
print(f"{'Gradient Boosted Trees':<25} {gbt_binary_results['auc_roc']:<10.4f} {gbt_binary_results['accuracy']:<10.4f} {gbt_binary_results['f1']:<10.4f} {gbt_binary_results['precision']:<10.4f} {gbt_binary_results['recall']:<10.4f}")

print("\nüìä MULTI-CLASS CLASSIFICATION (9 Attack Types)")
print("-"*70)
print(f"{'Model':<25} {'Accuracy':<10} {'F1':<10} {'Precision':<10} {'Recall':<10}")
print("-"*70)
print(f"{'Random Forest':<25} {rf_multi_results['accuracy']:<10.4f} {rf_multi_results['f1']:<10.4f} {rf_multi_results['precision']:<10.4f} {rf_multi_results['recall']:<10.4f}")

print("\n" + "="*70)
print("‚úÖ All models trained and saved successfully!")
print(f"üìÅ Models location: {MODEL_DIR}")
print("="*70)


MODEL TRAINING SUMMARY

üìä BINARY CLASSIFICATION (Attack vs Benign)
----------------------------------------------------------------------
Model                     AUC-ROC    Accuracy   F1         Precision  Recall    
----------------------------------------------------------------------
Random Forest             0.9757     0.9715     0.9719     0.9728     0.9715    
Gradient Boosted Trees    0.9852     0.9817     0.9818     0.9819     0.9817    

üìä MULTI-CLASS CLASSIFICATION (9 Attack Types)
----------------------------------------------------------------------
Model                     Accuracy   F1         Precision  Recall    
----------------------------------------------------------------------
Random Forest             0.5573     0.6813     0.9666     0.5573    

‚úÖ All models trained and saved successfully!
üìÅ Models location: /content/drive/MyDrive/NetworkIDS/output/models


## Step 10: Feature Importance (Optional)

In [None]:
# Get feature importance from Random Forest
print("Top 20 Most Important Features (Random Forest - Binary):")
print("="*50)

importances = rf_binary_model.featureImportances.toArray()

# Create feature importance list
feature_importance = [(i, imp) for i, imp in enumerate(importances)]
feature_importance.sort(key=lambda x: x[1], reverse=True)

print(f"{'Rank':<6} {'Feature Index':<15} {'Importance':<12}")
print("-"*35)
for rank, (idx, imp) in enumerate(feature_importance[:20], 1):
    print(f"{rank:<6} {idx:<15} {imp:.6f}")

Top 20 Most Important Features (Random Forest - Binary):
Rank   Feature Index   Importance  
-----------------------------------
1      0               0.081287
2      66              0.050510
3      35              0.048572
4      7               0.045518
5      20              0.037889
6      64              0.034247
7      25              0.032383
8      21              0.031753
9      40              0.029736
10     38              0.026515
11     63              0.026513
12     24              0.026096
13     5               0.026089
14     69              0.025748
15     52              0.024871
16     19              0.024671
17     13              0.021418
18     9               0.020865
19     46              0.020662
20     36              0.020216


## Summary

### Models Trained:
1. **Random Forest - Binary** (`rf_binary_classifier`)
   - Task: Attack vs Benign
   - Use case: Quick attack detection

2. **Gradient Boosted Trees - Binary** (`gbt_binary_classifier`)
   - Task: Attack vs Benign
   - Use case: Higher accuracy attack detection

3. **Random Forest - Multi-class** (`rf_multiclass_classifier`)
   - Task: Identify specific attack type (9 classes)
   - Use case: Detailed threat classification

### Saved Artifacts:
- Models: `/content/drive/MyDrive/NetworkIDS/output/models/`
- Results: `training_results.json`

### Next Steps:
1. Deploy models for real-time inference
2. Integrate with Kafka streaming pipeline
3. Build alerting/monitoring dashboard

In [None]:
# Cleanup
print("Cleaning up...")
try:
    train_sampled.unpersist()
    test_sampled.unpersist()
except:
    pass

gc.collect()
spark.stop()
print("‚úÖ Spark session stopped")
print("\nüéâ Model training complete! Ready for deployment.")

Cleaning up...
‚úÖ Spark session stopped

üéâ Model training complete! Ready for deployment.


**Reasoning**:
To optimize the Random Forest binary classification evaluation and prevent `SparkOutOfMemoryError`, I will replace the existing code in cell `70b75119` with the provided code that incorporates sampling for AUC-ROC and AUC-PR calculations while retaining full dataset evaluation for other metrics.



In [None]:
# === RELOAD SESSION AND DATA ===
import gc
gc.collect()

!pip install pyspark -q

# Recreate Spark session
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import time
import os

spark = SparkSession.builder \
    .appName("NIDS-ModelTraining") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.memory", "8g") \
    .config("spark.sql.shuffle.partitions", "100") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.parquet.enableVectorizedReader", "false") \
    .config("spark.driver.maxResultSize", "2g") \
    .config("spark.network.timeout", "800s") \
    .config("spark.executor.heartbeatInterval", "60s") \
    .config("spark.memory.fraction", "0.6") \
    .master("local[2]") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")
print("‚úÖ Spark session created")

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m434.2/434.2 MB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m203.0/203.0 kB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pyspark (pyproject.toml) ... [?25l[?25hdone
‚úÖ Spark session created


In [None]:
# === MOUNT GOOGLE DRIVE FIRST ===
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted")

# === VERIFY DATA PATH ===
import os

DATA_PATH = "/content/drive/MyDrive/NetworkIDS/output/parquet/cicids_merged_harmonized"
MODEL_DIR = "/content/drive/MyDrive/NetworkIDS/output/models"

# Check if path exists
if os.path.exists(DATA_PATH):
    print(f"‚úÖ Data path exists: {DATA_PATH}")
    # List contents
    print(f"üìÅ Contents: {os.listdir(DATA_PATH)[:5]}...")
else:
    print(f"‚ùå Data path NOT found: {DATA_PATH}")
    # List what's available
    base = "/content/drive/MyDrive/NetworkIDS"
    if os.path.exists(base):
        print(f"\nüìÇ Available in {base}:")
        for item in os.listdir(base):
            print(f"  - {item}")
        output_path = f"{base}/output"
        if os.path.exists(output_path):
            print(f"\nüìÇ Available in {output_path}:")
            for item in os.listdir(output_path):
                print(f"  - {item}")
    else:
        print(f"‚ùå Base directory not found: {base}")
        print("\nüìÇ Available in MyDrive:")
        print(os.listdir("/content/drive/MyDrive")[:10])


Mounted at /content/drive
‚úÖ Google Drive mounted
‚úÖ Data path exists: /content/drive/MyDrive/NetworkIDS/output/parquet/cicids_merged_harmonized
üìÅ Contents: ['part-00000-6b64f6cd-5f77-487e-bf7f-d3220977a1b6-c000.snappy.parquet', '.part-00000-6b64f6cd-5f77-487e-bf7f-d3220977a1b6-c000.snappy.parquet.crc', 'part-00010-6b64f6cd-5f77-487e-bf7f-d3220977a1b6-c000.snappy.parquet', 'part-00003-6b64f6cd-5f77-487e-bf7f-d3220977a1b6-c000.snappy.parquet', 'part-00007-6b64f6cd-5f77-487e-bf7f-d3220977a1b6-c000.snappy.parquet']...


In [None]:
# === RELOAD DATA ===
print("Loading data...")
df = spark.read.parquet(DATA_PATH)

df_train = df.select(
    'features_scaled',
    'binary_label',
    'unified_label',
    'sample_weight',
    'multiclass_weight'
)

# Split
train_df, test_df = df_train.randomSplit([0.8, 0.2], seed=42)

# Sample (same fraction as before)
SAMPLE_FRACTION = 0.15
train_sampled = train_df.sample(fraction=SAMPLE_FRACTION, seed=42).cache()
test_sampled = test_df.sample(fraction=SAMPLE_FRACTION, seed=42).cache()

train_count = train_sampled.count()
test_count = test_sampled.count()

print(f"‚úÖ Training set: {train_count:,} samples")
print(f"‚úÖ Test set: {test_count:,} samples")

# Clean up
df.unpersist()
gc.collect()

Loading data...
‚úÖ Training set: 2,191,961 samples
‚úÖ Test set: 548,492 samples


576

In [None]:
# === DIAGNOSTIC: Check class imbalance ===
print("="*60)
print("CLASS DISTRIBUTION ANALYSIS")
print("="*60)

print("\nüìä Training Set Class Distribution:")
train_sampled.groupBy('unified_label').count().orderBy('unified_label').show()

print("\nüìä Test Set Class Distribution:")
test_sampled.groupBy('unified_label').count().orderBy('unified_label').show()


CLASS DISTRIBUTION ANALYSIS

üìä Training Set Class Distribution:
+-------------+-------+
|unified_label|  count|
+-------------+-------+
|            0|1859701|
|            1|  83895|
|            2|  84335|
|            3|  19958|
|            4|    122|
|            5|  19264|
|            6|  33676|
|            7|  10962|
|            8|  80048|
+-------------+-------+


üìä Test Set Class Distribution:
+-------------+------+
|unified_label| count|
+-------------+------+
|            0|464986|
|            1| 21029|
|            2| 21303|
|            3|  4933|
|            4|    30|
|            5|  4867|
|            6|  8539|
|            7|  2686|
|            8| 20119|
+-------------+------+



In [None]:
# === IMPROVED CLASS WEIGHTS ===
print("="*60)
print("Calculating improved class weights...")
print("="*60)

class_counts = train_sampled.groupBy('unified_label').count().collect()
total_samples = train_sampled.count()
num_classes = len(class_counts)

# Inverse frequency with sqrt smoothing
class_weight_map = {}
for row in class_counts:
    label = row['unified_label']
    count = row['count']
    weight = (total_samples / (num_classes * count)) ** 0.5
    class_weight_map[label] = weight
    print(f"  Class {label}: {count:,} samples -> weight {weight:.4f}")

# Apply weights
from pyspark.sql.functions import when, col, lit

weight_expr = lit(1.0)
for label, weight in class_weight_map.items():
    weight_expr = when(col('unified_label') == label, weight).otherwise(weight_expr)

train_reweighted = train_sampled.withColumn('improved_weight', weight_expr)
train_reweighted.cache()
train_reweighted.count()

print("\n‚úÖ Improved weights applied")


Calculating improved class weights...
  Class 8: 80,048 samples -> weight 1.7443
  Class 7: 10,962 samples -> weight 4.7136
  Class 1: 83,895 samples -> weight 1.7038
  Class 6: 33,676 samples -> weight 2.6893
  Class 3: 19,958 samples -> weight 3.4933
  Class 5: 19,264 samples -> weight 3.5557
  Class 2: 84,335 samples -> weight 1.6994
  Class 0: 1,859,701 samples -> weight 0.3619
  Class 4: 122 samples -> weight 44.6802

‚úÖ Improved weights applied


In [None]:
# === TRAIN IMPROVED MULTI-CLASS MODEL ===
print("="*60)
print("Training IMPROVED Random Forest - Multi-class")
print("="*60)

start_time = time.time()

rf_multi_improved = RandomForestClassifier(
    featuresCol='features_scaled',
    labelCol='unified_label',
    weightCol='improved_weight',
    numTrees=50,
    maxDepth=12,
    maxBins=64,
    minInstancesPerNode=5,
    minInfoGain=0.0,
    featureSubsetStrategy='sqrt',
    seed=42
)

print("üöÄ Training...")
rf_multi_improved_model = rf_multi_improved.fit(train_reweighted)
elapsed = time.time() - start_time
print(f"‚úÖ Training completed in {elapsed/60:.2f} minutes")

train_reweighted.unpersist()
gc.collect()

Training IMPROVED Random Forest - Multi-class
üöÄ Training...
‚úÖ Training completed in 3.78 minutes


260

In [None]:
# === EVALUATE IMPROVED MODEL ===
print("="*60)
print("Evaluating Improved Multi-class Model")
print("="*60)

preds = rf_multi_improved_model.transform(test_sampled)

mc_evaluator = MulticlassClassificationEvaluator(
    labelCol='unified_label',
    predictionCol='prediction'
)

accuracy = mc_evaluator.evaluate(preds, {mc_evaluator.metricName: 'accuracy'})
f1 = mc_evaluator.evaluate(preds, {mc_evaluator.metricName: 'f1'})
precision = mc_evaluator.evaluate(preds, {mc_evaluator.metricName: 'weightedPrecision'})
recall = mc_evaluator.evaluate(preds, {mc_evaluator.metricName: 'weightedRecall'})

print(f"\nAccuracy:  {accuracy:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")

# Compare with original (0.5573 accuracy, 0.6813 F1, 0.5573 recall)
print("\nüìä IMPROVEMENT vs ORIGINAL:")
print(f"  Accuracy: 0.5573 -> {accuracy:.4f} ({(accuracy-0.5573)*100:+.2f}%)")
print(f"  F1 Score: 0.6813 -> {f1:.4f} ({(f1-0.6813)*100:+.2f}%)")
print(f"  Recall:   0.5573 -> {recall:.4f} ({(recall-0.5573)*100:+.2f}%)")

# Per-class recall
print("\nüìä Per-Class Recall:")
preds.withColumn('correct', F.when(F.col('unified_label') == F.col('prediction'), 1).otherwise(0)) \
    .groupBy('unified_label').agg(
        F.count('*').alias('total'),
        F.sum('correct').alias('correct'),
        F.round(F.sum('correct') / F.count('*'), 4).alias('recall')
    ).orderBy('unified_label').show()

preds.unpersist()
gc.collect()

Evaluating Improved Multi-class Model

Accuracy:  0.9870
F1 Score:  0.9858
Precision: 0.9851
Recall:    0.9870

üìä IMPROVEMENT vs ORIGINAL:
  Accuracy: 0.5573 -> 0.9870 (+42.97%)
  F1 Score: 0.6813 -> 0.9858 (+30.45%)
  Recall:   0.5573 -> 0.9870 (+42.97%)

üìä Per-Class Recall:
+-------------+------+-------+------+
|unified_label| total|correct|recall|
+-------------+------+-------+------+
|            0|464986| 462475|0.9946|
|            1| 21029|  20343|0.9674|
|            2| 21303|  21278|0.9988|
|            3|  4933|   4924|0.9982|
|            4|    30|     18|   0.6|
|            5|  4867|   1091|0.2242|
|            6|  8539|   8500|0.9954|
|            7|  2686|   2681|0.9981|
|            8| 20119|  20035|0.9958|
+-------------+------+-------+------+



199

In [None]:

# === SAVE IMPROVED MODEL ===
improved_model_path = f"{MODEL_DIR}/rf_multiclass_improved"
rf_multi_improved_model.write().overwrite().save(improved_model_path)
print(f"‚úÖ Improved model saved to: {improved_model_path}")

‚úÖ Improved model saved to: /content/drive/MyDrive/NetworkIDS/output/models/rf_multiclass_improved
