In [4]:
from DATALOADER import *
from DRIFT_MODEL import *
from STATS_TEST import *
from LOSS_ANALYSIS import *
import matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [5]:
matplotlib.use('agg')
# 0. Environment Setting
train_folder_path = 'Drift Detection/Input/train'
pred_folder_path = 'Drift Detection/Input/prediction'
output_path = 'Drift Detection/Output'
model_path = 'Drift Detection/Model/model_1_0720.hdf5'

# 1. Data Preprocessing
reference, current = read_files(train_folder_path, pred_folder_path, output_path=output_path)
print('Files Read.')
data_dict = clean_data(reference, current, output_path, val_size=0.2)
feature_names = data_dict['val'].columns
print('Data Cleaned.')

# # 2. Model Initialization
detector, scaled_data = init_detector(data_dict, default_weight=True, model_path=model_path, 
                         hyper_params={'n_epochs':20, 'batch_size':128})
print('Model Ready.')

# 3. Model Inference
pred_base = detector.predict(scaled_data['train'])
pred = detector.predict(scaled_data['test'])
print('Data Reconstructed.')


# 4. Loss Analysis
loss_result_base = cal_loss(scaled_data['train'], pred_base, feature_names=feature_names, cal_type='Base', output_path=output_path)
loss_result = cal_loss(scaled_data['test'], pred, feature_names=feature_names, threshold=loss_result_base['threshold'], 
                       cal_type='Predict', output_path=output_path)
ae_df = pp_change(loss_result_base['loss_percentage'], loss_result['loss_percentage'], feature_names)
print('Outlier Loss Decomposed.')

# 5. Evidently AI 
drift_score = stats_test(data_dict['ref'], data_dict['cur'], output_path)
print('Stats Tests Completed.')

# 6. Combine Result from AE model and Stats tests
result_df = pd.merge(ae_df, drift_score, on ='feature', how='left')
# Plot: Juxtaposed Barplot
plot_df = result_df.melt(id_vars='feature', var_name='Score Type', value_name='Score')
plt.figure(figsize=(10,6))
sns.barplot(x='feature', y='Score', hue='Score Type', data=plot_df)
plt.xticks(rotation = 90)
plt.savefig(f'{output_path}/figures/Drift_Score.png')