In [None]:
import os
import sys
import numpy as np
import seaborn as sns
import explainer.rule_pattern_miner as rlm
import explainer.DT_rules as dtr

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn import Linear
from torch_geometric.nn import GraphConv
from torch_geometric.nn import global_mean_pool
from torch.utils.data import TensorDataset, DataLoader

from sklearn.metrics import recall_score, precision_score, roc_auc_score, roc_curve
from sklearn.tree import DecisionTreeClassifier, export_text

from rdkit import Chem
from rdkit.Chem import Descriptors

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


In [None]:
here = os.getcwd()
sys.path.append(os.path.join(here,"../../"))

seed = 42

In [None]:
def recover_feature_raw_value(fid,fval,featur_names,raw_min_max,ntype="min_max"):
    fn = featur_names[fid]
    if ntype == "min_max":
        if fn in raw_min_max.columns:
            mx = raw_min_max.loc[:,fn].max()
            mi = raw_min_max.loc[:,fn].min()
            return fval*(mx-mi)+mi
        return fval
    else:
        raise TypeError("Not yet supported type")

## 1.0 Data loading and preprocessing

### 1.1 Load the data into graph

In [None]:
from torch_geometric.datasets import MoleculeNet
dataset = MoleculeNet(root='data/MoleculeNet', name='Tox21')
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

### 1.2 Transform SMILES to RDKit descriptors

In [None]:
def calculate_all_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None  # Handle invalid SMILES input

    descriptor_name = []; value_list = []
    for name, descriptor_function in Descriptors.descList:
        value = descriptor_function(mol)
        value_list.append(value)
        descriptor_name.append(name)

    return value_list, descriptor_name

def smiles_to_rdkit(dataset):
    rdk_features = []
    for i in range(len(dataset)):
        smiles = dataset[i].smiles
        features, descriptors = calculate_all_descriptors(smiles)
        rdk_features.append(features)

    return torch.tensor(rdk_features, dtype=torch.float32), descriptors

### 1.3 Divide the data into training validatiton and testing

In [None]:
torch.manual_seed(seed)
np.random.seed(seed)
dataset = dataset.shuffle()

train_dataset = dataset[:int(len(dataset) * 0.7)]
val_dataset = dataset[int(len(dataset) * 0.7):int(len(dataset) * 0.8)]
test_dataset = dataset[int(len(dataset) * 0.8):]

## generate rdk descriptors from smiles representation of molecules
dataset_rdk, descriptors = smiles_to_rdkit(dataset)
## remove descriptors with nan values
descriptors = [descriptors[i] for i in range(len(descriptors)) if ~torch.any(dataset_rdk.isnan(), dim=0)[i]]
dataset_rdk = dataset_rdk[:,~torch.any(dataset_rdk.isnan(), dim=0)]
dataset_rdk = (dataset_rdk - dataset_rdk.mean(axis=0))/dataset_rdk.std(axis=0) ## normalize
descriptors = [descriptors[i] for i in range(len(descriptors)) if ~torch.any(dataset_rdk.isnan(), dim=0)[i]]
dataset_rdk = dataset_rdk[:,~torch.any(dataset_rdk.isnan(), dim=0)]

train_dataset_rdk = dataset_rdk[:int(len(dataset_rdk) * 0.7)]
val_dataset_rdk = dataset_rdk[int(len(dataset_rdk) * 0.7):int(len(dataset_rdk) * 0.8)]
test_dataset_rdk = dataset_rdk[int(len(dataset_rdk) * 0.8):]


print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(val_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

## task weights for training
mask = 1-torch.isnan(train_dataset.y).type(torch.float32)
task_weight = 1 / torch.mean(mask, axis = 0)
label_weight = torch.sum(train_dataset.y == 0., axis = 0) / torch.sum(train_dataset.y == 1., axis = 0) 

## 2.0 PyTorch model

In [None]:
batch_size = 64; learning_rate=0.001; num_epoch=200

### 2.1 MLP model on RDK features

### 2.2 GNN model

In [None]:


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(seed)
        self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)
        self.activation_fn_last = nn.Sigmoid()


    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.activation_fn_last(self.lin(x))

        return x

class GNNtraining(object):
    def __init__(self, 
                 model, 
                 learning_rate=0.001, 
                 num_epoch=200,
                 use_cuda=False):
        
        self.model = model
        self.learning_rate = learning_rate
        self.num_epoch = num_epoch
        
        self.use_cuda = use_cuda
        if use_cuda:
            self.model.cuda()
        
    def training(self, train_loader, val_loader, task_weight, label_weight):     
        parameters = set(self.model.parameters())
        optimizer = optim.Adam(parameters, lr=self.learning_rate, eps=1e-3)

        for epoch in range(self.num_epoch):
            for data in train_loader:
                y_batch = torch.nan_to_num(data.y, nan=0.0) # nan to 0.0
                
                task_weight_batch = ~data.y.isnan()*task_weight # weight each task according to the number of valid labels
                label_weight_batch = y_batch * label_weight; label_weight_batch[label_weight_batch==0.0] = 1.0 # weight each label according to the number of positive labels
                w_batch = task_weight_batch * label_weight_batch
                if self.use_cuda:
                    data = data.cuda(); y_batch = y_batch.cuda(); w_batch = w_batch.cuda()
                criterion = nn.BCELoss(weight=w_batch)
                optimizer.zero_grad()
                self.model.train()
                # calculate the training loss
                output = self.model(data.x.to(torch.float32), data.edge_index, data.batch)
                loss = criterion(output, y_batch)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            train_bce = loss.detach()
            train_acc = self.evaluation(train_loader)
            val_acc = self.evaluation(val_loader)
            print('>>> Epoch {:5d}/{:5d} | train_bce={:.5f} | train_acc={:.5f} | val_acc={:.5f}'.format(epoch, self.num_epoch, train_bce, train_acc, val_acc))
                
    def evaluation(self, loader):
        self.model.eval()
        correct = 0; total = 0
        for data in loader:
            if self.use_cuda:
                data = data.cuda()
            output = self.model(data.x.to(torch.float32), data.edge_index, data.batch)
            pred = (output > 0.5).to(torch.float32)
            correct += int((pred == data.y).sum())
            total += int((~data.y.isnan()).sum())

        return correct/total


In [None]:
from torch_geometric.loader import DataLoader as GDataLoader

torch.manual_seed(seed)
train_loader = GDataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = GDataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
torch.manual_seed(seed)
gnn = GNN(hidden_channels=64)
gnn.load_state_dict(torch.load('gnn_models/gnn_42.pt'))

## 3.0 Interpretation

In [None]:
torch.manual_seed(seed)
np.random.seed(seed)

train_loader_visualization = GDataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False)
for train_data in train_loader_visualization:
    break
for val_data in val_loader:
    break
train_loader_rdk_visualization = DataLoader(TensorDataset(train_dataset_rdk, train_dataset.y), batch_size=len(train_dataset_rdk), shuffle=False)
for train_data_x, train_data_y in train_loader_rdk_visualization:
    break

y_pred_s_gnn = gnn(train_data.x.to(torch.float32), train_data.edge_index, train_data.batch).detach().numpy()

y_pred_test_gnn = gnn(val_data.x.to(torch.float32), val_data.edge_index, val_data.batch).detach().numpy()


### 3.1 Task AR

#### 3.11 Task AR; select hyper-parameters of AMORE and DT.

In [None]:
idx_task = 0

y_true = train_data.y.numpy()[:,idx_task:idx_task+1].reshape(-1); idx = ~np.isnan(y_true)
y_pred_s = y_pred_s_gnn[:,idx_task:idx_task+1].reshape(-1)[idx]
y_true = y_true[idx].astype(int)

fpr, tpr, thresholds = roc_curve(y_true, y_pred_s)
th_id = np.argmax(tpr - fpr); y_thd = thresholds[th_id]

fids = [i for i in range(len(descriptors))]
feature_types = []
for i in range(len(descriptors)):
    c = descriptors[i]
    if np.dtype(train_data_x.numpy()[idx][:,i][0])!=np.uint8:
        feature_types.append(str(np.dtype(train_data_x.numpy()[idx][:,i][0])))
    else:
        feature_types.append("cat")
    # print(c,np.dtype(train_data_x.numpy()[idx][:,i][0]),feature_types[-1])

y_true_test = val_data.y.numpy()[:,idx_task:idx_task+1].reshape(-1); idx_test = ~np.isnan(y_true_test)
y_pred_test = y_pred_test_gnn[:,idx_task:idx_task+1].reshape(-1)[idx_test]
y_true_test = y_true_test[idx_test].astype(int)

print('Task {}: total positive: {:}; total predictived positive:{:}; recall:{:.5f}; precision:{:.5f}; auroc: {:.5f}; test auroc:  {:.5f}'.format(idx_task, sum(y_true), sum(y_pred_s>y_thd), recall_score(y_true, y_pred_s>y_thd), precision_score(y_true, y_pred_s>y_thd), roc_auc_score(y_true, y_pred_s), roc_auc_score(y_true_test, y_pred_test)))


In [None]:
## set "grid_search = True" to do a grid search for hyperprameters

grid_search = True

if grid_search:
    ng_range = np.arange(2,21)
    bin_strategies = ["uniform","kmeans"]
    support_range = np.arange(100,1000,100)
    confidence_lower_bound = 0.8
    max_depth=2
    top_K=3


    best_rule_set,best_configs,config_metric_records = rlm.param_grid_search_for_amore(bin_strategies,ng_range,support_range,train_data_x.numpy()[idx],fids,target_indices=y_pred_s>y_thd,y=y_true,c=1,confidence_lower_bound = confidence_lower_bound,
                                                                                        max_depth=max_depth,top_K=top_K,sort_by="fitness")

In [None]:
## set "grid_search = True" to do a grid search for hyperprameters

grid_search = True

if grid_search:
    criteria = ["gini", "entropy", "log_loss"]; c=1.
    w = (y_true==c).sum()/y_true.shape[0]
    # class_weight_options = [{0:0.5,1:0.5},{0:1./(1.-w),1:1./w}]
    class_weight_options = [{0:0.5,1:0.5},'balanced']
    support_range = np.arange(100,1000,100)
    confidence_lower_bound = 0.
    max_depth=2
    DT_best_rule_set, DT_best_configs, DT_config_metric_records = dtr.param_grid_search_for_DT(criteria,support_range,weight_options=class_weight_options,X=train_data_x.numpy()[idx],y=y_true,target_indices=y_pred_s>y_thd,c=1,max_depth=max_depth,feature_names=descriptors,confidence_lower_bound=confidence_lower_bound,seed=seed)


#### 3.12 Task AR; fitness and confidence with different minimum support

In [None]:
confidence_lower_bound_dt = 0.0
confidence_lower_bound_amore=0.8

cf_mtx = np.vstack([config_metric_records[key]['top_confidence_records'] for key in config_metric_records.keys()])
ft_mtx = np.vstack([config_metric_records[key]['top_fitness_records'] for key in config_metric_records.keys()])
as_mtx = np.vstack([config_metric_records[key]['actual_support'] for key in config_metric_records.keys()])
DT_cf_mtx = np.vstack([DT_config_metric_records[key]['top_confidence_records'] for key in DT_config_metric_records.keys()])
DT_ft_mtx = np.vstack([DT_config_metric_records[key]['top_fitness_records'] for key in DT_config_metric_records.keys()])
DT_as_mtx = np.vstack([DT_config_metric_records[key]['actual_support'] for key in DT_config_metric_records.keys()])
best_cfs,best_fts,best_ass=[],[],[]
DT_best_cfs,DT_best_fts,DT_best_ass=[],[],[]
ft_mtx_cp = ft_mtx.copy()
DT_ft_mtx_cp = DT_ft_mtx.copy()
ft_mtx_cp[cf_mtx<confidence_lower_bound_amore]=0
DT_ft_mtx_cp[DT_cf_mtx<confidence_lower_bound_dt]=0.
for s in range(cf_mtx.shape[1]):
    cid = np.argmax(ft_mtx_cp[:,s])
    bc = cf_mtx[cid,s]
    if bc >= confidence_lower_bound_amore:
        best_cfs.append(bc)
        best_fts.append(ft_mtx[cid,s])
        best_ass.append(as_mtx[cid,s])
    else:
        cid = np.argmax(ft_mtx[:,s])
        bc = cf_mtx[cid,s]
        best_cfs.append(bc)
        best_fts.append(ft_mtx[cid,s])
        best_ass.append(as_mtx[cid,s])
    cid = np.argmax(DT_ft_mtx_cp[:,s])
    bc = DT_cf_mtx[cid,s]
    if bc >= confidence_lower_bound_dt:
        DT_best_cfs.append(bc)
        DT_best_fts.append(DT_ft_mtx[cid,s])
        DT_best_ass.append(DT_as_mtx[cid,s])
    else:
        cid = np.argmax(DT_ft_mtx[:,s])
        bc = DT_cf_mtx[cid,s]
        DT_best_cfs.append(bc)
        DT_best_fts.append(DT_ft_mtx[cid,s])
        DT_best_ass.append(DT_as_mtx[cid,s])


sns.set_style('whitegrid')
plt.figure(figsize=(5,4))
color1 = '#377EB8'  # Blue
color2 = '#E41A1C'  # Red
color3 = '#4DAF4A'  # green
plt.plot(support_range,best_cfs,'-o',color=color1,markersize=4)
plt.plot(support_range,best_fts,'--o',color=color1,markersize=4)
plt.plot(support_range,DT_best_cfs,'-o',color=color2,markersize=4)
plt.plot(support_range,DT_best_fts,'--o',color=color2,markersize=4)
plt.xlim(100,900)
plt.ylim(0.,1.)
# Creating custom lines for the color legend
custom_lines_color = [Line2D([0], [0], color=color1, lw=4),
                      Line2D([0], [0], color=color2, lw=4)]
# Creating custom lines for the line style legend
custom_lines_style = [Line2D([0], [0], color='grey', lw=2, linestyle='-'),
                      Line2D([0], [0], color='grey', lw=2, linestyle='--')]

plt.xlabel('Specified minimum support')
# plt.savefig('plot/compare_DT_AR.svg')

#### 3.13 Task AR; Rules extracted from AMORE and DT

In [None]:
### search rules for target pattern: pred_y > y_thd  ###
### we set the hyperparameters obtaind by above grid search step ###

min_support=best_configs['min_support']
bin_strategy=best_configs['bin_strategy']
num_grids=best_configs['num_grids']
max_depth=2
top_K=3

y_rule_candidates = rlm.gen_rule_list_for_one_target(train_data_x.numpy()[idx],fids,y_pred_s>y_thd,y=y_true,c=1,sort_by="fitness",
                                                    min_support=min_support,num_grids=num_grids,max_depth=max_depth,top_K=top_K,
                                                    local_x=None,feature_types=feature_types,bin_strategy=bin_strategy,
                                                    verbose=False)

for i, rules in enumerate(y_rule_candidates):   
    rules["rules"] = rlm.replace_feature_names(rules["rules"],descriptors)
    y_rule_candidates[i] = rules
y_rule_candidates

In [None]:
### Obtain rules for target pattern: pred_y > y_thd from a DecisionTreeClassifier ###
### We set the hyperparameters obtaind by above grid search step ###
criterion = DT_best_configs['criterion']
class_weight = DT_best_configs['class_weight']
min_support = DT_best_configs['min_support']

input_feature_names =descriptors
treemodel = DecisionTreeClassifier(max_depth=max_depth,min_samples_leaf=min_support,criterion=criterion,random_state=seed,class_weight=class_weight)
treemodel.fit(train_data_x.numpy()[idx],y_pred_s>y_thd)
rule_list, rule_value_list, rule_metric_list, new_lines = dtr.obtain_rule_lists_from_DT(treemodel,train_data_x.numpy()[idx],y_true,y_pred_s>y_thd,np.arange(train_data_x.numpy()[idx].shape[-1]),descriptors,c=1)
print(export_text(treemodel))

## display rules extracted by DT classifier
dtr.display_rules_from_DT(rule_list,rule_metric_list,input_feature_names)

### 3.2 Task ER-LBD

### 3.21 Task ER-LBD; select hyper-parameters of AMORE and DT.

In [None]:
idx_task = 4

y_true = train_data.y.numpy()[:,idx_task:idx_task+1].reshape(-1); idx = ~np.isnan(y_true)
y_pred_s = y_pred_s_gnn[:,idx_task:idx_task+1].reshape(-1)[idx]
y_true = y_true[idx].astype(int)

fpr, tpr, thresholds = roc_curve(y_true, y_pred_s)
th_id = np.argmax(tpr - fpr); y_thd = thresholds[th_id]

fids = [i for i in range(len(descriptors))]
feature_types = []
for i in range(len(descriptors)):
    c = descriptors[i]
    if np.dtype(train_data_x.numpy()[idx][:,i][0])!=np.uint8:
        feature_types.append(str(np.dtype(train_data_x.numpy()[idx][:,i][0])))
    else:
        feature_types.append("cat")
    # print(c,np.dtype(train_data_x.numpy()[idx][:,i][0]),feature_types[-1])

y_true_test = val_data.y.numpy()[:,idx_task:idx_task+1].reshape(-1); idx_test = ~np.isnan(y_true_test)
y_pred_test = y_pred_test_gnn[:,idx_task:idx_task+1].reshape(-1)[idx_test]
y_true_test = y_true_test[idx_test].astype(int)

print('Task {}: total positive: {:}; total predictived positive:{:}; recall:{:.5f}; precision:{:.5f}; auroc: {:.5f}; test auroc:  {:.5f}'.format(idx_task, sum(y_true), sum(y_pred_s>y_thd), recall_score(y_true, y_pred_s>y_thd), precision_score(y_true, y_pred_s>y_thd), roc_auc_score(y_true, y_pred_s), roc_auc_score(y_true_test, y_pred_test)))


In [None]:
## set "grid_search = True" to do a grid search for hyperprameters

grid_search = True

if grid_search:
    ng_range = np.arange(2,21)
    bin_strategies = ["uniform","kmeans"]
    support_range = np.arange(100,2000,200)
    confidence_lower_bound = 0.8
    max_depth=2
    top_K=3


    best_rule_set,best_configs,config_metric_records = rlm.param_grid_search_for_amore(bin_strategies,ng_range,support_range,train_data_x.numpy()[idx],fids,target_indices=y_pred_s>y_thd,y=y_true,c=1,confidence_lower_bound = confidence_lower_bound,
                                                                                        max_depth=max_depth,top_K=top_K,sort_by="fitness")

                                                                    

In [None]:
## set "grid_search = True" to do a grid search for hyperprameters

grid_search = True

if grid_search:
    criteria = ["gini", "entropy", "log_loss"]; c=1.
    w = (y_true==c).sum()/y_true.shape[0]
    class_weight_options = [{0:0.5,1:0.5},'balanced']
    support_range = np.arange(100,2000,200)
    confidence_lower_bound = 0.
    max_depth=2
    DT_best_rule_set, DT_best_configs, DT_config_metric_records = dtr.param_grid_search_for_DT(criteria,support_range,weight_options=class_weight_options,X=train_data_x.numpy()[idx],y=y_true,target_indices=y_pred_s>y_thd,c=1,max_depth=max_depth,feature_names=descriptors,confidence_lower_bound=confidence_lower_bound,seed=seed)


#### 3.22 Task ER-LBD; fitness and confidence with different minimum support

In [None]:
confidence_lower_bound_dt = 0.0
confidence_lower_bound_amore=0.8


cf_mtx = np.vstack([config_metric_records[key]['top_confidence_records'] for key in config_metric_records.keys()])
ft_mtx = np.vstack([config_metric_records[key]['top_fitness_records'] for key in config_metric_records.keys()])
as_mtx = np.vstack([config_metric_records[key]['actual_support'] for key in config_metric_records.keys()])
DT_cf_mtx = np.vstack([DT_config_metric_records[key]['top_confidence_records'] for key in DT_config_metric_records.keys()])
DT_ft_mtx = np.vstack([DT_config_metric_records[key]['top_fitness_records'] for key in DT_config_metric_records.keys()])
DT_as_mtx = np.vstack([DT_config_metric_records[key]['actual_support'] for key in DT_config_metric_records.keys()])
best_cfs,best_fts,best_ass=[],[],[]
DT_best_cfs,DT_best_fts,DT_best_ass=[],[],[]
ft_mtx_cp = ft_mtx.copy()
DT_ft_mtx_cp = DT_ft_mtx.copy()
ft_mtx_cp[cf_mtx<confidence_lower_bound_amore]=0
DT_ft_mtx_cp[DT_cf_mtx<confidence_lower_bound_dt]=0.
for s in range(cf_mtx.shape[1]):
    cid = np.argmax(ft_mtx_cp[:,s])
    bc = cf_mtx[cid,s]
    if bc >= confidence_lower_bound_amore:
        best_cfs.append(bc)
        best_fts.append(ft_mtx[cid,s])
        best_ass.append(as_mtx[cid,s])
    else:
        cid = np.argmax(ft_mtx[:,s])
        bc = cf_mtx[cid,s]
        best_cfs.append(bc)
        best_fts.append(ft_mtx[cid,s])
        best_ass.append(as_mtx[cid,s])
    cid = np.argmax(DT_ft_mtx_cp[:,s])
    bc = DT_cf_mtx[cid,s]
    if bc >= confidence_lower_bound_dt:
        DT_best_cfs.append(bc)
        DT_best_fts.append(DT_ft_mtx[cid,s])
        DT_best_ass.append(DT_as_mtx[cid,s])
    else:
        cid = np.argmax(DT_ft_mtx[:,s])
        bc = DT_cf_mtx[cid,s]
        DT_best_cfs.append(bc)
        DT_best_fts.append(DT_ft_mtx[cid,s])
        DT_best_ass.append(DT_as_mtx[cid,s])


sns.set_style('whitegrid')
plt.figure(figsize=(5,4))
color1 = '#377EB8'  # Blue
color2 = '#E41A1C'  # Red
color3 = '#4DAF4A'  # green
plt.plot(support_range,best_cfs,'-o',color=color1,markersize=4)
plt.plot(support_range,best_fts,'--o',color=color1,markersize=4)
plt.plot(support_range,DT_best_cfs,'-o',color=color2,markersize=4)
plt.plot(support_range,DT_best_fts,'--o',color=color2,markersize=4)
plt.xlim(100,1900)
plt.ylim(0.,1.)
# Creating custom lines for the color legend
custom_lines_color = [Line2D([0], [0], color=color1, lw=4),
                      Line2D([0], [0], color=color2, lw=4)]
# Creating custom lines for the line style legend
custom_lines_style = [Line2D([0], [0], color='grey', lw=2, linestyle='-'),
                      Line2D([0], [0], color='grey', lw=2, linestyle='--')]
plt.xticks(np.arange(100, 2000, 200))
plt.xlabel('Specified minimum support')
# plt.savefig('plot/compare_DT_ER-LBD.svg')   

#### 3.23 Task ER-LBD; Rules extracted from AMORE and DT

In [None]:
### search rules for target pattern: pred_y > y_thd  ###
### we set the hyperparameters obtaind by above grid search step ###

min_support=best_configs['min_support']
bin_strategy=best_configs['bin_strategy']
num_grids=best_configs['num_grids']
max_depth=2
top_K=3

y_rule_candidates = rlm.gen_rule_list_for_one_target(train_data_x.numpy()[idx],fids,y_pred_s>y_thd,y=y_true,c=1,sort_by="fitness",
                                                    min_support=min_support,num_grids=num_grids,max_depth=max_depth,top_K=top_K,
                                                    local_x=None,feature_types=feature_types,bin_strategy=bin_strategy,
                                                    verbose=False)

for i, rules in enumerate(y_rule_candidates):   
    rules["rules"] = rlm.replace_feature_names(rules["rules"],descriptors)
    y_rule_candidates[i] = rules
y_rule_candidates

In [None]:
### Obtain rules for target pattern: pred_y > y_thd from a DecisionTreeClassifier ###
### We set the hyperparameters obtaind by above grid search step ###
criterion = DT_best_configs['criterion']
class_weight = DT_best_configs['class_weight']
min_support = DT_best_configs['min_support']


input_feature_names =descriptors
treemodel = DecisionTreeClassifier(max_depth=max_depth,min_samples_leaf=min_support,criterion=criterion,random_state=seed,class_weight=class_weight)
treemodel.fit(train_data_x.numpy()[idx],y_pred_s>y_thd)
rule_list, rule_value_list, rule_metric_list, new_lines = dtr.obtain_rule_lists_from_DT(treemodel,train_data_x.numpy()[idx],y_true,y_pred_s>y_thd,np.arange(train_data_x.numpy()[idx].shape[-1]),descriptors,c=1)
print(export_text(treemodel))

## display rules extracted by DT classifier
dtr.display_rules_from_DT(rule_list,rule_metric_list,input_feature_names)