In [None]:
import pickle
import pandas as pd
import os
from explainers.dce import DistributionalCounterfactualExplainer
import torch
from utils.visualization import *
from utils.data_processing import *
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec


pd.set_option('display.max_columns', None)

%reload_ext autoreload
%autoreload 2

In [None]:
os.environ["PATH"] += os.pathsep + 'C:\\texlive\\2023\\bin\\windows'

In [None]:
data_path = 'data/german_credit'

In [None]:
df_ = pd.read_csv(os.path.join(data_path, 'german_credit_data.csv'))
df = df_.copy()

In [None]:
df.sample(5)

In [None]:
df, label_mappings = feature_encoding(df=df, target_name='Risk', target_encode_dict={"good": 0, "bad": 1})

In [None]:
factual = pd.read_csv(os.path.join(data_path,'factual.csv'))
counterfactual = pd.read_csv(os.path.join(data_path,'counterfactual.csv'))

In [None]:
plot_quantile(factual=factual, counterfactual=counterfactual, column_name='Credit amount');

In [None]:
np.quantile(factual['Credit amount'], 0.4)

In [None]:
factual.head(10)

In [None]:
counterfactual.head(10)

In [None]:
interval = (0, 25, 35, 65, 120)

cats = ['Student', 'Young', 'Adult', 'Senior']
df["Age_cat"] = pd.cut(df.Age, interval, labels=cats)

for cat in cats:
    risk_prob = df[df.Age_cat == cat]['Risk'].mean()
    print(f'Risk {cat}: {risk_prob}')

In [None]:
plot_quantile(factual=factual, counterfactual=counterfactual, column_name='Age');

In [None]:
factual['Risk'].mean(), counterfactual['Risk'].mean()

In [None]:
plot_quantile(factual=factual, counterfactual=counterfactual, column_name='Risk');

In [None]:
print('Risk Low credit amount:', df[df['Credit amount']<1000]['Risk'].mean())
print('Risk Middle credit amount:', df[ (df['Credit amount']>=1000) & (df['Credit amount']<=6000) ]['Risk'].mean())
print('Risk High credit amount:', df[df['Credit amount']>6000]['Risk'].mean())

In [None]:
factual['Risk'].mean(), counterfactual['Risk'].mean()

In [None]:
factual['data'] = 'factual'
counterfactual['data'] = 'counterfactual'

fcf = pd.concat([factual, counterfactual])

def plot_quantile_ax(factual, counterfactual, column_name):
    quantiles_factual = factual[column_name].quantile(np.linspace(0, 1, 100))
    quantiles_counterfactual = counterfactual[column_name].quantile(np.linspace(0, 1, 100))

    # Plot quantiles
    plt.plot(quantiles_factual.values, np.linspace(0, 1, 100), label="Factual")
    plt.plot(quantiles_counterfactual.values, np.linspace(0, 1, 100), label="Counterfactual")
    plt.xlabel("Quantile Values")
    plt.ylabel("Quantiles")
    plt.title(f"{column_name}")
    plt.legend()
    plt.grid(True)

# Assuming factual and counterfactual are pandas DataFrames with the same columns
columns = ['Age', 'Credit amount', 'Duration', 'Risk']

# Create a 3x4 subplot layout
plt.figure(figsize=(20, 5))
for i, column in enumerate(columns):
    plt.subplot(1, 4, i+1)
    plot_quantile_ax(factual, counterfactual, column)

plt.tight_layout()
plt.show()

def hist_plot_ax(df, x, hue, title, ax):
    g = sns.countplot(x=x, hue=hue, data=df.sort_values(by=x), palette="hls", ax=ax)
    g.set_xticklabels(g.get_xticklabels(), rotation=45)
    g.set_xlabel(x, fontsize=12)
    g.set_ylabel("Count", fontsize=12)
    g.set_title(title, fontsize=20)


columns = ['Sex', 'Job', 'Housing', 'Saving accounts', 'Checking account']

factual['is_cf'] = False
counterfactual['is_cf'] = True

fcf = pd.concat([factual, counterfactual])

# Create a 3x4 subplot layout
plt.figure(figsize=(25, 5))
for i, column in enumerate(columns):
    ax = plt.subplot(1, 5, i+1)
    hist_plot_ax(fcf, column, 'data', column, ax)

plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(25,5))

hist_plot_ax(fcf, 'Purpose', 'data', 'Purpose', ax)

plt.tight_layout()
plt.show()

In [None]:
column_name = 'Risk'
quantiles_factual = factual[column_name].quantile(np.linspace(0, 1, 100))
quantiles_counterfactual = counterfactual[column_name].quantile(np.linspace(0, 1, 100))

In [None]:
s_factual = ""
for k, v in quantiles_factual.to_dict().items():
    s_factual += f" ({np.round(v,6)},{np.round(k,6)})"
print(s_factual)

In [None]:
s_counterfactual = ""
for k, v in quantiles_counterfactual.to_dict().items():
    s_counterfactual += f" ({np.round(v,6)},{np.round(k,6)})"
print(s_counterfactual)

In [None]:
column_name = 'Purpose'
factual[column_name].value_counts()

In [None]:
counterfactual[column_name].value_counts()

In [None]:
column_name = 'Purpose'
pd.concat(
    [
        pd.DataFrame(factual.groupby(column_name)['Risk'].mean()).rename({'Risk': 'Risk_factual'}, axis=1),
        pd.DataFrame(counterfactual.groupby(column_name)['Risk'].mean()).rename({'Risk': 'Risk_counterfactual'}, axis=1)
    ],
    axis=1,
)

In [None]:
column_name = 'Saving accounts'
indice = factual[(factual[column_name] != counterfactual[column_name])].index

factual.loc[indice]


In [None]:
counterfactual.loc[indice]

In [None]:
def colorful_scatter(df, x, y, color_col, title, ax):
    scatter = ax.scatter(df[x], df[y], alpha=0.7, c=df[color_col], cmap='rocket_r')

    # Adding colorbar to show the scale
    plt.colorbar(scatter, ax=ax, label=color_col)

    # Adding labels and title
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    ax.set_title(title)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))  # Create 1 row, 2 columns of subplots

# Assuming 'df' is your DataFrame and it already contains the columns 'Credit amount', 'Duration', and 'Risk'
# Call the function with the respective axes object
colorful_scatter(factual[factual['Saving accounts']=='little'], 'Credit amount', 'Duration', 'Risk', 'Factual', ax1)
colorful_scatter(counterfactual[counterfactual['Saving accounts']=='little'], 'Credit amount', 'Duration', 'Risk', 'Counterfactual', ax2)

plt.tight_layout()  # Adjust the layout
plt.show()

In [None]:
fontsize = 25

# Enable LaTeX text rendering in Matplotlib
plt.rcParams.update({
    "text.usetex": True,
    "text.latex.preamble": r"\usepackage{times}",  # Ensure you use the times package
    "font.family": "serif",
    "font.serif": ["Times", "Times New Roman"],  # This should use Times font
    "font.size": fontsize
})

def colorful_scatter_with_sizing(df, x, y, size_col, color_col, title, ax, show_colorbar=True, cbar_ax=None):
    scatter = ax.scatter(df[x], df[y], s=df[size_col]*10, alpha=0.7, c=df[color_col], cmap='rocket_r', edgecolor='black')

    # Adding colorbar conditionally
    if show_colorbar and cbar_ax is not None:
        cbar = plt.colorbar(scatter, cax=cbar_ax, label=color_col)
        cbar.ax.tick_params(labelsize=fontsize)  # Adjust for LaTeX

    # Adding labels and title
    ax.set_xlabel('Credit Amount')
    ax.set_ylabel(y)
    ax.set_title(title)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)

    # Set grid interval
    ax.set_xticks(np.arange(0, 15000, 4000))  # Set x grid interval
    ax.set_yticks(np.arange(20, 80, 10))  # Set y grid interval

    ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray')  # Grid lines like TikZ
    ax.set_facecolor('white')  # White background like TikZ
    ax.set_ylim(20, 80)

    # Make sure the spines (frame) are visible
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(0.5)

# Create a GridSpec with 2 columns with the second column being slightly narrower for the colorbar
fig = plt.figure(figsize=(12, 5))
gs = GridSpec(1, 3, width_ratios=[1, 1, 0.05])

# Create the two subplots and the colorbar axis
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
cbar_ax = fig.add_subplot(gs[2])

# Call the function with the axes object
colorful_scatter_with_sizing(factual, 'Credit amount', 'Age', 'Duration', 'Risk', 'Factual', ax1, show_colorbar=False)
colorful_scatter_with_sizing(counterfactual, 'Credit amount', 'Age', 'Duration', 'Risk', 'Counterfactual', ax2, cbar_ax=cbar_ax)

# Adjust the layout so there's no extra space
plt.tight_layout()

# Save the figure as a PDF
plt.savefig('pictures/german_credit_scatter.pdf', format='pdf', bbox_inches='tight')

# Display the plot
plt.show()

# Make sure to close the figure after saving to avoid memory issues
plt.close(fig)



In [None]:
factual.groupby('Saving accounts')['Risk'].mean()

In [None]:
counterfactual.groupby('Saving accounts')['Risk'].mean()

In [None]:
with open(os.path.join(data_path, 'explainer.pkl'), 'rb') as file:
    explainer = pickle.load(file)

In [None]:
col_names = []
col_scores = []
for column, col_index in zip(explainer.explain_columns, explainer.explain_indices):

    y_s = torch.FloatTensor(explainer.X[:, col_index])
    y_t = torch.FloatTensor(explainer.X_prime[:, col_index])
    wd_dist, _ = explainer.wd.distance(y_s, y_t, delta=0)

    col_names.append(column)
    col_scores.append(wd_dist.item())

pd.DataFrame({
    'Feature': col_names,
    'Score': col_scores,
}).sort_values(by='Score', ascending=False)

In [None]:
plt.figure(figsize=(16, 6))
heatmap = sns.heatmap(explainer.X_prime.corr(), vmin=-1, vmax=1, annot=True, cmap='BrBG')
heatmap.set_title('Correlation Heatmap', fontdict={'fontsize':18}, pad=12);
# save heatmap as .png file
# dpi - sets the resolution of the saved image in dots/inches
# bbox_inches - when set to 'tight' - does not allow the labels to be cropped
# plt.savefig('heatmap.png', dpi=300, bbox_inches='tight')

In [None]:

fa = torch.zeros_like(explainer.X_prime[:, explainer.explain_indices] @ explainer.swd.thetas[0])
dfa = torch.zeros_like(explainer.X_prime[:, explainer.explain_indices] @ explainer.swd.thetas[0])
for theta in explainer.swd.thetas:

    fa += explainer.X_prime[:, explainer.explain_indices] @ theta
    dfa += explainer.best_X[:, explainer.explain_indices] @ theta
    
fa /= len(explainer.swd.thetas)
dfa /= len(explainer.swd.thetas)

plot_quantile(factual=pd.DataFrame({'X': fa}), counterfactual=pd.DataFrame({'X': dfa}), column_name='X')

In [None]:
s_factual = ""
s_counterfactual = ""

for k, v in enumerate(np.sort(fa.numpy())):
    s_factual += f" ({np.round(v,6)},{np.round(k/100,6)})"

for k, v in enumerate(np.sort(dfa.numpy())):
    s_counterfactual += f" ({np.round(v,6)},{np.round(k/100,6)})"

print(s_factual)
print(s_counterfactual)

In [None]:
fa

In [None]:
matrix_nu = explainer.wd.nu.detach().numpy()

mu_avg = torch.zeros_like(explainer.swd.mu_list[0])
for mu in explainer.swd.mu_list:
    mu_avg += mu

total_sum = mu_avg.sum()

matrix_mu = mu_avg / total_sum

# Determine the global minimum and maximum values across both matrices
vmin = min(matrix_mu.min(), matrix_nu.min())
vmax = max(matrix_mu.max(), matrix_nu.max())

# Create a figure and a set of subplots
fig, axs = plt.subplots(1, 2, figsize=(20, 8))  # 1 row, 2 columns

# First subplot for matrix_mu
im_mu = axs[0].imshow(matrix_mu, cmap='viridis', vmin=vmin, vmax=vmax)
axs[0].set_title("Heatmap of $\mu$")

# Second subplot for matrix_nu
im_nu = axs[1].imshow(matrix_nu, cmap='viridis', vmin=vmin, vmax=vmax)
axs[1].set_title("Heatmap of $\\nu$")

# Create a colorbar for the whole figure
fig.colorbar(im_mu, ax=axs, orientation='vertical')

# Display the plots
plt.show()

In [None]:
row_num  = 16

# Interleave rows
combined = pd.concat([factual.head(row_num), counterfactual.head(row_num)]).sort_index(kind='merge')

# Define formatters for specific columns
formatters = {
    "Risk": lambda x: f"{x:.4f}"
}


# Convert to LaTeX
latex_code = combined.to_latex(index=False, formatters=formatters, 
                               caption="[\\textit{{German-Credit}}] Data points of factual and counterfactual distributions.", label="tab:german-credit")

print(latex_code)