In [None]:
import os, sys, warnings
from pathlib import Path

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

parent_dir = Path().resolve().parents[0]
sys.path.append(str(parent_dir))

In [None]:
from src.utils.utils import load_data, make_synthetic_data, train_and_evaluate_model
from src.gan.models import NetworkAnomalyDetector, DeeperNetworkAnomalyDetector
from src.gan.data_augmentors import CTGANAugmentor, TVAEAugmentor

# UNSW-NB15 Dataset

## Binary

In [None]:
dataset = "UNSW-NB15"                       # "UNSW-NB15" is smaller, "NF-ToN-IoT-v1" is larger
target_variable = 'Label'                  #  "Label" for binary, "Attack" for multi class
augmentor_ctgan = CTGANAugmentor()
augmentor_ctgan.minority_threshold = 100000

augmentor_tvae = TVAEAugmentor()
augmentor_tvae.minority_threshold = 100000

augmentors = [None, augmentor_ctgan, augmentor_tvae] 
model_type = DeeperNetworkAnomalyDetector   # NetworkAnomalyDetector, DeeperNetworkAnomalyDetector

X_train, X_test, y_train, y_test = load_data(dataset, target_variable)

for augmentor in augmentors:
    X_train_tvae, y_train_tvae, X_test_tvae, y_test_tvae, le_target = make_synthetic_data(
        X_train,
        y_train,
        X_test,
        y_test, 
        target_variable, 
        gan_augmentor=augmentor, 
    )

    model_tvae = train_and_evaluate_model(
        X_train_tvae, 
        y_train_tvae, 
        X_test_tvae, 
        y_test_tvae, 
        le_target,
        model=model_type
    )

# Multiclass

In [None]:
dataset = "UNSW-NB15"                       # "UNSW-NB15" is smaller, "NF-ToN-IoT-v1" is larger
target_variable = 'Attack'                  #  "Label" for binary, "Attack" for multi class
augmentor_ctgan = CTGANAugmentor()
augmentor_ctgan.minority_threshold = 15000

augmentor_tvae = TVAEAugmentor()
augmentor_tvae.minority_threshold = 15000

augmentors = [None, augmentor_ctgan, augmentor_tvae] 
model_type = DeeperNetworkAnomalyDetector   # NetworkAnomalyDetector, DeeperNetworkAnomalyDetector

X_train, X_test, y_train, y_test = load_data(dataset, target_variable)

for augmentor in augmentors:
    X_train_tvae, y_train_tvae, X_test_tvae, y_test_tvae, le_target = make_synthetic_data(
        X_train,
        y_train,
        X_test,
        y_test, 
        target_variable, 
        gan_augmentor=augmentor, 
    )

    model_tvae = train_and_evaluate_model(
        X_train_tvae, 
        y_train_tvae, 
        X_test_tvae, 
        y_test_tvae, 
        le_target,
        model=model_type
    )

# NF-ToN-IoT-v1 Dataset

## Binary

In [None]:
dataset = "NF-ToN-IoT"                   # "UNSW-NB15" is smaller, "NF-ToN-IoT-v1" is larger
target_variable = 'Label'                   #  "Label" for binary, "Attack" for multi class
augmentor_ctgan = CTGANAugmentor()
augmentor_ctgan.minority_threshold = 250000

augmentor_tvae = TVAEAugmentor()
augmentor_tvae.minority_threshold = 250000

augmentors = [None, augmentor_ctgan, augmentor_tvae] 
model_type = DeeperNetworkAnomalyDetector           # NetworkAnomalyDetector, DeeperNetworkAnomalyDetector

X_train, X_test, y_train, y_test = load_data(dataset, target_variable)


for augmentor in augmentors:
    X_train_tvae, y_train_tvae, X_test_tvae, y_test_tvae, le_target = make_synthetic_data(
        X_train,
        y_train,
        X_test,
        y_test, 
        target_variable, 
        gan_augmentor=augmentor, 
    )

    model_tvae = train_and_evaluate_model(
        X_train_tvae, 
        y_train_tvae, 
        X_test_tvae, 
        y_test_tvae, 
        le_target,
        model=model_type
    )

## Multiclass

In [None]:
dataset = "NF-ToN-IoT-v1"                   # "UNSW-NB15" is smaller, "NF-ToN-IoT-v1" is larger
target_variable = 'Attack'                  #  "Label" for binary, "Attack" for multi class
augmentor_ctgan = CTGANAugmentor()
augmentor_ctgan.minority_threshold = 50000

augmentor_tvae = TVAEAugmentor()
augmentor_tvae.minority_threshold = 50000

augmentors = [None, augmentor_ctgan, augmentor_tvae] 
model_type = DeeperNetworkAnomalyDetector           # NetworkAnomalyDetector, DeeperNetworkAnomalyDetector

X_train, X_test, y_train, y_test = load_data(dataset, target_variable)

for augmentor in augmentors:
    X_train_tvae, y_train_tvae, X_test_tvae, y_test_tvae, le_target = make_synthetic_data(
        X_train,
        y_train,
        X_test,
        y_test, 
        target_variable, 
        gan_augmentor=augmentor, 
    )

    model_tvae = train_and_evaluate_model(
        X_train_tvae, 
        y_train_tvae, 
        X_test_tvae, 
        y_test_tvae, 
        le_target,
        model=model_type
    )