# Load Packages

In [None]:
import functions as f
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px

np.random.seed(0)
plt.style.use('ggplot')

# Load Data

In [None]:
df = pd.read_csv("../data/prepped_data.csv", low_memory=False, index_col=0).drop_duplicates()
segments = pd.read_csv("../customer_segmentation/segments.csv", index_col=0)

df = df[df["first_data_year"] >= 2021]
df.index = df["policy_nr_hashed"]
df = df.drop("policy_nr_hashed", axis=1)

segments.index = segments["policy_nr_hashed"]
segments = segments.drop("policy_nr_hashed", axis=1)
segments = pd.get_dummies(segments, columns=["cluster"])

group_names = {
    'cluster_0': 'Value Seekers', 
    'cluster_1': 'High-Income Customers',
    'cluster_2': 'Basic Coverage',
    'cluster_3': 'Rural Customers',
}

segments = segments.rename(columns=group_names)

# Run Double ML

In [None]:
first_stage_1, first_stage_2, double_mls, splits = f.global_run(df, splits=3, cols_to_drop_manual=['last_type'], iters=50, log=False, intermediary_scores=False)

In [None]:
i = 0

for k, v in double_mls.items():
    # if i == 2:
        print(k)
        # display(v.summary)
        included_policy_nr = splits[k].index.to_list()
        segments_i = segments.loc[included_policy_nr]
        gate = v.gate(groups=segments_i)
        display(gate.summary)
        # v.sensitivity_analysis(cf_y=0.0, cf_d=0.0, rho=0.4833)
        # print(v.sensitivity_summary)
        # v.sensitivity_plot()

    # i += 1

In [None]:
v.sensitivity_analysis(cf_y=0.1158, cf_d=0.1158, rho=0.4833)
fig = v.sensitivity_plot()
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
    margin=dict(l=20, r=20, t=20, b=20),
    coloraxis_showscale=False
)

In [None]:
group_1 = [9.209, 10.312, 19.822]
group_2 = [6.740, 16.398, 19.323]
group_3 = [7.418, 6.760, 15.915]
group_4 = [6.057, 11.811, 21.584]


x = ["0%-16.2%", "16.2%-24.3%", "24.3%-30%"]

fig, ax = plt.subplots()

ax.plot(x, group_1, label="Value Seekers (1)", marker="x", linestyle='--', linewidth=1, color="firebrick")

ax.plot(x, group_2, label="High-Income Customers (2)", marker="x", linestyle='--', linewidth=1, color="goldenrod")

ax.plot(x, group_3, label="Basic Coverage (3)", marker="x", linestyle='--', linewidth=1, color="darkcyan")

ax.plot(x, group_4, label="Rural Customers (4)", marker="x", linestyle='--', linewidth=1, color="green")

plt.xticks(x)
ax.legend()
plt.xlabel("Discount Range")
plt.ylabel("Average Treatment Effect [%]")
# plt.savefig('../plots/gate.png', dpi=200)
plt.show()

In [None]:
group_1 = [21.9, 23, 32.5]
group_2 = [16.3, 26, 28.9]
group_3 = [27.7, 27.1, 36.2]
group_4 = [16.7, 22.4, 32.2]


x = ["0%-16.2%", "16.2%-24.3%", "24.3%-30%"]

fig, ax = plt.subplots()

ax.plot(x, group_1, label="Value Seekers (1)", marker="x", linestyle='--', linewidth=1, color="firebrick")

ax.plot(x, group_2, label="High-Income Customers (2)", marker="x", linestyle='--', linewidth=1, color="goldenrod")

ax.plot(x, group_3, label="Basic Coverage (3)", marker="x", linestyle='--', linewidth=1, color="darkcyan")

ax.plot(x, group_4, label="Rural Customers (4)", marker="x", linestyle='--', linewidth=1, color="green")

plt.xticks(x)
ax.legend()
plt.xlabel("Discount Range")
plt.ylabel("Average Treatment Effect [%]")
plt.savefig('../plots/final_model.png', dpi=200)
plt.show()

# Code to check max C_y and C_d in our data

In [None]:
# i = 0

# for k_, v_ in double_mls.items():

#     if i == 0:
#         print(k_)
#         display(v_.summary)
#         features = [col for col in splits[k_].columns if col not in ['welcome_discount', 'churn']]
#         benchmark_sensitivities = {}

#         def process_feature(feature):
#             return feature, v_.sensitivity_benchmark(benchmarking_set=[feature])

#         results = Parallel(n_jobs=-1)(delayed(process_feature)(feature) for feature in features)

#         for feature, result in results:
#             benchmark_sensitivities[feature] = result

#         cf_y_lst = []
#         cf_d_lst = []
#         names = []
#         rhos = []

#         for k, v in benchmark_sensitivities.items():
#             cf_y_lst.append(v.loc["welcome_discount", "cf_y"])
#             cf_d_lst.append(v.loc["welcome_discount", "cf_d"])
#             rhos.append(v.loc["welcome_discount", "rho"])
#             names.append(k)

#         benchmark_dict = {
#             "cf_y" : cf_y_lst,
#             "cf_d" : cf_d_lst,
#             "name" : names
#         }

#         v_.sensitivity_analysis(cf_y=0.04, cf_d=0.03)
#         v_.sensitivity_plot(benchmarks=benchmark_dict)

#     i += 1

# print(f"Max cf_y: {np.max(cf_y_lst)}")
# print(f"Max cf_d: {np.max(cf_d_lst)}")
# print(f"Max rho: {np.max([np.abs(rho) for rho in rhos if np.abs(rho) != 1.0])}")