In [1]:
import os
import copy
import time
import pickle
import sys

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from sian.utils import gettimestamp
from sian.data import Final_TabularDataset
from sian.models import TrainingArgs

from sian.fis import layerwise_FIS_Hyperparameters, batchwise_FIS_Hyperparameters
from sian.interpret import unmasked_FID_Hyperparameters, masked_FID_Hyperparameters
from sian import initalize_the_explainer

from sian import train_mlp_final, do_the_fis_final, train_sian_final #steps 1, 2, and 3
from sian.interpret import plot_all_GAM_functions #step 4


%load_ext autoreload
%autoreload 2

In [None]:
BS = 32
# EP = 100
EP = 10
LR = 5e-3

if True:
    dataset_str = "UCI_275_bike_sharing_dataset"
    preproc_owner = "SIAN2022"
if False:
    dataset_str = "UCI_186_wine_quality"
    preproc_owner = "SIAN2022"
if False:
    dataset_str = "UCI_2_adults_dataset"
    preproc_owner = "InstaSHAP2025"
if False:
    dataset_str = "UCI_31_tree_cover_type_dataset"
    preproc_owner = "InstaSHAP2025"


data_base_path = "../data/"
load_dataset_path = data_base_path
save_dataset_path = data_base_path+dataset_str+"/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mlp_training_args = TrainingArgs(BS, EP, LR, device)
sian_training_args = TrainingArgs(BS, EP, LR, device)




is_masked_mlp = True;  is_masked_sian = True;
# is_masked_mlp = False; is_masked_sian = False;

FIS_style = 'batchwise'
# FIS_style = 'layerwise'    

MAX_K = None
MAX_K = 3
MAX_K = 2 


In [None]:

results_path = "results/"
exp_datetimestr = gettimestamp()
exp_folder = results_path+exp_datetimestr +'_'+ "demo" +'_simple_testing/'
if not os.path.exists(exp_folder):
    os.makedirs(exp_folder)
print(exp_folder)

In [None]:

if True: #DEFAULT MODEL PARAMETRIZATION
    mlp_training_args.model_config.net_name = "MLP"
    mlp_training_args.model_config.sizes = [-1, 256, 128, 64, -1]
    mlp_training_args.model_config.is_masked = is_masked_mlp
    mlp_training_args.saving_settings.exp_folder = exp_folder
    
    sian_training_args.model_config.net_name = "SIAN-K"
    sian_training_args.model_config.sizes = [-1, 256, 128, 64, -1]
    sian_training_args.model_config.small_sizes = [-1, 32, 24, 16, -1]
    sian_training_args.model_config.is_masked = is_masked_sian
    sian_training_args.saving_settings.exp_folder = exp_folder

In [None]:
dataset_obj = \
    Final_TabularDataset(dataset_str, preproc_owner=preproc_owner,
                       load_dataset_path=load_dataset_path, 
                       save_dataset_path=save_dataset_path)     


In [None]:
D = dataset_obj.get_D()
readable_labels = dataset_obj.get_readable_labels()
print(readable_labels)

# SIAN Step 1: Train Masked MLP

In [None]:
mlp_results = train_mlp_final(dataset_obj, mlp_training_args)
trained_mlp = mlp_results["trained_mlp"]
val_tensor = mlp_results["val_tensor"]


# SIAN Step 2: Masked Archipelago FIS

### setup FID hypers

In [None]:

output_type = "regression"  #TODO: can set to classification when masking version has support (not sobol version)
grouped_features_dict = dataset_obj.get_grouped_feature_dict()
if is_masked_mlp: 
    fid_masking_style = "masking_based"
    score_type_name = "new_arch_inter_sobol_score"
    inc_rem_pel_list = ['inc_inter_sobol_score', 'rem_inter_sobol_score', 'new_arch_inter_sobol_score',] #NOTE: only for batchwise plots
    fis_valX = val_tensor

    my_FID_hypers = masked_FID_Hyperparameters(fid_masking_style, output_type, score_type_name, inc_rem_pel_list,
                                               grouped_features_dict)
else:    
    fid_masking_style = "triangle_marginal"
    score_type_name = "old_arch_inter_score"
    inc_rem_pel_list = ['inc_inter_score', 'rem_inter_score', 'old_arch_inter_score',] #NOTE: only for batchwise plots
    fis_valX = val_tensor.detach().cpu().numpy()
    
    my_FID_hypers = unmasked_FID_Hyperparameters(fid_masking_style, output_type, score_type_name, inc_rem_pel_list,
                                               device, grouped_features_dict)


### setup FIS hypers

In [None]:

if FIS_style=="batchwise":
    max_number_of_rounds = 5
    inters_per_round = 1
    tau_tup=(1.0,0.5,0.33)
    
    tau_thresholds = {}
    for k in range(MAX_K): #NOTE: no good MAX_K = None support yet
        tau_thresholds[k+1] = tau_tup[k]
    
    my_FIS_hypers = batchwise_FIS_Hyperparameters(MAX_K, tau_thresholds, max_number_of_rounds, inters_per_round,
                   # jam_arch, 
                   None, 
                   tuples_initialization=None,pick_underlings=False,fill_underlings=False,PLOTTING=True)

elif FIS_style=="layerwise":

    theta_percentile_mode=True
    theta_tup=(0.8,0.4,0.2)
    tau_tup=(1.0,0.5,0.33)
    
    tau_thresholds, theta_thresholds = {}, {}
    for k in range(MAX_K):
        tau_thresholds[k+1] = tau_tup[k]
        theta_thresholds[k+1] = theta_tup[k]

    my_FIS_hypers = layerwise_FIS_Hyperparameters(MAX_K, tau_thresholds, theta_thresholds, 
                   # jam_arch, 
                   None, 
                   theta_percentile_mode=theta_percentile_mode)
else:
    raise Exception(f"FIS_style={FIS_style} not recognized")


### finalize the FID and FIS hypers

In [None]:
jam_arch = initalize_the_explainer(trained_mlp, my_FID_hypers)
my_FIS_hypers.add_the_explainer(jam_arch)

### run the actual FIS

In [None]:
FIS_algorithm_start_time = time.time()
FIS_interactions = do_the_fis_final(my_FIS_hypers, fis_valX, AGG_K=100)
FIS_algorithm_time_taken = time.time() - FIS_algorithm_start_time
print("FIS_algorithm_time_taken",FIS_algorithm_time_taken)

# SIAN Step 3: Train the InstaSHAP GAM

In [None]:
print("FIS_interactions")
print(FIS_interactions)

In [None]:
sian_training_args.model_config.FIS_interactions = FIS_interactions
sian_results = train_sian_final(dataset_obj, sian_training_args)
trained_sian = sian_results["trained_sian"]
val_tensor = sian_results["val_tensor"]


# SIAN Step 4: Plotting Learned Shapes

In [None]:
full_readable_labels = dataset_obj.get_full_readable_labels()
plot_all_GAM_functions(trained_sian.cpu(), val_tensor.detach().cpu().numpy(),     full_readable_labels)


In [None]:
pass