In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import random
import torch
import torch.nn as nn
import anomaly_tpp as tpp

from tqdm.auto import tqdm, trange
from statsmodels.distributions.empirical_distribution import ECDF

sns.set_style("whitegrid")
%matplotlib inline

In [2]:
t_max = 100
batch_size = 64

seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

In [3]:
# scenario = tpp.scenarios.real_world.STEAD()
scenario = tpp.scenarios.real_world.ServerLogs()

In [4]:
id_train = scenario.id_train
id_test = scenario.id_test

dl_train = id_train.get_dataloader(batch_size=batch_size, shuffle=True)

In [5]:
# Fit a neural TPP model on the training ID sequences
torch.manual_seed(123)
ntpp = tpp.utils.fit_ntpp_model(dl_train, num_marks=id_train.num_marks)

  0%|          | 0/100 [00:00<?, ?it/s]

Early stopping at epoch 42


In [6]:
test_statistics = [
    tpp.statistics.ks_arrival,
    tpp.statistics.ks_interevent,
    tpp.statistics.chi_squared,
    tpp.statistics.sum_of_squared_spacings,
    tpp.statistics.loglike,
]

### Estimate distribution of each test statistic under $H_0$

In [7]:
# in-distribution (ID) training sequences are used to estimate the CDF of the test statistic under H_0
# (this is then used to compute the p-values)
id_train_batch = tpp.data.Batch.from_list(id_train)
id_train_poisson_times = tpp.utils.extract_poisson_arrival_times(ntpp, id_train_batch)

In [8]:
# Empirical distribution of each test statistic on id_train.
# This approximates the CDF of the test statistic under H_0
# and is used to compute the p-values
ecdfs = {}

for stat in test_statistics:
    name = stat.__name__
    scores = stat(poisson_times_per_mark=id_train_poisson_times, model=ntpp, batch=id_train_batch)
    ecdfs[name] = ECDF(scores)

def twosided_pval(stat_name: str, scores: np.ndarray):
    """Compute two-sided p-value for the given values of test statistic.
    
    Args:
        stat_name: Name of the test statistic, 
            {"ks_arrival", "ks_interevent", "chi_squared", "sum_of_squared_spacings"}
        scores: Value of the statistic for each sample in the test set,
            shape [num_test_samples]
    
    Returns:
        p_vals: Two-sided p-value for each sample in the test set,
            shape [num_test_samples]
    """
    ecdf = ecdfs[stat_name](scores)
    return 2 * np.minimum(ecdf, 1 - ecdf)

### Compute test statistic for ID test sequences

In [9]:
# ID test sequences will be compared to OOD test sequences to evaluate different test statistics
id_test_batch = tpp.data.Batch.from_list(id_test)
id_test_poisson_times = tpp.utils.extract_poisson_arrival_times(ntpp, id_test_batch)

# Compute the statistics for all ID test sequences
id_test_scores = {}
for stat in test_statistics:
    name = stat.__name__
    id_test_scores[name] = stat(poisson_times_per_mark=id_test_poisson_times, model=ntpp, batch=id_test_batch)

### Compute test statistic for OOD test sequences & evaluate AUC ROC based on the p-values

In [10]:
results = []

for name, ood_test in scenario.ood_test_datasets.items():
    ood_test_batch = tpp.data.Batch.from_list(ood_test)
    ood_test_poisson_times = tpp.utils.extract_poisson_arrival_times(ntpp, ood_test_batch)

    for stat in test_statistics:
        stat_name = stat.__name__
        id_scores = id_test_scores[stat_name]
        id_pvals = twosided_pval(stat_name, id_scores)

        ood_scores = stat(poisson_times_per_mark=ood_test_poisson_times, model=ntpp, batch=ood_test_batch)
        ood_pvals = twosided_pval(stat_name, ood_scores)

        auc = tpp.utils.roc_auc_from_pvals(id_pvals, ood_pvals)

        res = {"statistic": stat_name, "auc": auc, "scenario": name}
        results.append(res)

In [11]:
df = pd.DataFrame(results)

In [12]:
df.groupby(["scenario", "statistic"]).mean().round(3).unstack() * 100

Unnamed: 0_level_0,auc,auc,auc,auc,auc
statistic,chi_squared,ks_arrival,ks_interevent,loglike,sum_of_squared_spacings
scenario,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
Packet corruption (1%),78.0,43.2,73.1,91.3,94.6
Packet corruption (10%),56.0,73.5,99.0,99.0,99.0
Packet delay (all services),98.5,97.3,94.8,95.7,98.7
Packet delay (frontend),98.1,90.8,67.9,99.2,96.4
Packet duplication(1%),27.7,55.1,58.3,81.5,91.0
