In [1]:
import dice_ml
from dice_ml.utils import helpers # helper functions
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from microbiome_ml.data_processing import load_data, filter_data, clr_transform
from microbiome_ml.modeling import train_model

In [2]:
abundance_path = "../Dataset/abundance_crc.txt"
metadata_path = "../Dataset/metadata_crc.txt"
target_column = "Group"
abundance, labels = load_data(abundance_file=abundance_path, metadata_file=metadata_path, target_column=target_column)
filtered_data = filter_data(abundance)
clr_data = clr_transform(filtered_data)

In [3]:
model = train_model(clr_data, labels, n_jobs=1)
clr_data['target'] = labels.values
dataset = clr_data
target = dataset['target']
train_dataset, test_dataset, _, _ = train_test_split(dataset,
                                                     target,
                                                     test_size=0.2,
                                                     random_state=0,
                                                     stratify=target)
# model = LogisticRegression()
# model.fit(train_dataset.drop(columns="target"), train_dataset["target"])

# Dataset for training an ML model
d = dice_ml.Data(dataframe=train_dataset,
                 continuous_features=dataset.columns[:-1].tolist(),
                 outcome_name='target')


In [4]:
# Pre-trained ML model
m = dice_ml.Model(model=model,
                  backend='sklearn')
# DiCE explanation instance
exp = dice_ml.Dice(d,m, method="genetic")
# Generate counterfactual examples
queries = test_dataset[test_dataset["target"] == 1].drop(columns="target")
query_instance = queries[0:1]
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite", verbose=True)
# Visualize counterfactual explanation
dice_exp.visualize_as_dataframe(show_only_changes=True)

  0%|          | 0/1 [00:00<?, ?it/s]

Initializing initial parameters to the genetic algorithm...
Initialization complete! Generating counterfactuals...


100%|██████████| 1/1 [00:00<00:00,  2.46it/s]

Diverse Counterfactuals found! total time taken: 00 min 00 sec
Query instance (original outcome : 1)





Unnamed: 0,Victivallis vadensis [1000],Akkermansia muciniphila [1008],Alistipes shahii [1052],unnamed Alistipes sp. HGB5 [1053],Alistipes putredinis [1054],Porphyromonas asaccharolytica [1056],Porphyromonas uenonis [1057],Prevotella buccalis [1059],Prevotella buccae [1065],Prevotella dentalis [1067],...,Bifidobacterium bifidum [968],Bifidobacterium breve [969],Bifidobacterium longum [970],Bifidobacterium dentium [971],Bifidobacterium adolescentis [972],Bifidobacterium catenulatum-Bifidobacterium pseudocatenulatum complex [973],Bifidobacterium angulatum [974],Parascardovia denticolens [976],Scardovia inopinata [977],target
0,-1.392562,2.105299,-0.737418,1.244804,1.749803,-0.648682,-0.370292,-0.216932,-0.369124,-0.291348,...,-3.153677,-1.074443,-6.321915,-2.7292,-5.181201,-4.566145,-2.094948,-0.462484,-0.31182,1



Diverse Counterfactual set (new outcome: 0)


Unnamed: 0,Victivallis vadensis [1000],Akkermansia muciniphila [1008],Alistipes shahii [1052],unnamed Alistipes sp. HGB5 [1053],Alistipes putredinis [1054],Porphyromonas asaccharolytica [1056],Porphyromonas uenonis [1057],Prevotella buccalis [1059],Prevotella buccae [1065],Prevotella dentalis [1067],...,Bifidobacterium bifidum [968],Bifidobacterium breve [969],Bifidobacterium longum [970],Bifidobacterium dentium [971],Bifidobacterium adolescentis [972],Bifidobacterium catenulatum-Bifidobacterium pseudocatenulatum complex [973],Bifidobacterium angulatum [974],Parascardovia denticolens [976],Scardovia inopinata [977],target
0,-1.3925618,2.5960073,0.2726408,1.2961415,-1.4707574,-0.648682,-0.37029198,-0.21693194,-0.3691239,-0.29134834,...,2.3052354,-1.0744432,-0.585362,-2.7292001,-0.23809,1.3244883,2.6235359,-0.4624842,-0.31181964,0.0
0,-1.3925618,-0.5483761,1.8003366,2.8619192,2.7115362,-0.648682,-0.37029198,-0.21693194,-0.3691239,0.1859421,...,-0.2381356,-1.0744432,0.757404,-2.7292001,-1.525136,0.4966835,-2.0949483,1.2733662,-0.31181964,0.0
0,4.060926,5.8407087,1.027295,-0.3991702,-0.2232277,3.177045,0.83171743,0.72164696,-0.3691239,-0.29134834,...,5.135869,-0.0267243,3.780432,2.3585169,3.188324,-0.2514397,-0.6094009,1.6135049,-0.31181964,0.0
0,-1.3925618,-1.1396621,-2.0101523,-1.4717476,-2.3619301,-0.648682,-0.37029198,-0.21693194,-0.3691239,-0.29134834,...,5.7065363,3.8999512,3.153469,1.5425001,2.868416,5.0884075,1.445936,-0.4624842,-0.31181964,0.0
0,1.5887978,2.6953855,-2.6713111,-0.7949058,2.8986728,0.308279,-0.1064781,-0.21693194,-0.3691239,0.38246113,...,6.0879245,0.9306861,1.00752,0.9376402,3.257947,3.1315076,1.3579293,2.1745932,-0.31181964,0.0
