In [None]:
# imports
from data.dataset import IntrusionData
import numpy as np
from torch.utils.data import DataLoader

# Parameter

In [None]:
# data set
test_ratio = 0.5
contamination_rate = 0.0

# general
ckpt_root = '../tmp'
ckpt_file = None # if checkpoint

benign_label = 0
anomaly_label = 1
batch_size = 8192
lr = 1e-4
weight_decay = 1e-4
epochs = 5
device = 'cuda'

# AE params
latent_dim = 2
act_fn = 'relu'
n_layers = 4
compression_factor = 2
reg = 0.5

# Load Data

In [None]:
# load data set
path_to_dataset = '../data/processed/unswnb15.csv'
dataset = IntrusionData(path_to_dataset)
dataset.load_data()

train_set, test_set = dataset.split_train_test(test_ratio=test_ratio, contamination_rate=contamination_rate, pos_label=anomaly_label)
train_ldr = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_ldr = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True)

In [None]:
print('columns: ', dataset.columns)
x_train, y_train = dataset[train_set.indices]
x_test, y_test = dataset[test_set.indices]

print('train set: ')
print('samples: ', len(y_train))
print('normal samples: ', len(np.where(y_train == benign_label)[0]))
print('malicious samples: ', len(np.where(y_train == anomaly_label)[0]))
print("\n")
print('test set: ')
print('samples: ', len(y_test))
print('normal samples: ', len(np.where(y_test == benign_label)[0]))
print('malicious samples: ', len(np.where(y_test == anomaly_label)[0]))

# DAE Training

In [None]:
from baselines.model.reconstruction import AutoEncoder
from baselines.trainer.reconstruction import AutoEncoderTrainer


model = AutoEncoder(
    device=device,
    in_features=dataset.in_features,
    n_instances=dataset.n_instances,
    latent_dim=latent_dim,
    act_fn=act_fn,
    n_layers=n_layers,
    compression_factor=compression_factor,
    reg=reg
)

trainer = AutoEncoderTrainer(
    model=model,
    batch_size=batch_size,
    lr=lr,
    weight_decay=weight_decay,
    n_epochs=epochs,
    device=device,
    anomaly_label=anomaly_label,
    test_ldr=test_ldr,
    ckpt_root=ckpt_root
)

In [None]:
# If loading a checkpoint from file
if ckpt_file:
    from baselines.trainer.reconstruction import AutoEncoderTrainer
    trainer, model = AutoEncoderTrainer.load_from_file(ckpt_file, device=device)
    trainer.test_ldr = test_ldr

In [None]:
# train model
trainer.train(train_ldr)

In [None]:
trainer.plot_metrics('../reports/figures/DAE_NB15_KDD.png')