In [None]:
# imports
from data.dataset import MultiViewIntrusionData, MultiViewIntrusionDataTransformer
import numpy as np
from torch.utils.data import DataLoader
from data.augmentation import MultiViewDataInjector

# Parameter

In [None]:
# data set
test_ratio = 0.5
contamination_rate = 0.0

# general
ckpt_root = '../tmp'
ckpt_file = None # if checkpoint

benign_label = 0
anomaly_label = 1
batch_size = 512
lr = 1e-4
weight_decay = 1e-4
epochs = 1
device = 'cuda'

# ssl params
encoder_class = 'CNN'
encoder_args = {}
mlp_params = {
    'embedding_dim': 256,
    'output_dim': 256,
    'n_layers': 2,
    'batch_norm': True,
    'dropout': False,
}

# augmentation params
transformations = [[{'ShuffleSwapNoise': {'p': 0.4}}], [{'ShuffleSwapNoise': {'p': 0.4}}]]
mixup_alpha = None
n_subsets = 2
overlap = 1.0
subsets = n_subsets != 2 or overlap != 1.0
if subsets:
    transformations = [None]*n_subsets # subsets

# Load Data

In [None]:
#  define augmentations
train_transform = MultiViewDataInjector(transformations, n_subsets, overlap, training=True)
test_transform = None if subsets is False else MultiViewDataInjector([None] * n_subsets, n_subsets, overlap, training=False)

# load data set
path_to_dataset = '../data/processed/unswnb15.csv'
if encoder_class == 'FTTransformerEncoder':
    dataset = MultiViewIntrusionDataTransformer(path_to_dataset, train_transform, test_transform, shuffle_features=subsets)
    encoder_args = {
        'categorical_col_indices': dataset.categorical_cols_idx,
        'categories_unique_values': dataset.unique_cats,
        'numeric_col_indices': dataset.numeric_cols_idx
    }
    if n_subsets > 2 or overlap < 1.0:
        # workaround
        encoder_args['numeric_col_indices'] = range(0, (train_transform.n_overlap + train_transform.n_features_subset))
else:
    dataset = MultiViewIntrusionData(path_to_dataset, train_transform, test_transform, shuffle_features=subsets)

train_set, test_set = dataset.split_train_test(test_ratio=test_ratio, contamination_rate=contamination_rate, pos_label=anomaly_label)
train_ldr = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_ldr = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

In [None]:
print('columns: ', dataset.columns)

print('train set: ')
print('samples: ', len(train_set))
print('normal samples: ', len(np.where(train_set.labels == benign_label)[0]))
print('malicious samples: ', len(np.where(train_set.labels == anomaly_label)[0]))
print("\n")
print('test set: ')
print('samples: ', len(test_set))
print('normal samples: ', len(np.where(test_set.labels == benign_label)[0]))
print('malicious samples: ', len(np.where(test_set.labels == anomaly_label)[0]))

# SimSiam Training

In [None]:
from models.ssl.simsiam import SimSiam
from trainer.ssl.simsiam import SimSiam_Trainer

if encoder_class == 'CNN' or encoder_class == 'MLP_Encoder':
    encoder_args['num_features'] = dataset.in_features

model = SimSiam(
    device=device,
    in_features=dataset.in_features,
    n_instances=dataset.n_instances,
    encoder_class=encoder_class,
    mixup_alpha=mixup_alpha,
    mlp_params=mlp_params,
    **encoder_args
)

trainer = SimSiam_Trainer(
    model=model,
    batch_size=batch_size,
    lr=lr,
    weight_decay=weight_decay,
    n_epochs=epochs,
    device=device,
    anomaly_label=anomaly_label,
    test_ldr=None,
    ckpt_root=ckpt_root
)

In [None]:
# If loading a checkpoint from file
if ckpt_file:
    from trainer.ssl.simsiam import SimSiam_Trainer
    trainer, model = SimSiam_Trainer.load_from_file(ckpt_file, device=device)

In [None]:
# train model
trainer.train(train_ldr)

In [None]:
from models.ssl_evaluation.kmeans import KMeans_Eval
if subsets:
    train_set.transform.transformations = [None] * n_subsets
else:
    train_set.transform = None

kmeans = KMeans_Eval(
    encoder = model.get_encoder(),
    batch_size=batch_size,
    device=device
)

kmeans.fit(train_ldr)
kmeans.validate(test_ldr)