# Model Finetune to Client Traffic post Pruning-SubsetSearch

In this notebook we investigate the potential effects of tha pruning and subset search achieved in the previous step. The aim is to verify whether it is possible to empower the lightweight model with the knwoledge of the oracle model, obtained in the initial Federated Learning training phase.

The main problem is that by pruning and removing features from the network traffic to fit specific deployment scenarios, the organization might have obtained a model with worse performance than the previous one. 
Moreover, if it is willing to specialize that model to its specific traffic categories, how does the model behave against the other previously seen attacks?
Is it possible to preserve the knwoledge? Will it face catastrophic forgetting? Can you postpone catastrophic forgetting somehow?

We will analyse different learning algorithms that the organization can adopt during this finetune process and measure the loss in the global knowledge of the model with respect to all the attacks.

In [1]:
%reload_ext autoreload
%autoreload 2
import pandas as pd
import threadpoolctl
import itertools
import matplotlib as mpl
from copy import deepcopy
from sklearn.metrics import accuracy_score
import os

from intellect.model.torch.model import Mlp
from intellect.model.torch.pruning import globally_unstructured_connections_l1
from intellect.io import load, dump, create_dir
from intellect.inspect import set_seed
from intellect.scoring import compute_metric_percategory
from intellect.dataset import (Dataset, ContinuouLearningAlgorithm, FeatureAvailability,
    portions_from_data, indexes_for_oracle_learning, InputForLearn)

threadpoolctl.threadpool_limits(limits=1);
mpl.rcParams['figure.dpi']= 70
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 20)

Define parameters and scenarios to be tested.

In [11]:
# parameters

# dataset and useful directories
DATASET = "./dataset_shrinked.h5"
TRAIN_MODEL = "train_output/oracle.pt"
RANK_DIR = "rank_prune_output/"
OUTPUT_DIR = "refit_output/"

# client categories, benign labels and dataset portions. Should be equal to the previous notebook.
CLIENT_CATEGORIES = ["BENIGN", "DDoS"]
BENIGN_LABELS = ["BENIGN"]
DATASET_PORTIONS = (0.6, 0.1, 0.1, 0.2)

# target feature subset size for which this notebook is performing the tests.
TARGET_SUBSET_RATIOS = (0.1, 0.3, 0.5, 0.8)

# all possible tested scenarios:
# o_to_o:  oracle to oracle scenario, the student/client model is a copy of the oracle model.
# o_to_po: oracle to pruned oracle scenario, where the student/client model is a pruned version of the oracle.
# o_to_eo: oracle to edge oracle scenario, where the student/client is a copy of the oracle model, but it is provided
#               with only a limited set of features
# o_to_ec: oracle to edge client scenario, where the student/client is a pruned version of the oracle model AND it is provided
#               with only a limited set of features
SCENARIOS = {
    "o_to_o": {"availability": (FeatureAvailability.bilateral,), "learn_input": (InputForLearn.client,)},
    "o_to_po": {"availability": (FeatureAvailability.bilateral,), "learn_input": (InputForLearn.client,),},
    "o_to_eo": {"availability": (FeatureAvailability.none, FeatureAvailability.oracle,), "learn_input": (InputForLearn.client, InputForLearn.oracle, InputForLearn.mixed)},
    "o_to_ec": {"availability": (FeatureAvailability.none, FeatureAvailability.oracle,), "learn_input": (InputForLearn.client, InputForLearn.oracle, InputForLearn.mixed)},}

# knowledge distillation hyperaparameters space
KD_HYPERPARAMS = {
    "alpha": (1,),#(1, 0.9),
    "temperature": (14,)}#(20, 14, 7, 4)}

# common hyperparameters
COMMON_PARAMETERS = {
    "max_epochs": (20,),#(100, ),
    "epochs_wo_improve": (100,),
    "batch_size": (64, ),
    "algorithm": (ContinuouLearningAlgorithm.ground_truth, ContinuouLearningAlgorithm.ground_inferred, ContinuouLearningAlgorithm.knowledge_distillation)}

Load the dataset, keep the validation portion to measure loss of knowledge, and finetune portions that will act as the new re-train portion.

In [3]:
def get_dataset():
    set_seed()
    return portions_from_data(DATASET, normalize=True, benign_labels=BENIGN_LABELS, ratios=DATASET_PORTIONS)
_, validation, finetune, _ = get_dataset()
cols = ["Global"] + CLIENT_CATEGORIES + [v for v in finetune._y.value_counts().sort_values(ascending=False).index.values if v not in CLIENT_CATEGORIES]

In [17]:
oracle_cache = {}
student_cache = {}

def run_test(save_prefix: str, finetune_ds: Dataset, features_available=None,
             skip=False, prune_ratio=None, availability=None, **kwargs):
    if skip is True and os.path.isfile(f"{save_prefix}.csv"):
        return
    if features_available is None:
        features_available = []
    set_seed()
    oracle_net= Mlp.load(TRAIN_MODEL)
    student_net = Mlp.load(TRAIN_MODEL)
    
    if prune_ratio is not None:
        student_net = globally_unstructured_connections_l1(student_net, prune_ratio)

    idx, idx_oracle = indexes_for_oracle_learning(finetune_ds, features_available, availability)
    
    key_oracle = (hash(finetune_ds), hash(str(idx_oracle)))
    key_student = (hash(finetune_ds), hash(str(idx)), hash(prune_ratio))

    if key_oracle not in oracle_cache:
        oracle_tmp_val = validation.clone()
        oracle_tmp_test = finetune_ds.clone()
        oracle_tmp_val.X.iloc[:, idx_oracle] = 0.
        oracle_tmp_test.X.iloc[:, idx_oracle] = 0.
        oracle_cache[key_oracle] = {
            "validation": compute_metric_percategory(oracle_tmp_val.y, oracle_net.predict(oracle_tmp_val.X), oracle_tmp_val._y, scorer=accuracy_score),
            "finetune": compute_metric_percategory(oracle_tmp_test.y, oracle_net.predict(oracle_tmp_test.X), oracle_tmp_test._y, scorer=accuracy_score)}

    tmp_val = validation.clone()
    tmp_test = finetune_ds.clone()
    tmp_val.X.iloc[:, idx] = 0.
    tmp_test.X.iloc[:, idx] = 0.

    if key_student not in student_cache:
        student_cache[key_student] = {
            "validation": compute_metric_percategory(tmp_val.y, student_net.predict(tmp_val.X), tmp_val._y, scorer=accuracy_score),
            "finetune": compute_metric_percategory(tmp_test.y, student_net.predict(tmp_test.X), tmp_test._y, scorer=accuracy_score)}
    
    hs, m = student_net.fit(finetune_ds, oracle=oracle_net, idx_active_features=idx, idx_active_features_oracle=idx_oracle,
                            monitori_ds=tmp_val, **kwargs)
    dump(m, f"{save_prefix}_monitor.csv")
    dump(hs, f"{save_prefix}_history.csv")
    
    df = pd.DataFrame(columns=cols)
    df.loc["Validation Before"] = student_cache[key_student]["validation"]
    df.loc["Validation After"] = compute_metric_percategory(tmp_val.y, student_net.predict(tmp_val.X), tmp_val._y, scorer=accuracy_score)
    df.loc["Finetune Before"] = student_cache[key_student]["finetune"]
    df.loc["Finetune After"] = compute_metric_percategory(tmp_test.y, student_net.predict(tmp_test.X), tmp_test._y, scorer=accuracy_score)
    df.loc["Oracle Validation"] = oracle_cache[key_oracle]["validation"]
    df.loc["Oracle Finetune"] = oracle_cache[key_oracle]["finetune"]
    dump(df, f"{save_prefix}.csv")
    return df

In [5]:
def run_scenario(scenario_dict, dirname, categories=None, features_available=None, skip=False, **kwargs):
    if categories is None:
        categories = []
    if not os.path.isdir(OUTPUT_DIR + dirname) or skip is False:
        create_dir(OUTPUT_DIR + dirname)
    tmp = finetune
    if categories:
        set_seed()
        tmp = tmp.filter_categories(categories).balance_categories()
    params = deepcopy(scenario_dict)
    params.update(COMMON_PARAMETERS)
    combinations = [dict(zip(params.keys(), v)) for v in itertools.product(*params.values())]
    for c in combinations:
        if c["availability"].value == FeatureAvailability.none.value and c["learn_input"].value != InputForLearn.client.value:
            continue
        add_params = [c]
        if features_available is None or not len(features_available):
            c["availability"] = FeatureAvailability.bilateral
        if c["algorithm"].name == ContinuouLearningAlgorithm.knowledge_distillation.name:
            add_params = [{**c, "learn_kwargs": dict(zip(KD_HYPERPARAMS.keys(), v))} for v in itertools.product(*KD_HYPERPARAMS.values())]
        print("----Running", len(add_params), "combinations for", c["algorithm"].name, "and", c["availability"], "and", c["learn_input"])
        for p in add_params:
            name = OUTPUT_DIR + dirname + "-".join(f"{k}_{v.name if hasattr(v, 'name') else v}" for k,v in p.items())
            run_test(name, tmp, **kwargs, features_available=features_available, skip=skip, **p)

In [6]:
create_dir(OUTPUT_DIR)

## Test Refit Algoithms for different use cases

Oracle to Oracle.

In [None]:
run_scenario(SCENARIOS["o_to_o"], f"o_to_o_few_c/", categories=CLIENT_CATEGORIES, features_available=None, prune_ratio=None, skip=True)

Oracle to Pruned Oracle. Select highest prune ratio with worst performance degradation.

In [None]:
res = load(RANK_DIR + "traffic_few_c_pruning_ratios_only.csv", index_col=0)
col = res["Prune Ratio"]
row = res.iloc[col[col == col.max()].index[-1]]
print("Highest pruning ratio to preserve the accuracy is", row["Prune Ratio"], f"(Accuracy of {row['Accuracy']})")
run_scenario(SCENARIOS["o_to_po"], f"o_to_po_few_c/", categories=CLIENT_CATEGORIES,
             features_available=None, prune_ratio=row["Prune Ratio"], skip=True)

Oracle to Edge Oracle. Select the worst performant accepted feature subset.

In [None]:
for s in TARGET_SUBSET_RATIOS:
    res = load(RANK_DIR + f"rank_few_c_traffic_few_c_subsets_features_for_subsetsize_{s}.csv", index_col=0)
    worst_row = res.iloc[-1]
    subset = worst_row[~worst_row.index.isin(["Accuracy"]) & worst_row.notnull()].index.values
    print("For a subset size of ratio", s, "the worst performant subset identified achieved an Accuracy of", worst_row["Accuracy"])
    run_scenario(SCENARIOS["o_to_eo"], f"o_to_eo_few_c_subset_{s}/", categories=CLIENT_CATEGORIES,
                 features_available=subset, prune_ratio=None, skip=True)

Oracle to Edge Client. Select highest prune ratio with worst performance degradation.

In [None]:
for s in TARGET_SUBSET_RATIOS:
    res = load(RANK_DIR + f"rank_few_c_traffic_few_c_combo_pruned_models_subsets_features_for_subsetsize_{s}.csv", index_col=0)
    col = res["Prune Ratio"]
    idx = col[col == col.max()].index.values[-1]
    worst_row = res.loc[idx]
    prune_ratio = worst_row["Prune Ratio"]
    subset = worst_row[~worst_row.index.isin(["Accuracy", "Prune Ratio"]) & worst_row.notnull()].index.values
    print("For a subset size of ratio", s, "the worst performant subset identified has a maximum accepted prune ratio of",
          prune_ratio, f"(Accuracy of {worst_row['Accuracy']})")
    run_scenario(SCENARIOS["o_to_ec"], f"o_to_ec_few_c_subset_{s}/", categories=CLIENT_CATEGORIES,
                 features_available=subset, prune_ratio=prune_ratio, skip=True)

If interested in any further test, there are all those with the model whose feature ranking has been performed on all traffic categories, and also those where the stochastic search is performed before the pruning (expect the oracle to be better than pruned model always).