In [None]:
import json
import os
import numpy as np
import pandas as pd
from collections import defaultdict
from libraries.utils import *
from libraries.hybrid import hybrid

## Load Data

In [None]:
############ configuration ################
############################################

CODE = 'mamba2'       ### application (code)
BEHAVIOUR_FAULTY = 'faulty_data'            ### normal, faulty_data
BEHAVIOUR_NORMAL = 'normal'            ### normal, faulty_data
THREAD = 'single'           ### single, multi
VER = 3                     ### format of data collection

base_dir = '../trace_data' ### can be replaced with 'csv', 'exe_plot', 'histogram'
normalbase_path = base_dir+f'/{CODE}/{THREAD}_thread/version_{VER}/{BEHAVIOUR_NORMAL}'
faultybase_path = base_dir+f'/{CODE}/{THREAD}_thread/version_{VER}/{BEHAVIOUR_FAULTY}'

print(normalbase_path)
print(faultybase_path)

In [None]:

train_base_path = os.path.join(normalbase_path, 'train_data')
train_data_path = [os.path.join(train_base_path, x) for x in os.listdir(train_base_path)]
train_varlist_path = os.listdir(normalbase_path)
train_varlist_path = [os.path.join(normalbase_path, x) for x in train_varlist_path if 'varlist' in x]

######### get paths #######################
paths_log, paths_traces, varlist_path, paths_label = get_paths(faultybase_path)

### remove.Ds_store from all lists
train_data_path = [x for x in train_data_path if '.DS_Store' not in x]
train_varlist_path = [x for x in train_varlist_path if '.DS_Store' not in x]

paths_log = [x for x in paths_log if '.DS_Store' not in x]
paths_traces = [x for x in paths_traces if '.DS_Store' not in x]
varlist_path = [x for x in varlist_path if '.DS_Store' not in x]
paths_label = [x for x in paths_label if '.DS_Store' not in x]

paths_log.sort()
paths_traces.sort()
varlist_path.sort()
paths_label.sort()

print(train_data_path)
print(paths_log)
print(paths_traces)
print(varlist_path)
print(paths_label)

test_data_path = paths_traces
test_label_path = paths_label


In [None]:
############# check varlist is consistent ############
############# only for version 3 ######################

if VER == 3:
    check_con, _ = is_consistent([train_varlist_path[0]]+ varlist_path) ### compare with train varlist

    if check_con != False:
        to_number = read_json(varlist_path[0])
        from_number = mapint2var(to_number)
    else:
        ### load normal varlist
        print('loading normal varlist')
        to_number = read_json(train_varlist_path[0])
        from_number = mapint2var(to_number)



In [None]:
############ Get variable list ######################
sorted_keys = list(from_number.keys())
sorted_keys.sort()
var_list = [from_number[key] for key in sorted_keys]   ### get the variable list
# print(var_list)

## Train

In [None]:
### initialize the hybrid model
hybrid = hybrid()

In [None]:
hybrid.train(train_data_path)

In [None]:
transitions = hybrid.transitions
print(transitions)

In [None]:
### viz transitions

for key in transitions.keys():
    print(from_number[key], ':', end=' ')
    for val in transitions[key]:
        print(from_number[val], end=', ')
    print('\n')

In [None]:
thresholds = hybrid.thresholds
### visualize the thresholds for varlist
for key in thresholds.keys():
    print(from_number[key], ':', end=' ')
    print(thresholds[key], end=', ')
    print('\n')

### Visualising Thresholds

In [None]:
#### plot exe_list to vsiualize the distribution of execution intervals
hybrid.viz_thresholds()


### Validation

In [None]:
#### Detect anomalies in faulty traces
DIFF_VAL = 2
all_tp = []
all_fp = []
all_detections = [] ### format [file1_detection, file2_detection] -> file1_detection: [(state1, 0), (ts1, ts2), filename]  
all_group_detections = [] ### format [file1_detection, file2_detection] -> file1_detection: [(state1, 0), (ts1, ts2), filename]
all_merged_detections = [] ### format [file1_detection, file2_detection] -> file1_detection: [(state1, 0), (ts1, ts2), filename]
y_pred_all = []
y_true_all = []
for ti, (test_data, test_label) in enumerate(zip(test_data_path, test_label_path)):
    print(ti, test_data, test_label)
    # if ti == 1:
         
    hybrid_detections = hybrid.test_single(test_data, thresholds)   ### detection in format: [var, (ts1,ts2), file_name]
    ei_detection = hybrid.ei_detections
    st_detection = hybrid.st_detections

    all_detections += [(test_data, hybrid_detections, test_label)]  ### used to plot detections
    # all_group_detections += [(test_data, grouped_det, test_label)]  ### used to plot grouped detections
    # all_merged_detections += [(test_data, merged_detection, test_label)]  ### used to plot merged detections

    ### load ground truths
    ground_truth_raw = read_traces(test_label)
    ground_truth = ground_truth_raw['labels']
    label_trace_name = list(ground_truth.keys())[0]
    ground_truth = ground_truth[label_trace_name]
    print('ground truths:', ground_truth)
    print(len(ground_truth))

    # correct_pred, rest_pred, y_pred, y_true = get_ypred_ytrue(detection, ground_truth)  ### case1_pred, case2_pred, case34_pred, rest_pred
    correct_pred, rest_pred, y_pred, y_true = hybrid.get_correct_detections(hybrid_detections, ground_truth)  ### case1_pred, case2_pred, case34_pred, rest_pred

    assert( len(hybrid_detections) == len(correct_pred) + len(rest_pred) )

    all_tp += [(test_data, correct_pred, test_label)]
    all_fp += [(test_data, rest_pred, test_label)]

    y_pred_all.extend(y_pred)
    y_true_all.extend(y_true)

    # break

In [None]:
hybrid_detections

In [None]:
'''
Approach 1: (based on the assumption that ST can make very precise detections but can miss some detections, while EI can detect every event but gives more wide detections)

for each ei detection detect all the st detection within it, store it in a list
if st detection exist for any ei detection, output only st.
if st detection does not exist for any ei detection, output ei

'''

# all_detections = []
# for ei_det in ei_detection:
#     #### structure of detection, get elements
#     ei_var = ei_det[0]
#     eits1, eits2 = ei_det[1]
#     print('EI', eits1, eits2)

#     #### get all st detections within the ei detection
#     st_detections_within_ei = []
#     for i, st_det in enumerate(st_detection):
#         st_var = st_det[0]
#         stts1, stts2 = st_det[1]
#         print('ST', i, stts1, stts2)

#         if eits1 <= stts1 and eits2 >= stts2:
#             st_detections_within_ei.append(st_det)
            
#     if len(st_detections_within_ei) > 0:
#         for det in st_detections_within_ei:
#             print('Removed', det)
#             st_detection.remove(det)
#         all_detections.extend(st_detections_within_ei)
#         print(ei_det, 'replaced with', st_detections_within_ei)
#     else:
#         all_detections.append([ei_det])
#         print('added', ei_det)

# print('Any Detections in ST that are not in EI:')
# print(st_detection)
# all_detections.extend(st_detection)



In [None]:
print(len(st_detection), len(ei_detection))
st_detection

In [None]:
### Evaluation metrics

from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score, average_precision_score, ConfusionMatrixDisplay


# Calculate precision
precision = precision_score(y_true_all, y_pred_all)
print(f'Precision: {precision:.4f}')

# Calculate recall
recall = recall_score(y_true_all, y_pred_all)
print(f'Recall: {recall:.4f}')

# # Calculate average precision
# average_precision = average_precision_score(y_true_all, y_pred_all)
# print(f'Average Precision: {average_precision:.4f}')

# Calculate F1 score
f1 = f1_score(y_true_all, y_pred_all)
print(f"F1 Score: {f1:.4f}")

# Calculate confusion matrix
conf_matrix = confusion_matrix(y_true_all, y_pred_all)
print("Confusion Matrix:")
print(conf_matrix)
if len(conf_matrix) == 1:
    conf_matrix = np.array([[0, 0], [0, conf_matrix[0][0]]])
disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=['normal', 'anomaly'])
disp.plot()

## Save Detections

In [None]:
######## save detections for the dashboard to plot #############
import traceback

for test_data, detections, test_label in all_detections:
    # print(test_data, test_label)
    # print(test_label.replace('labels', 'detections'))
    detection_path = test_label.replace('labels', 'st_detections')
    tp_detection_path = detection_path.replace('st_detections.json', 'tp_st_detections.json')
    fp_detection_path = detection_path.replace('st_detections.json', 'fp_st_detections.json')
    # print(detections)

    detection_dir = os.path.dirname(detection_path)
    # print(detection_dir)
    if not os.path.exists(detection_dir):
        os.makedirs(detection_dir)
        print(f'Created Directory: {detection_dir}')

    try:
        with open(detection_path, 'w') as f:
            json.dump(detections, f)
            print(f'Saved detections in {detection_path}')

            
    except Exception as e:
        traceback.print_exception(e)
        print('Error in saving detections')
        continue

for test_data, detections, test_label in all_tp:
    # print(test_data, test_label)
    # print(test_label.replace('labels', 'detections'))
    detection_path = test_label.replace('labels', 'st_detections')
    tp_detection_path = detection_path.replace('st_detections.json', 'tp_st_detections.json')
    fp_detection_path = detection_path.replace('st_detections.json', 'fp_st_detections.json')
    # print(detections)

    detection_dir = os.path.dirname(detection_path)
    # print(detection_dir)
    if not os.path.exists(detection_dir):
        os.makedirs(detection_dir)
        print(f'Created Directory: {detection_dir}')

    try:

        with open(tp_detection_path, 'w') as f:
            json.dump(detections, f)
            print(f'Saved detections in {tp_detection_path}')
            
    except Exception as e:
        traceback.print_exception(e)
        print('Error in saving detections')
        continue

for test_data, detections, test_label in all_fp:
    # print(test_data, test_label)
    # print(test_label.replace('labels', 'detections'))
    detection_path = test_label.replace('labels', 'st_detections')
    tp_detection_path = detection_path.replace('st_detections.json', 'tp_st_detections.json')
    fp_detection_path = detection_path.replace('st_detections.json', 'fp_st_detections.json')
    # print(detections)

    detection_dir = os.path.dirname(detection_path)
    # print(detection_dir)
    if not os.path.exists(detection_dir):
        os.makedirs(detection_dir)
        print(f'Created Directory: {detection_dir}')

    try:

        with open(fp_detection_path, 'w') as f:
            json.dump(detections, f)
            print(f'Saved detections in {fp_detection_path}')
            
    except Exception as e:
        traceback.print_exception(e)
        print('Error in saving detections')
        continue

## Plot Detections

In [None]:
# ### plot gt and detections
# for test_data, detections, test_label in all_detections:
# # for test_data, detections, test_label in all_fp:
#     # print('test_data:', test_data)
#     # print('detections:', detections)
#     # print(test_label)

#     ### prepare trace to plot
#     col_data = preprocess_traces([test_data])
#     all_df = get_dataframe(col_data) 
#     # print(all_df[0])

#     ### prepare detections to plot
#     timestamps = col_data[0][1]
#     print('timestamps:', timestamps)
#     plot_val = []
#     plot_x_ticks = []
#     plot_class = []
#     for det in detections:
#         # print(det)
#         det_ts1, det_ts2 = det[1]
#         # print(det_ts1, det_ts2)

#         det_ind1_pre = [ abs(t-det_ts1) for t in timestamps]
#         det_ind1 = det_ind1_pre.index(min(det_ind1_pre))

#         det_ind2_pre = [ abs(t-det_ts2) for t in timestamps]
#         det_ind2 = det_ind2_pre.index(min(det_ind2_pre))
#         # print(det_ind1, det_ind2)
#         # print(timestamps[det_ind1], timestamps[det_ind2])

#         plot_val += [(det_ind1, det_ind2)]
#         plot_x_ticks += [(timestamps[det_ind1], timestamps[det_ind2])]
#         plot_class += [0]

#     plot_detections = [plot_val, plot_x_ticks, plot_class]

#     ### get ground truths
#     gt_plot = prepare_gt(test_label)

#     ### plot
#     for df in all_df:
#         # print(df.columns)
#         plot_single_trace(df, 
#                           var_list, 
#                           with_time=False, 
#                           is_xticks=True, 
#                           detections=plot_detections, 
#                           dt_classlist=['detection'],
#                           ground_truths=gt_plot,
#                           gt_classlist=['gt_communication', 'gt_sensor', 'gt_bitflip'],
#                           )

#     # break