In [None]:
############################ 0. PREPARATION ############################

#-------------------------- import packages --------------------------
import random
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm
from itertools import combinations
import seaborn as sns
import copy
import itertools

import tensorflow as tf
import tensorflow_lattice as tfl
import tf_keras as keras # need keras 2 to fit Lattice model
from tf_keras.models import Sequential, Model
from tf_keras.layers import Dense, Input, Multiply, Add, Embedding, Reshape, Concatenate, Dropout, BatchNormalization, Lambda, Layer, CategoryEncoding, Activation
from tf_keras.constraints import Constraint
from tf_keras.callbacks import EarlyStopping
from tf_keras.initializers import Zeros, Constant
from tf_keras.optimizers import Adam, Nadam, RMSprop
from tf_keras.models import clone_model
import keras_tuner as kt
from tf_keras import backend as K
from tf_keras import regularizers
from tf_keras.utils import plot_model
from tf_keras.losses import Poisson, Loss
from tf_keras.metrics import MeanAbsoluteError, RootMeanSquaredError
from scipy.stats import gamma


from pygam import LinearGAM, GAM, s, f, l
from sklearn.linear_model import ElasticNet
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import OneHotEncoder, StandardScaler, OrdinalEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import make_column_transformer
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error as mae
from sklearn.tree import DecisionTreeRegressor
from interpret.glassbox import ExplainableBoostingClassifier

In [None]:
# -------------------------- hyperparameters of final model --------------------------
imp_interactions = [("X3","X4"),("X5","X6")]
inputs = []  # input layers
sub_outputs = [] # subnet output
num_neurons_main = 40
num_layers_main = 5
num_neurons_interaction = 100
num_layers_interaction = 5
activation = 'leaky_relu'
lattice_smooth_reg = (0.1,0)
num_keypoints = 30
monotonicity_list = {"X3" : "decreasing"}
lattice_sizes_full = {}
for var in imp_vars:
    # for lattice size
    if var in cat_vars:
        lattice_sizes_full[var] = X_train[var].nunique()
    else:
        lattice_sizes_full[var] = 5

In [None]:
# -------------------------- model architecture --------------------------
# main effect
for name in imp_vars:
    
    # Input layers
    input_layer = Input(shape = (1,), name = name)
    inputs.append(input_layer)

    # subnetworks for main effects
    if name in cat_vars: # categorical variables
        # categorical variables will be one-hot encoded
        embed_layer = Embedding(input_dim = X_train[name].nunique(), 
                                output_dim = 1, 
                                name = f"{name}_embed")(input_layer)
        embed_layer_reshape = Reshape(target_shape = (1,), name = f"{name}_reshape")(embed_layer)
        cat_output = BatchNormalization(scale = False, name = f"{name}_dense")(embed_layer_reshape)
        sub_outputs.append(cat_output)
    elif name in monotonicity_list: # variables with monotonicity constraint
        calibrator_layer = tfl.layers.PWLCalibration(
            input_keypoints =  np.linspace(X_train[name].min(), X_train[name].max(), num = num_keypoints), # keypoints
            kernel_regularizer = ('hessian', lattice_smooth_reg[0], lattice_smooth_reg[1]), # for smoothness
            monotonicity = monotonicity_list[name], # monotonicity constraint
            name = f"{name}_calibrator"
        )(input_layer)
        lattice_layer = tfl.layers.Lattice(lattice_sizes = [lattice_sizes_full[name]], 
                             monotonicities = ["increasing"],
                             name = f"{name}_lattice")(calibrator_layer)
        mean_layer = BatchNormalization(scale = False, name = f"{name}_mean")(lattice_layer)
        sub_outputs.append(mean_layer)
    else: # numeric variables
        subnet = create_subnet(num_layers_main, num_neurons_main, activation, f"{name}_subnetwork")
        sub_output = subnet(input_layer)
        sub_outputs.append(sub_output)

# pairwise interaction effect
for (var1, var2) in imp_interactions:
    var1_input = inputs[imp_vars.index(var1)]
    var2_input = inputs[imp_vars.index(var2)]

    if any(var in monotonicity_list for var in [var1, var2]):
        lattice_inputs = []

        # create calibrator_layer
        calibrator_var1 = add_calibrate_layer(var1, X_train, monotonicity_list, lattice_sizes_full, num_keypoints, cat_vars)
        calibrator_var2 = add_calibrate_layer(var2, X_train, monotonicity_list, lattice_sizes_full, num_keypoints, cat_vars)
        calibrator_layer_var1 = calibrator_var1(var1_input)
        calibrator_layer_var2 = calibrator_var2(var2_input)
        
        # lattice
        lattice_inputs.append(calibrator_layer_var1)
        lattice_inputs.append(calibrator_layer_var2)
        lattice_layer = tfl.layers.Lattice(lattice_sizes = [lattice_sizes_full[var1], lattice_sizes_full[var2]], 
                             monotonicities = ["increasing" if var1 in monotonicity_list else 'none',
                                                "increasing" if var2 in monotonicity_list else 'none'],
                            #  kernel_regularizer = tfl.pwl_calibration_layer.HessianRegularizer(
                            #     l1 = 0.01, l2 = 0),
                             name = f"{var1}_{var2}_lattice")(lattice_inputs)
        pairwise_output = BatchNormalization(scale = False, name = f"{var1}_{var2}_mean")(lattice_layer)
        sub_outputs.append(pairwise_output)
    else:
        pairwise_input_layer = Concatenate(name = f"{var1}_{var2}_concat")([var1_input, var2_input])
        pairwise_subnet = create_subnet(num_layers_interaction, 
                                        num_neurons_interaction, 
                                        activation, 
                                        f"{var1}_{var2}_subnetwork")
        pairwise_dense = pairwise_subnet(pairwise_input_layer)
        sub_outputs.append(pairwise_dense)
    

# combine subnets' outputs
subnets = Concatenate(name = "subnet_output")(sub_outputs)
output_layer = AddSubnetOutput(activation = 'exponential',
                     name = "final_output")(subnets)

# final model
model_main_pairwise = Model(inputs = inputs, outputs = output_layer)

In [None]:
plot_model(model_main_pairwise)

In [None]:
# -------------------------- add marginal clarity constraint --------------------------
lambd = 5
for i in range(len(imp_interactions)):
    penalty = 0
    var1 = imp_interactions[i][0]
    var2 = imp_interactions[i][1]
    subnet_output = model_main_pairwise.get_layer("subnet_output").output
    main_output_var1 = subnet_output[imp_vars.index(var1)]
    main_output_var2 = subnet_output[imp_vars.index(var2)]
    interaction_output = subnet_output[len(imp_vars) + i]
    penalty = lambd * (
        K.abs(K.mean(main_output_var1 * interaction_output)) +
        K.abs(K.mean(main_output_var2 * interaction_output))
    )

    # Add the penalty to the model's total loss
    model_main_pairwise.add_loss(penalty)

In [None]:
# -------------------------- compile and fit --------------------------
# remove irrelevant main effects
X_train_important = []
X_val_important = []
X_test_important = []
for i in range(len(X_train_split)):
    if all_vars[i] in imp_vars:
        X_train_important.append(X_train_split[i])
        X_val_important.append(X_val_split[i])
        X_test_important.append(X_test_split[i])

model_main_pairwise.compile(optimizer = "rmsprop", loss = gamma_log_likelihood, metrics = [RootMeanSquaredError()])
es = EarlyStopping(restore_best_weights = True, patience = 10)
%time hist_main_pairwise = model_main_pairwise.fit(X_train_important, y_train, epochs = 5_000, \
    callbacks = [es], batch_size = 1_000, validation_data = (X_val_important, y_val))

In [None]:
# -------------------------- quantify variable importance --------------------------
subnet_layer = model_main_pairwise.get_layer("subnet_output")
subnet_output_model = Model(inputs = model_main_pairwise.inputs, 
                            outputs = subnet_layer.output)
subnet_output_values = subnet_output_model.predict(X_train_important, batch_size = X_train_important[0].shape[0])


# Get the variance of each subnetwork across all data points
subnet_variance = []
for i in range(subnet_output_values.shape[1]):
    subnet_variance.append(np.var(subnet_output_values[:, i]))

# Define covariates
covariates = [name for name in imp_vars]
for i in range(len(imp_interactions)):
    covariates.append(f"{imp_interactions[i][0]}_{imp_interactions[i][1]}")


# -------------------------- plot the variable importance --------------------------
# Create DataFrame for variable importance
var_importance = pd.DataFrame({"Covariates": covariates, "Subnet Variance": subnet_variance})

# Sort the DataFrame by subnet variance in decreasing order
varimp_sorted = var_importance.sort_values(by = "Subnet Variance", ascending = False)

# Create a bar plot using Seaborn
plt.figure(figsize = (10, 6))
sns.barplot(x = "Covariates", y = "Subnet Variance", data = varimp_sorted)
plt.title("Subnet Variance by Covariates")
plt.xlabel("Covariates")
plt.ylabel("Subnet Variance")
plt.show()

In [None]:
# -------------------------- Create the Pairwise Only Model --------------------------
pairwise_effect = ("X3", "X4")  # the pairwise effect we want to view
pairwise_index = imp_interactions.index(pairwise_effect)
subnet_output_layer = model_main_pairwise.get_layer("subnet_output")
subnet_model = Model(inputs = model_main_pairwise.inputs, outputs = subnet_output_layer.output)


# -------------------------- Generate the Input Grid --------------------------
# Create a grid of values
grid_length = 100
var1_values = np.linspace(X_train[imp_interactions[pairwise_index][0]].min(), 
                          X_train[imp_interactions[pairwise_index][0]].max(), 
                          grid_length)
var2_values = np.linspace(X_train[imp_interactions[pairwise_index][1]].min(), 
                          X_train[imp_interactions[pairwise_index][1]].max(), 
                          grid_length)

# Create a meshgrid
grid_var1, grid_var2 = np.meshgrid(var1_values, var2_values)

# Flatten the grid
grid_flat_var1 = grid_var1.ravel()
grid_flat_var2 = grid_var2.ravel()


# -------------------------- Prepare Inputs for the Model --------------------------
grid_inputs = []
for var in imp_vars:
    if var == imp_interactions[pairwise_index][0]:
        grid_inputs.append(grid_flat_var1)
    elif var == imp_interactions[pairwise_index][1]:
        grid_inputs.append(grid_flat_var2)
    else:
        grid_inputs.append(np.zeros_like(grid_flat_var1))


# -------------------------- Predict Using the Pairwise-Only Model --------------------------
# Predict the pairwise interaction effect
pairwise_predictions = subnet_model.predict(grid_inputs, batch_size = grid_length**2)[:,len(imp_vars) + pairwise_index]

# Reshape the predictions to the grid format
heatmap_values = pairwise_predictions.reshape(grid_var1.shape)


# -------------------------- Plot the Heatmap --------------------------
plt.figure(figsize = (10, 8))
contour = plt.contourf(grid_var1, grid_var2, heatmap_values, levels = 10)
plt.colorbar(contour)
plt.title('Predicted Function by the Model')
plt.xlabel(f"{imp_interactions[pairwise_index][0]}")
plt.ylabel(f"{imp_interactions[pairwise_index][1]}")
plt.show()

In [None]:
#-------------------------- view the shape function --------------------------
# get the prediction
subnet_output = subnet_output_model.predict(X_train_important)

# create the plot
var_name = "X4"
var_index = imp_vars.index(var_name)
plt.figure(figsize = (12, 6))
sns.lineplot(x = X_train_important[var_index], 
             y = subnet_output[:,var_index])
plt.xlabel(f"{var_name}")
plt.ylabel('Subnetwork Output')
plt.title(f"Shape function for {var_name}")
plt.grid(True)
plt.show()