In [None]:
from typing import Union, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import dalex as dx

from tqdm import tqdm
from xai.data.reader import read_data
from xai.models import RandomForestModel, LogisticRegressionModel
from xai.validation import HoldOutValidation

import warnings
warnings.simplefilter("ignore")
pd.set_option('display.max_columns', 500)

In [None]:
features, target = read_data('data/hotel_bookings.csv')

In [None]:
features.head()

In [None]:
model_classes = [RandomForestModel, LogisticRegressionModel]
model_params = [dict(n_jobs=6), dict(C=1, max_iter=2000, n_jobs=6)]

In [None]:
models = [_class(**_params).fit(features, target) for _class, _params in zip(model_classes, model_params)]

In [None]:
transformed_features = models[0]._feature_engineering(features, train=False)

In [None]:
plt.style.use('seaborn-whitegrid')


IMPORTANT_FEATUES = [
    'lead_time',
    'arrival_date_month',
    'arrival_date_week_number',
    'is_repeated_guest',
    'booking_changes',
    'adr',
    'days_in_waiting_list'
]

for feature_id, feature_name in enumerate(IMPORTANT_FEATUES):
    plt.subplots(1, 2, figsize=(24, 8), facecolor='w')
    for model in models:
        exp = dx.Explainer(model.model, data=transformed_features, y=target, verbose=False)
        plt.subplot(1, 2, 1)
        pdp_num = exp.model_profile(type = 'partial', variables=[feature_name])
        plt.plot(pdp_num.result._x_, pdp_num.result._yhat_, label=model.__class__.__name__, lw=4)

    plt.title('Partial dependance profile', fontsize=24)
    plt.ylabel('Model predictions', fontsize=32)
    plt.xlabel('Feature values', fontsize=32)
    plt.yticks(fontsize=24)
    plt.xticks(fontsize=24)
    plt.legend(fontsize=24)


    for model in models:
        exp = dx.Explainer(model.model, data=transformed_features, y=target, verbose=False)
        plt.subplot(1, 2, 2)
        ale_num = exp.model_profile(type = 'accumulated', variables=[feature_name])
        plt.plot(ale_num.result._x_, ale_num.result._yhat_, label=model.__class__.__name__, lw=4)

    plt.title('Accumulated dependance profile', fontsize=24)
    plt.xlabel('Feature values', fontsize=32)
    plt.yticks(fontsize=24)
    plt.xticks(fontsize=24)
    plt.legend(fontsize=24)
    plt.suptitle(feature_name, fontsize=32)
    plt.tight_layout()
    plt.savefig(str(feature_id) + '.png')
    plt.show()