In [1]:
#!/usr/bin/env python3
"""
full_scatter_attack_vs_noattack.py

Create a scatter plot using 2 randomly chosen numeric columns (from the dataset)
where point color is determined by attack (1) vs no-attack (0). The script:

 - Scans available CSV feature files to discover numeric columns (first chunk of each file).
 - Randomly selects two numeric columns (seeded for reproducibility).
 - Streams every CSV chunk, extracts the chosen columns and the attack label per-row
   (label: 0 for "normal" file, 1 for others) and collects all rows.
 - If the total number of points becomes extremely large, the script will downsample
   to a SAFETY_MAX number of points to avoid running out of memory / crashing.
 - Produces a scatter plot saved to OUTPUT_DIR/combined_scatter.png.
 - Cleans up common temporary files (*.tmp, *.temp) in system temp dir and base_path.

NOTE:
 - This script attempts to plot *all* data points, but includes a safety cap to avoid
   OOM in practical environments. Adjust SAFETY_MAX if you have sufficient memory.
"""

import os
import glob
import tempfile
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ------------------------------
# Config - adjust as needed
# ------------------------------
base_path = "./"   # root folder where feature folders live
folders = {
    "packet": "packet_features",
    "uniflow": "uniflow_features",
    "biflow": "biflow_features"
}
files = {
    "normal": "normal.csv",
    "sparta": "sparta.csv",
    "scan_A": "scan_A.csv",
    "mqtt_bruteforce": "mqtt_bruteforce.csv",
    "scan_sU": "scan_sU.csv"
}
def build_filenames(prefix):
    return {
        "normal": f"{prefix}_normal.csv",
        "sparta": f"{prefix}_sparta.csv",
        "scan_A": f"{prefix}_scan_A.csv",
        "mqtt_bruteforce": f"{prefix}_mqtt_bruteforce.csv",
        "scan_sU": f"{prefix}_scan_sU.csv"
    }

feature_files = {
    "packet": files,
    "uniflow": build_filenames("uniflow"),
    "biflow": build_filenames("biflow")
}

OUTPUT_DIR = "outputs_scatter"
os.makedirs(OUTPUT_DIR, exist_ok=True)

CHUNK_SIZE = 100000       # pandas read_csv chunksize for streaming
RANDOM_SEED = 42

# Safety cap: if there are more than SAFETY_MAX points, we will downsample to SAFETY_MAX before plotting.
# Set to None to attempt plotting all (dangerous if dataset is huge).
SAFETY_MAX = 3_000_000

np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# ------------------------------
# Helpers
# ------------------------------
def discover_numeric_columns(feature_files_map, base_path, chunksize=CHUNK_SIZE, max_files_first_chunks=50):
    """
    Inspect the first chunk of each CSV file (or until enough inspected) and build
    the union of numeric columns observed across files. Return a sorted list.
    """
    numeric_cols_set = set()
    files_inspected = 0

    for level, file_dict in feature_files_map.items():
        folder_path = os.path.join(base_path, folders[level])
        for name, fname in file_dict.items():
            fpath = os.path.join(folder_path, fname)
            if not os.path.exists(fpath):
                continue
            try:
                # Read only the first chunk from this file
                for chunk in pd.read_csv(fpath, chunksize=chunksize, low_memory=False):
                    nc = chunk.select_dtypes(include=[np.number]).columns.tolist()
                    # avoid adding obvious label columns that are not features
                    nc = [c for c in nc if c.lower() not in ("label", "attack", "attack_type")]
                    numeric_cols_set.update(nc)
                    files_inspected += 1
                    break
            except Exception as e:
                print(f"[WARN] Could not read first chunk of {fpath}: {e}")
                continue
            if files_inspected >= max_files_first_chunks:
                break
        if files_inspected >= max_files_first_chunks:
            break

    numeric_cols = sorted(list(numeric_cols_set))
    return numeric_cols

def pick_two_random_numeric_columns(numeric_cols):
    """Pick two distinct numeric columns randomly (seeded)."""
    if len(numeric_cols) < 2:
        return None, None
    cols = random.sample(numeric_cols, 2)
    return cols[0], cols[1]

def stream_collect_all(feature_files_map, base_path, x_col, y_col, chunksize=CHUNK_SIZE, safety_max=SAFETY_MAX):
    """
    Stream all CSVs chunk-by-chunk and collect the chosen x_col and y_col values and labels.
    Returns arrays X (n x 2) and y (n,), possibly downsampled to safety_max.
    """
    xs = []
    ys = []
    labels = []
    total = 0

    for level, file_dict in feature_files_map.items():
        folder_path = os.path.join(base_path, folders[level])
        for name, fname in file_dict.items():
            fpath = os.path.join(folder_path, fname)
            if not os.path.exists(fpath):
                continue
            label_value = 0 if name == "normal" else 1
            try:
                for chunk in pd.read_csv(fpath, chunksize=chunksize, low_memory=False):
                    # if the desired columns aren't in this chunk, try selecting first two numeric columns
                    if x_col not in chunk.columns or y_col not in chunk.columns:
                        numeric_cols = chunk.select_dtypes(include=[np.number]).columns.tolist()
                        numeric_cols = [c for c in numeric_cols if c.lower() not in ("label", "attack", "attack_type")]
                        if len(numeric_cols) >= 2:
                            use_x, use_y = numeric_cols[0], numeric_cols[1]
                        else:
                            continue  # skip chunk if no usable numeric features
                    else:
                        use_x, use_y = x_col, y_col

                    sub = chunk[[use_x, use_y]].copy()
                    sub = sub.dropna()
                    if sub.shape[0] == 0:
                        continue

                    # extend lists
                    xs.extend(sub.iloc[:, 0].astype(np.float64).tolist())
                    ys.extend(sub.iloc[:, 1].astype(np.float64).tolist())
                    labels.extend([label_value] * len(sub))
                    total += len(sub)

                    # safety: if we've exceeded safety_max massively, we can stop early and downsample later
                    if safety_max is not None and total > (safety_max * 5):
                        # stop early to avoid running out of memory collecting an enormous buffer
                        print(f"[WARN] Collected {total} points; exceeding safe accumulation multiple. Stopping collection early.")
                        break
                # if stop condition from inner loop triggered, break outer loops
                if safety_max is not None and total > (safety_max * 5):
                    break
            except pd.errors.EmptyDataError:
                print(f"[WARN] {fpath} is empty; skipping.")
            except Exception as e:
                print(f"[WARN] Error processing {fpath}: {e}")
        if safety_max is not None and total > (safety_max * 5):
            break

    if total == 0:
        raise RuntimeError("No numeric data collected. Check CSVs and numeric columns.")

    X = np.column_stack((np.array(xs), np.array(ys)))
    y = np.array(labels, dtype=np.int32)
    print(f"[COLLECT] Collected {X.shape[0]} points in total across files.")

    # Downsample to safety_max if needed
    if safety_max is not None and X.shape[0] > safety_max:
        print(f"[DOWNSAMPLE] Downsampling from {X.shape[0]} to {safety_max} points for safety.")
        idxs = np.random.choice(X.shape[0], size=safety_max, replace=False)
        X = X[idxs]
        y = y[idxs]

    return X, y

def plot_scatter_full(X, y, x_label, y_label, out_path):
    """Plot scatter colored by label and save to out_path."""
    plt.figure(figsize=(10, 8))
    idx_no = (y == 0)
    idx_yes = (y == 1)

    # Plot no attack first (blue), attacks on top (red) for visibility
    plt.scatter(X[idx_no, 0], X[idx_no, 1], c="tab:blue", label="No Attack", alpha=0.6, s=6)
    plt.scatter(X[idx_yes, 0], X[idx_yes, 1], c="tab:red", label="Attack", alpha=0.6, s=6)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(f"Scatter plot: {y_label} vs {x_label} (Attack=red, No Attack=blue)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"[SAVED] Scatter plot saved to: {out_path}")

def cleanup_temp_files(base_dirs=None):
    """Remove common temporary files from provided directories (tempdir + base_path by default)."""
    if base_dirs is None:
        base_dirs = [tempfile.gettempdir(), base_path]
    patterns = ["*.tmp", "*.temp", "tmp*"]
    removed = 0
    for d in base_dirs:
        for pat in patterns:
            for f in glob.glob(os.path.join(d, pat)):
                try:
                    if os.path.isfile(f):
                        os.remove(f)
                        removed += 1
                        print(f"[CLEANUP] Removed temp file: {f}")
                except Exception as e:
                    print(f"[CLEANUP] Could not remove {f}: {e}")
    if removed == 0:
        print("[CLEANUP] No temp files found to remove.")
    else:
        print(f"[CLEANUP] Removed {removed} temp files.")

# ------------------------------
# Main
# ------------------------------
if __name__ == "__main__":
    print("== Start: Full scatter plot (attack vs no-attack) ==")

    # 1) Discover numeric columns across files
    numeric_cols = discover_numeric_columns(feature_files, base_path, chunksize=CHUNK_SIZE)
    if not numeric_cols or len(numeric_cols) < 2:
        print("[ERROR] Not enough numeric columns discovered across dataset to make a scatter plot.")
        raise SystemExit(1)
    print(f"[DISCOVERED] Numeric columns (sample): {numeric_cols[:20]} (total {len(numeric_cols)})")

    # 2) Randomly pick two numeric columns
    x_col, y_col = pick_two_random_numeric_columns(numeric_cols)
    if x_col is None or y_col is None:
        print("[ERROR] Unable to pick two numeric columns.")
        raise SystemExit(1)
    print(f"[SELECTED] Random numeric columns chosen for plotting: X='{x_col}', Y='{y_col}'")

    # 3) Stream and collect all rows for the two columns (with safety)
    X_all, y_all = stream_collect_all(feature_files, base_path, x_col, y_col, chunksize=CHUNK_SIZE, safety_max=SAFETY_MAX)

    # 4) Plot
    out_file = os.path.join(OUTPUT_DIR, "combined_scatter.png")
    plot_scatter_full(X_all, y_all, x_col, y_col, out_file)

    # 5) Cleanup temporary files
    cleanup_temp_files([tempfile.gettempdir(), base_path])

    print("== Done ==")

== Start: Full scatter plot (attack vs no-attack) ==
[DISCOVERED] Numeric columns (sample): ['bwd_max_iat', 'bwd_max_pkt_len', 'bwd_mean_iat', 'bwd_mean_pkt_len', 'bwd_min_iat', 'bwd_min_pkt_len', 'bwd_num_bytes', 'bwd_num_pkts', 'bwd_num_psh_flags', 'bwd_num_rst_flags', 'bwd_num_urg_flags', 'bwd_std_iat', 'bwd_std_pkt_len', 'dst_port', 'fwd_max_iat', 'fwd_max_pkt_len', 'fwd_mean_iat', 'fwd_mean_pkt_len', 'fwd_min_iat', 'fwd_min_pkt_len'] (total 69)
[SELECTED] Random numeric columns chosen for plotting: X='fwd_max_iat', Y='bwd_mean_pkt_len'
[WARN] Collected 15099994 points; exceeding safe accumulation multiple. Stopping collection early.
[COLLECT] Collected 15099994 points in total across files.
[DOWNSAMPLE] Downsampling from 15099994 to 3000000 points for safety.


  plt.tight_layout()
  plt.savefig(out_path, dpi=200)


[SAVED] Scatter plot saved to: outputs_scatter\combined_scatter.png
[CLEANUP] Could not remove C:\Users\VALMIK~1\AppData\Local\Temp\0bfaf4d2-a756-4c46-8b84-425ea23f954d.tmp: [WinError 32] The process cannot access the file because it is being used by another process: 'C:\\Users\\VALMIK~1\\AppData\\Local\\Temp\\0bfaf4d2-a756-4c46-8b84-425ea23f954d.tmp'
[CLEANUP] Could not remove C:\Users\VALMIK~1\AppData\Local\Temp\0fcf55a5-762e-4589-b10b-5eca2bef0e65.tmp: [WinError 32] The process cannot access the file because it is being used by another process: 'C:\\Users\\VALMIK~1\\AppData\\Local\\Temp\\0fcf55a5-762e-4589-b10b-5eca2bef0e65.tmp'
[CLEANUP] Could not remove C:\Users\VALMIK~1\AppData\Local\Temp\30804d63-b694-4472-8339-72895e4ea950.tmp: [WinError 32] The process cannot access the file because it is being used by another process: 'C:\\Users\\VALMIK~1\\AppData\\Local\\Temp\\30804d63-b694-4472-8339-72895e4ea950.tmp'
[CLEANUP] Could not remove C:\Users\VALMIK~1\AppData\Local\Temp\4667ddf0-5