In [1]:
import torch
import os
import csv
import math
import numpy as np

In [2]:
import sys
import os
from importlib import reload
here = os.getcwd()
sys.path.append(os.path.join(here,"../"))

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from models.cde.cde_data_common import process_data,get_final_linear_input_channels,get_final_indices,wrap_data,augment_data
import models.cde.cde_train_common as train_common

from utils.test_utils import make_results_filenames

#from explainer.integrad import integrad
from explainer.FPGrowth_tree import *
from explainer.rule_pattern_miner import *
from explainer.explainer_utils import *
import explainer.RuleGrowth_tree as rgtree

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score,confusion_matrix,precision_score,recall_score,accuracy_score,roc_auc_score
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt 
import seaborn as sns

In [5]:
def new_make_model():
    model, regularise = make_model()
    model.linear.weight.register_hook(lambda grad: 100 * grad)
    model.linear.bias.register_hook(lambda grad: 100 * grad)
    return model, regularise

In [6]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dpath',default='../data', help='Path to data file')
parser.add_argument('--rpath',default='./results', help='Path to save results')
parser.add_argument('--use_phys',action='store_true', help='Use physiology data or not')
parser.add_argument('--seed',default=0,type=int, help='Random seed')
parser.add_argument('--model_name',default='ncde',type=str, help='Model name')
parser.add_argument('--interpolate',default='cubic_spline',type=str, help='Interpolation function name')
parser.add_argument('--device',default='cpu',type=str, help='cpu or cuda')
parser.add_argument('--side_input',action='store_true', help='Use side input to final task')
parser.add_argument('--concat_z',action='store_true', help='Concat hidden states for the final task')
parser.add_argument('--time_intensity',action='store_true', help='Add time intensity')
parser.add_argument('--append_times',action='store_true', help='Append time indices as one feature')
parser.add_argument('--intensity',action='store_true', help='Add X intensity')
parser.add_argument('--time_len',default=12,type=int, help='Length of time indices')
parser.add_argument('--hidden_channels',default=2,type=int, help='Dimension of hidden states z')
parser.add_argument('--hidden_hidden_channels',default=128,type=int, help='Dimension of hidden units of f')
parser.add_argument('--num_hidden_layers',default=4,type=int, help='Number of hidden layers of f')
parser.add_argument('--batch_size',default=1024,type=int, help='Batch size')
parser.add_argument('--max_epochs',default=500,type=int, help='Maximum epochs')
parser.add_argument('--pos_weight',default=20,type=int, help='Weight of positive class')
parser.add_argument('--lr',default=0.0001,type=float, help='Raw learning rate')
parser.add_argument('--K',default=5,type=int, help='K-fold cross-validation')

args,_ = parser.parse_known_args()

#print(args)




model_name=args.model_name 
#device=args.device
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

    
args.interpolate = "linear"    
args.side_input = False
args.concat_z = True

args.time_intensity = True
args.intensity = True 

args.time_len = 72
args.max_epochs = 50
args.pos_weight = 20

hidden_channels=args.hidden_channels 
hidden_hidden_channels=args.hidden_hidden_channels 
num_hidden_layers=args.num_hidden_layers

batch_size = args.batch_size

lr = args.lr * (batch_size / 64)


dpath = args.dpath
use_phys = args.use_phys
concat_z = args.concat_z
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)

name = make_results_filenames(args,'sepsis')

num_classes = 2
cummean = True
cumsum = False

In [33]:
pp = "/Users/chenyu/github/NDE-Models-DigitalHealth/"
#args.model_path = pp+"xncde/notebooks/results/sepsis/intensity_time_intensity_concatz/zdim2_hdim128_nlayer4_bs1024/posw20/interp_cubic_spline/model_2"
args.model_path = "./results/sepsis/intensity_time_intensity_concatz/zdim2_hdim128_nlayer4_bs1024/posw20/interp_linear/model_1"

In [8]:
base_loc = pp+'data/raw/sepsis/'
static_intensity = True
time_intensity = True

In [9]:
X_times = []
X_static = []
y = []
H = args.time_len
for filename in os.listdir(base_loc):
    if filename.endswith('.psv'):
        with open(os.path.join(base_loc,filename)) as file:
            time = []
            label = 0.0
            reader = csv.reader(file, delimiter='|')
            reader = iter(reader)
            next(reader)  # first line is headings
            prev_iculos = 0
            for line in reader:
                assert len(line) == 41
                # time values are 34 features
                *time_values, age, gender, unit1, unit2, hospadmtime, iculos, sepsislabel = line
                iculos = int(iculos)
                #print('iculos',iculos)
                if iculos > H:  # keep at most the first H hours
                    break
                ## padding nan for missing hours
                for iculos_ in range(prev_iculos + 1, iculos):
                    time.append([float('nan') for value in time_values])
                    #time.append(np.zeros(len(time_values))+np.nan)
                prev_iculos = iculos
                time.append([float(value) for value in time_values])
                label = max(label, float(sepsislabel))
            unit1 = float(unit1)
            unit2 = float(unit2)
            unit1_obs = not math.isnan(unit1)
            unit2_obs = not math.isnan(unit2)
            if not unit1_obs:
                unit1 = 0.
            if not unit2_obs:
                unit2 = 0.
            hospadmtime = float(hospadmtime)
            if math.isnan(hospadmtime):
                hospadmtime = 0.  # this only happens for one record
            static = [float(age), float(gender), unit1, unit2, hospadmtime]
            if static_intensity:
                static += [unit1_obs, unit2_obs]
            if len(time) > 2:
                if len(time) < H:
                    # padding less hours
                    for t in range(H-len(time)):
                        time.append([float('nan') for value in time_values])
                X_times.append(time)
                X_static.append(static)
                y.append(label)
    

In [10]:
len(X_times[0])

72

In [11]:
times = np.arange(args.time_len).astype(np.float32)
times = torch.tensor(times)

In [12]:
X_times = np.array(X_times)
X_static = np.array(X_static)
y = np.array(y)
#final_indices = torch.tensor(final_indices)

#times = torch.linspace(1, H, H)

In [13]:
for c in range(X_times.shape[-1]):
    mi = X_times[:,:,c][~np.isnan(X_times[:,:,c])].min()
    ma = X_times[:,:,c][~np.isnan(X_times[:,:,c])].max()
    X_times[:,:,c] = (X_times[:,:,c] - mi)/(ma - mi)+1.
    X_times[:,:,c][np.isnan(X_times[:,:,c])] = 0.

In [14]:
X_times = X_times.astype(np.float32)
y = y.astype(np.float32)

In [15]:
X_times.shape,X_static.shape,y.shape,y.max()

((40333, 72, 34), (40333, 7), (40333,), 1.0)

In [16]:
X_train,X_test,y_train,y_test = train_test_split(X_times,y,test_size=0.2)
X_test,X_val,y_test,y_val = train_test_split(X_test,y_test,test_size=0.5)

In [17]:
X_train.shape,X_test.shape,X_val.shape

((32266, 72, 34), (4033, 72, 34), (4034, 72, 34))

In [18]:
y_train.sum(),y_test.sum(),y_val.sum()

(1741.0, 214.0, 233.0)

In [19]:
def group_processed_data(X,y,times):
    X = torch.tensor(X)
    y = torch.tensor(y)
    final_indices,_ = get_final_indices(times,y)
    coeffs = process_data(times,X,intensity=args.intensity,time_intensity=args.time_intensity,cummean=cummean,cumsum=cumsum,append_times=args.append_times,interpolate=args.interpolate)
    return coeffs,y,final_indices

In [20]:
X_train_raw = augment_data(torch.tensor(X_train),times,intensity=args.intensity,time_intensity=args.time_intensity,cummean=cummean,cumsum=cumsum,append_times=args.append_times)

In [21]:
train_data = group_processed_data(X_train,y_train,times)
test_data = group_processed_data(X_test,y_test,times)
val_data = group_processed_data(X_val,y_val,times)

check X torch.Size([32266, 72, 136])
check X torch.Size([4033, 72, 136])
check X torch.Size([4034, 72, 136])


In [22]:
input_channels = train_data[0][0].shape[-1]
input_channels

136

In [23]:
output_channels = 1
stream = True if concat_z else False

In [24]:
if concat_z or side_input:
    side_input_dim = cv_sets[0][0][-1].shape[-1] if args.side_input else 0
    final_linear_input_channels = get_final_linear_input_channels(hidden_channels,side_input_dim=side_input_dim,time_len=args.time_len)
else:
    final_linear_input_channels = None

In [25]:
make_model = train_common.make_model(model_name, input_channels, output_channels, hidden_channels,
                               hidden_hidden_channels, num_hidden_layers, use_intensity=False,
                              final_linear_input_channels=final_linear_input_channels, 
                              initial=True,side_input=args.side_input,append_times=args.append_times,interpolate=args.interpolate)


In [26]:
reload(train_common)

<module 'models.cde.cde_train_common' from '/Users/chenyu/github/INSPIRE/code/notebooks/../models/cde/cde_train_common.py'>

In [27]:
times, train_dataloader, val_dataloader, test_dataloader = wrap_data(times, train_data, val_data, test_data, device,
                                                                                    batch_size=batch_size,num_workers=0)

log,log_num = train_common.main(name, times, train_dataloader, val_dataloader, test_dataloader, device,
                   new_make_model, num_classes, args.max_epochs, lr, kwargs={'stream':stream}, pos_weight=torch.tensor(args.pos_weight),
                   step_mode=True,rpath=args.rpath)

interpolate linear


  0%|                                                                                                                                         | 0/50 [00:00<?, ?it/s]

Starting training for model:

NeuralCDE(
  input_channels=136, hidden_channels=2, output_channels=1, initial=True
  (func): FinalTanh(
    input_channels: 136, hidden_channels: 2, hidden_hidden_channels: 128, num_hidden_layers: 4
    (linear_in): Linear(in_features=2, out_features=128, bias=True)
    (linears): ModuleList(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=128, bias=True)
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
    (linear_out): Linear(in_features=128, out_features=272, bias=True)
  )
  (initial_network): Linear(in_features=136, out_features=2, bias=True)
  (linear): Linear(in_features=144, out_features=1, bias=True)
)




  2%|██▌                                                                                                                              | 1/50 [00:34<28:17, 34.64s/it]

Epoch: 0  Train loss: 2.38  Train auroc: 0.752  Val loss: 2.42  Val auroc: 0.745


 22%|████████████████████████████▏                                                                                                   | 11/50 [05:12<18:50, 28.98s/it]

save model
Epoch: 10  Train loss: 2.49  Train auroc: 0.795  Val loss: 2.57  Val auroc: 0.797


 42%|█████████████████████████████████████████████████████▊                                                                          | 21/50 [09:39<13:42, 28.36s/it]

Epoch: 20  Train loss: 2.67  Train auroc: 0.808  Val loss: 2.88  Val auroc: 0.794


 62%|███████████████████████████████████████████████████████████████████████████████▎                                                | 31/50 [14:17<09:18, 29.40s/it]

save model
Epoch: 30  Train loss: 2.09  Train auroc: 0.823  Val loss: 2.19  Val auroc: 0.818


 82%|████████████████████████████████████████████████████████████████████████████████████████████████████████▉                       | 41/50 [18:54<04:26, 29.65s/it]

Epoch: 40  Train loss: 2.17  Train auroc: 0.815  Val loss: 2.18  Val auroc: 0.83


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [23:09<00:00, 27.79s/it]

save model
Epoch: 49  Train loss: 2.0  Train auroc: 0.836  Val loss: 2.11  Val auroc: 0.82
best epoch 49





#####################
test_metrics
{'accuracy': 0.7174479365348816, 'confusion': array([[2081.,  833.],
       [  35.,  123.]]), 'dataset_size': 3072, 'loss': 2.0691311359405518, 'auroc': 0.8119836146755516, 'average_precision': 0.22283657680613425, 'sensitivity': 0.7784810126582279, 'specificity': 0.7141386410432395}
#####################
val_metrics
{'accuracy': 0.7294921875, 'confusion': array([[2100.,  794.],
       [  37.,  141.]]), 'dataset_size': 3072, 'loss': 2.128657817840576, 'auroc': 0.8140864865704324, 'average_precision': 0.29333702529790034}
#####################
train_metrics
{'accuracy': 0.720703125, 'confusion': array([[21483.,  8542.],
       [  324.,  1395.]]), 'dataset_size': 31744, 'loss': 2.011030912399292, 'auroc': 0.8340240801852634, 'average_precision': 0.2598731232206565}


In [28]:
raw_feature_names = ['HR','O2Sat','Temp','SBP','MAP','DBP','Resp','EtCO2','BaseExcess','HCO3','FiO2','pH','PaCO2','SaO2',
                    'AST','BUN','Alkalinephos','Calcium','Chloride','Creatinine','Bilirubin_direct','Glucose','Lactate','Magnesium',
                    'Phosphate','Potassium','Bilirubin_total','TroponinI','Hct','Hgb','PTT','WBC','Fibrinogen','Platelets']

In [29]:
latent_feature_names = [r'$z_'+str(i)+'(t_{'+str(h)+'})$' for h in range(args.time_len) for i in range(hidden_channels) ]


In [30]:
intensity_feature_names = [rf+'_ctime' for rf in raw_feature_names] + [rf+'_cmax' for rf in raw_feature_names] + [rf+'_cmean' for rf in raw_feature_names] 
input_feature_names = raw_feature_names + intensity_feature_names
len(input_feature_names)

136

In [31]:
feature_types = ["int" if "ctime" in fn else "float" for fn in input_feature_names] 

In [32]:
args.append_times

False

In [34]:
## load model
model, regularise_parameters = make_model()
model.load_state_dict(torch.load(args.model_path))
model.to('cpu')
model.eval()

interpolate linear


NeuralCDE(
  input_channels=136, hidden_channels=2, output_channels=1, initial=True
  (func): FinalTanh(
    input_channels: 136, hidden_channels: 2, hidden_hidden_channels: 128, num_hidden_layers: 4
    (linear_in): Linear(in_features=2, out_features=128, bias=True)
    (linears): ModuleList(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=128, bias=True)
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
    (linear_out): Linear(in_features=128, out_features=272, bias=True)
  )
  (initial_network): Linear(in_features=136, out_features=2, bias=True)
  (linear): Linear(in_features=144, out_features=1, bias=True)
)

In [35]:
model.append_times=args.append_times

In [36]:
def gen_intgrad_baselines(reps,x,y):
    reps_norm = torch.square(all_reps).sum(dim=[d for d in range(1,len(reps.shape))])
    bid = torch.argmin(reps_norm)
    
    C = int(y.max())+1
    centres = []
    for c in range(C):
        centres.append(x[y==c].mean(axis=0).unsqueeze(0))
        
    
    baselines = torch.vstack([torch.zeros_like(x[0]).unsqueeze(0),x[bid].unsqueeze(0)]+centres)
    
    return baselines

In [37]:
def gen_balanced_subset(x,y,size_per_class=500):
    C = int(y.max())+1
    subset = []
    for c in range(C):
        y_c = (y == c)
        id_c = np.random.choice(np.sum(y_c),size=size_per_class)
        subset.append(x[y_c][id_c])

    subset = torch.vstack(subset)
    return subset

In [38]:

import explainer.explainer_utils as eutils
import models.cde as cde
# import controldiffeq

reload(eutils)
reload(cde)

<module 'models.cde' from '/Users/chenyu/github/INSPIRE/code/notebooks/../models/cde/__init__.py'>

In [39]:
## find the index of the baseline sample which has minimum norm of latent states
all_reps = model.latent_representation(X_train_raw,times=times).detach()
baselines = gen_intgrad_baselines(all_reps,X_train_raw,y_train)
subset = gen_balanced_subset(X_train_raw,y_train,size_per_class=200)


In [40]:
int_g, z_shift = eutils.calc_baselines_intg(test_examples=subset,model=model,baselines=baselines,times=times)

In [41]:
z_shift.shape,int_g.shape

(torch.Size([1600, 72, 2]), torch.Size([1600, 72, 2, 136]))

In [75]:
z_int_g = (int_g/z_shift.unsqueeze(3)).numpy()

In [43]:
z_int_g[torch.isnan(z_int_g)] = 0.

In [74]:
(torch.abs(z_int_g).numpy()>=0.01).sum()

15090848

In [107]:
K = int(args.time_len * args.hidden_channels)
itemsets = transform_intgrad_to_itemsets(z_int_g,thd=0.01,K=K)

In [108]:
for k in range(K):
    print(k,len(itemsets[k]))

0 0
1 0
2 1502
3 1502
4 1562
5 1562
6 1578
7 1578
8 1586
9 1586
10 1594
11 1594
12 1598
13 1598
14 1598
15 1598
16 1600
17 1600
18 1600
19 1600
20 1600
21 1600
22 1600
23 1600
24 1600
25 1600
26 1600
27 1600
28 1600
29 1600
30 1600
31 1600
32 1600
33 1600
34 1600
35 1600
36 1600
37 1600
38 1600
39 1600
40 1600
41 1600
42 1600
43 1600
44 1600
45 1600
46 1600
47 1600
48 1600
49 1600
50 1600
51 1600
52 1600
53 1600
54 1600
55 1600
56 1600
57 1600
58 1600
59 1600
60 1600
61 1600
62 1600
63 1600
64 1600
65 1600
66 1600
67 1600
68 1600
69 1600
70 1600
71 1600
72 1600
73 1600
74 1600
75 1600
76 1600
77 1600
78 1600
79 1600
80 1600
81 1600
82 1600
83 1600
84 1600
85 1600
86 1600
87 1600
88 1600
89 1600
90 1600
91 1600
92 1600
93 1600
94 1600
95 1600
96 1600
97 1600
98 1600
99 1600
100 1600
101 1600
102 1600
103 1600
104 1600
105 1600
106 1600
107 1600
108 1600
109 1600
110 1600
111 1600
112 1600
113 1600
114 1600
115 1600
116 1600
117 1600
118 1600
119 1600
120 1600
121 1600
122 1600
123 1600


In [91]:
linear_prams = []
for p in model.linear.parameters():
    p = p.detach()
    print(p.shape,p)
    linear_prams.append(p)

torch.Size([1, 144]) tensor([[ 0.0451, -0.0765, -0.0127,  0.0030, -0.0041, -0.0129,  0.0101,  0.0093,
          0.0083,  0.0096,  0.0144, -0.0235,  0.0119, -0.0161, -0.0237, -0.0025,
          0.0186,  0.0068,  0.0241, -0.0320, -0.0037, -0.0104, -0.0066,  0.0180,
          0.0291,  0.0313,  0.0255, -0.0275, -0.0362, -0.0259,  0.0344,  0.0517,
          0.0086, -0.0363, -0.0290,  0.0534,  0.0058, -0.0321,  0.0074, -0.0324,
          0.0015,  0.0802,  0.0144, -0.0390, -0.0397,  0.0020,  0.0054,  0.0203,
          0.0313,  0.0296,  0.0167, -0.0144, -0.0182, -0.0249, -0.0156, -0.0543,
         -0.0206,  0.0164, -0.0051,  0.0414, -0.0071, -0.0316,  0.0276,  0.0324,
         -0.0165, -0.0491,  0.0145,  0.0242, -0.0462, -0.0608,  0.0108,  0.0063,
          0.0040,  0.0203,  0.0142,  0.0691,  0.0180, -0.0442,  0.0235,  0.0594,
          0.0438, -0.0031, -0.0003, -0.0063, -0.0513,  0.0218, -0.0069, -0.0389,
          0.0192, -0.0677, -0.0762,  0.0785,  0.0103, -0.0626, -0.0164, -0.0460,
       

In [92]:
pred_y_train = (torch.matmul(all_reps.reshape(all_reps.shape[0],-1),linear_prams[0].transpose(1,0))+linear_prams[1]).numpy().reshape(-1)

In [93]:
top_latent = torch.argsort(torch.abs(all_reps.reshape(all_reps.shape[0],-1)*linear_prams[0]).mean(dim=0),descending=True)[:20].numpy()
top_latent

array([135,  99,  91,  89,  75,  93,  79, 143,  97,  69, 117,  41,  95,
       115,  77, 113,  65,  55, 129,  87])

In [94]:
import importlib
import explainer.rule_pattern_miner as rlm
importlib.reload(rlm)

<module 'explainer.rule_pattern_miner' from '/Users/chenyu/github/INSPIRE/code/notebooks/../explainer/rule_pattern_miner.py'>

In [112]:
latent_num = 2
z_rules = {}
zw = (linear_prams[0][0]>0).numpy()

for l in top_latent:
    l = int(l)
    print('### latent state {} ###'.format(l))
    time_step= int(l/latent_num)
    latent_id = int(l%latent_num)
    
    z = all_reps[:,time_step,latent_id].numpy()
    x = X_train_raw[:,time_step,:].numpy()
    zw_pos = zw[l]

    itemsets_z = itemsets[l]
    
    z_rules[l] = rlm.find_pattern_by_latent_state(x,z,itemsets_z,zw_pos,y=y_train,c=1,num_grids=100,omega=0.2,
                                                  min_support_pos=200,min_support_neg=2000,max_depth=10,
                                                  feature_types=feature_types,top_K=3,verbose=False)
#     z_rules[l] = rlm.find_pattern_by_latent_state(x,y_train,z,itemsets_z,zw_pos,c=1,num_grids=100,omega=0.2,
#                                                   min_support_pos=200,min_support_neg=2000,max_depth=5)
    

### latent state 135 ###
min -922.28217 max -0.021374801
p(y=1) 0.05395772639930577
thd_h -55.915968178517346 thd_l -391.28352844257915 pos True
[35.0] 1596.0
[35.0, 36.0] 1589.0
[35.0, 36.0, 39.0] 1584.0
[35.0, 36.0, 39.0, 41.0] 1576.0
[35.0, 36.0, 39.0, 41.0, 38.0] 1544.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0] 1491.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0] 1349.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0, 56.0] 1175.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0, 56.0, 60.0] 910.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0, 56.0, 60.0, 63.0] 760.0
feature set [34 35 38 40 37 36 39 55 59 62]
build_rule_tree
init rule tree
search rule for feature 34
check potential rule 34 26.893987966700696 0.0 4.737373737373738 331
add rule 34 (26.893987966700696, 0.0, 4.737373737373738, 331)
search rule for feature 35
build_rule_tree
init rule tree
search rule for feature 34
check potential rule 34 9.165178107431379 53.464646464646464 67.0 2082
add rule 34 (9.165178107431379, 53.464646464646464, 67.0, 

thd_h -61.84235200319779 thd_l -343.4712481457514 pos True
[36.0] 1597.0
[36.0, 39.0] 1594.0
[36.0, 39.0, 41.0] 1583.0
[36.0, 39.0, 41.0, 35.0] 1569.0
[36.0, 39.0, 41.0, 35.0, 38.0] 1541.0
[36.0, 39.0, 41.0, 35.0, 38.0, 37.0] 1478.0
[36.0, 39.0, 41.0, 35.0, 38.0, 37.0, 40.0] 1335.0
[36.0, 39.0, 41.0, 35.0, 38.0, 37.0, 40.0, 56.0] 1172.0
[36.0, 39.0, 41.0, 35.0, 38.0, 37.0, 40.0, 56.0, 60.0] 915.0
[36.0, 39.0, 41.0, 35.0, 38.0, 37.0, 40.0, 56.0, 60.0, 63.0] 747.0
feature set [35 38 40 34 37 36 39 55 59 62]
build_rule_tree
init rule tree
search rule for feature 35
check potential rule 35 20.300167334246865 3.6363636363636362 4.545454545454545 212.0
add rule 35 (20.300167334246865, 3.6363636363636362, 4.545454545454545, 212.0)
search rule for feature 38
build_rule_tree
init rule tree
search rule for feature 35
check potential rule 35 4.115721078521389 43.63636363636363 45.0 3587
add rule 35 (4.115721078521389, 43.63636363636363, 45.0, 3587)
search rule for feature 38
check potential rule 

check potential rule 38 1.2026492919067364 46.54545454545455 48.0 2450
add rule 38 (1.2026492919067364, 46.54545454545455, 48.0, 2450)
search rule for feature 40
### latent state 69 ###
min -525.8036 max -0.021374801
p(y=1) 0.05395772639930577
no enough support
thd_h -26.576032077235084 thd_l 0.0 pos False
[35.0] 1593.0
[35.0, 36.0] 1583.0
[35.0, 36.0, 39.0] 1574.0
[35.0, 36.0, 39.0, 41.0] 1559.0
[35.0, 36.0, 39.0, 41.0, 38.0] 1528.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0] 1465.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0] 1325.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0, 56.0] 1170.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0, 56.0, 60.0] 932.0
[35.0, 36.0, 39.0, 41.0, 38.0, 37.0, 40.0, 56.0, 60.0, 63.0] 774.0
feature set [34 35 38 40 37 36 39 55 59 62]
build_rule_tree
init rule tree
search rule for feature 34
build_rule_tree
init rule tree
search rule for feature 34
### latent state 117 ###
min -825.69257 max -0.021374801
p(y=1) 0.05395772639930577
thd_h -58.402166092260245 thd_l -392.006

check potential rule 37 1.0213933270676692 19.7979797979798 20.0 2240
add rule 37 (1.0213933270676692, 19.7979797979798, 20.0, 2240)
search rule for feature 39
check potential rule 39 1.073210855914489 19.7979797979798 20.0 2078
add rule 39 (1.073210855914489, 19.7979797979798, 20.0, 2078)
search rule for feature 55
check potential rule 36 1.0623030521292878 3.8383838383838382 20.0 2298
add rule 36 (1.0623030521292878, 3.8383838383838382, 20.0, 2298)
search rule for feature 37
check potential rule 37 1.0213933270676692 19.7979797979798 20.0 2240
add rule 37 (1.0213933270676692, 19.7979797979798, 20.0, 2240)
search rule for feature 39
check potential rule 39 1.073210855914489 19.7979797979798 20.0 2078
add rule 39 (1.073210855914489, 19.7979797979798, 20.0, 2078)
search rule for feature 55
check potential rule 36 1.0623030521292878 3.8383838383838382 20.0 2298
add rule 36 (1.0623030521292878, 3.8383838383838382, 20.0, 2298)
search rule for feature 37
check potential rule 37 1.0213933270

no enough support
thd_h -29.381014657743037 thd_l 0.0 pos False
[35.0] 1594.0
[35.0, 36.0] 1585.0
[35.0, 36.0, 38.0] 1558.0
[35.0, 36.0, 38.0, 41.0] 1536.0
[35.0, 36.0, 38.0, 41.0, 39.0] 1517.0
[35.0, 36.0, 38.0, 41.0, 39.0, 37.0] 1452.0
[35.0, 36.0, 38.0, 41.0, 39.0, 37.0, 40.0] 1310.0
[35.0, 36.0, 38.0, 41.0, 39.0, 37.0, 40.0, 56.0] 1162.0
[35.0, 36.0, 38.0, 41.0, 39.0, 37.0, 40.0, 56.0, 60.0] 967.0
[35.0, 36.0, 38.0, 41.0, 39.0, 37.0, 40.0, 56.0, 60.0, 63.0] 817.0
feature set [34 35 37 40 38 36 39 55 59 62]
build_rule_tree
init rule tree
search rule for feature 34
build_rule_tree
init rule tree
search rule for feature 34
check potential rule 34 1.0000309933364326 1.909090909090909 27.0 32248.0
add rule 34 (1.0000309933364326, 1.909090909090909, 27.0, 32248.0)
search rule for feature 35
### latent state 129 ###
min -893.089 max -0.021374801
p(y=1) 0.05395772639930577
no enough support
thd_h -27.084029784708378 thd_l 0.0 pos False
[35.0] 1592.0
[35.0, 36.0] 1583.0
[35.0, 36.0, 39.0] 1

In [100]:
len(itemsets_z)

110

In [101]:
z_rules[99]

{'pos': True,
 'thd_h': -51.471871939798234,
 'thd_l': -441.02563598938286,
 'p(z>=thd_h)': 0.031581231017169774,
 'p(z<=thd_l)': 0.007934048224136863,
 'rule_dict_higher_z': [],
 'rule_dict_lower_z': []}

In [113]:
sorted_rules_pos = rlm.sort_rules(z_rules,input_feature_names,pos=True,sort_by="cond_prob_y")
sorted_rules_pos

135 True
99 True
91 True
89 False
75 True
93 False
79 True
143 False
97 True
69 False
117 True
41 True
95 False
115 False
77 False
113 True
65 False
55 False
129 False
87 False


[{'rules': [(35, 'O2Sat_ctime', '==', 4.0)],
  'zid': 75,
  'p(z>=thd_h)': 0.03967024112068431,
  'thd_h': -57.940185683788854,
  'pos': True,
  'cond_prob_y': 0.06046511627906977,
  'cond_prob_target': 0.8744186046511628,
  'support': 215,
  'ratio_y': 0.007466973004020678},
 {'rules': [(35, 'O2Sat_ctime', '==', 4.0)],
  'zid': 113,
  'p(z>=thd_h)': 0.03815161470278312,
  'thd_h': -57.2684707403032,
  'pos': True,
  'cond_prob_y': 0.05714285714285714,
  'cond_prob_target': 0.8952380952380953,
  'support': 210,
  'ratio_y': 0.0068925904652498565},
 {'rules': [(35, 'O2Sat_ctime', '==', 4.0)],
  'zid': 91,
  'p(z>=thd_h)': 0.04438108225376557,
  'thd_h': -61.84235200319779,
  'pos': True,
  'cond_prob_y': 0.05660377358490566,
  'cond_prob_target': 0.9009433962264151,
  'support': 212,
  'ratio_y': 0.0068925904652498565},
 {'rules': [(34, 'HR_ctime', '>=', 2.0)],
  'zid': 55,
  'p(z<=thd_l)': 0.9999690076241244,
  'thd_l': 0.0,
  'pos': False,
  'cond_prob_y': 0.05398784420739271,
  'cond

In [None]:
for v in sorted_rules_pos:
    if v["zid"]==138:
        print(v)

In [114]:
sorted_rules_neg = rlm.sort_rules(z_rules,input_feature_names,pos=False)
sorted_rules_neg

135 True
99 True
91 True
89 False
75 True
93 False
79 True
143 False
97 True
69 False
117 True
41 True
95 False
115 False
77 False
113 True
65 False
55 False
129 False
87 False


[{'rules': [(38, 'MAP_ctime', '==', 20.0),
   (40, 'Resp_ctime', '<=', 19.0),
   (35, 'O2Sat_ctime', '==', 20.0),
   (34, 'HR_ctime', '==', 20.0),
   (36, 'Temp_ctime', '>=', 4.0),
   (37, 'SBP_ctime', '==', 20.0),
   (39, 'DBP_ctime', '==', 20.0)],
  'zid': 41,
  'p(z<=thd_l)': 0.06762536416041653,
  'thd_l': -177.3311723213116,
  'pos': True,
  'cond_prob_y': 0.965832531280077,
  'cond_prob_target': 0.21751684311838307,
  'support': 2078,
  'ratio_y': 0.06574938574938576},
 {'rules': [(35, 'O2Sat_ctime', '==', 37.0),
   (38, 'MAP_ctime', '==', 37.0),
   (34, 'HR_ctime', '==', 37.0)],
  'zid': 75,
  'p(z<=thd_l)': 0.06799727267092295,
  'thd_l': -301.1991913907719,
  'pos': True,
  'cond_prob_y': 0.9620473537604457,
  'cond_prob_target': 0.2503481894150418,
  'support': 2872,
  'ratio_y': 0.09051597051597052},
 {'rules': [(34, 'HR_ctime', '==', 39.0),
   (38, 'MAP_ctime', '==', 39.0),
   (35, 'O2Sat_ctime', '==', 39.0)],
  'zid': 79,
  'p(z<=thd_l)': 0.07506353437054485,
  'thd_l': -3

In [None]:
sample_local_rules

In [115]:
baseline_reps = model.latent_representation(baselines,times=times)
baseline_output = model.linear(baseline_reps.reshape(baseline_reps.shape[0],-1)).detach()

subset_reps = model.latent_representation(subset,times=times)
subset_output = model.linear(subset_reps.reshape(subset_reps.shape[0],-1)).detach()

In [116]:
yshift = []
for i in range(len(baseline_output)):
    yshift.append(subset_output-baseline_output[i])
yshift = torch.vstack(yshift)

In [118]:
weights = linear_prams[0].reshape(-1,args.hidden_channels)
y_int_g = output_intg_score(int_g,weights,yshift)
y_int_g[torch.isnan(y_int_g)] = 0.

itemsets_y = transform_intgrad_to_itemsets(y_int_g,thd=0.01,K=1)

In [143]:
importlib.reload(rlm)

<module 'explainer.rule_pattern_miner' from '/Users/chenyu/github/INSPIRE/code/notebooks/../explainer/rule_pattern_miner.py'>

In [120]:
x = X_train_raw.reshape(X_train_raw.shape[0],-1).numpy()
zw_pos = 1

In [121]:
fids = gen_freq_feature_set(itemsets_y[0],min_support=500,max_len=20)
fids = np.array(fids).astype(int)-1
print('feature set',fids)

[9147.0] 1599.0
[9147.0, 5339.0] 1597.0
[9147.0, 5339.0, 6699.0] 1595.0
[9147.0, 5339.0, 6699.0, 6703.0] 1593.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0] 1591.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0] 1590.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0] 1590.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0] 1590.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0] 1589.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0] 1588.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0, 6563.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0, 6563.0, 6705.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0, 6563.0, 6705.0, 9151.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0,

In [167]:
y_rule_candidates = rlm.gen_rule_list_for_one_target_greedy(x,fids,pred_y_train>=0.6,y=y_train,c=1,sort_by="cond_prob_y",
                                                            min_support=500,num_grids=100,max_depth=20,top_K=10,
                                                            local_x=None,feature_types=feature_types*args.time_len,
                                                            verbose=False)

build_rule_tree
init rule tree
check fids [9146, 5338, 6698, 6702, 9152, 9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 9146
check potential rule 9146 4.596094086996457 63.61616161616162 67.0 568
add rule 9146 (4.596094086996457, 63.61616161616162, 67.0, 568)
check fids [5338, 6698, 6702, 9152, 9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 5338
check potential rule 5338 1.002951238701455 36.63636363636363 39.0 529
add rule 5338 (1.002951238701455, 36.63636363636363, 39.0, 529)
check fids [6698, 6702, 9152, 9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6698
check potential rule 6698 1.0020829403257248 46.525252525252526 49.0 513
add rule 6698 (1.0020829403257248, 46.525252525252526, 49.0, 513)
check fids [6702, 9152, 9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070,

no valid rule,skip 9560
check fids [4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 4664
no valid rule,skip 4664
check fids [5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 5072
no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6426
check potential rule 6426 1.0 44.62626262626263 47.0 501
no enough support,skip 501 1.0
no valid rule,skip 6426
check fids [6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6562
check potential rule 6562 1.0 45.57575757575758 48.0 501
no enough support,skip 501 1.0
no valid rule,skip 6562
check fids [6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6704
no valid rule,skip 6704
check fids [9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 9150
no valid rule,skip 9150
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feat

no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6426
check potential rule 6426 1.0 35.60606060606061 47.0 501
no enough support,skip 501 1.0
check potential rule 6426 1.0 35.60606060606061 47.0 501
no enough support,skip 501 1.0
no valid rule,skip 6426
check fids [6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6562
check potential rule 6562 1.0 36.84848484848485 48.0 501
no enough support,skip 501 1.0
check potential rule 6562 1.0 36.84848484848485 48.0 501
no enough support,skip 501 1.0
no valid rule,skip 6562
check fids [6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6704
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support,skip 501 1.0
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support,skip 501 1.0
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support,skip 501 1.0
no valid rule,skip 6704
check 

check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9152 1.0071777764420669 44.66666666666667 67.0 508
add rule 9152 (1.0071777764420669, 44.66666666666667, 67.0, 508)
check fids [9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 9418
check potential rule 9418 1.0 52.969696969696976 69.0 508
no enough support,skip 508 1.0
check 

check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9150 1.002 46.696969696969695 67.0 500
add rule 9150 (1.002, 46.696969696969695, 67.0, 500)
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for fe

no valid rule,skip 4664
check fids [5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 5072
no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6426
check potential rule 6426 1.0 35.60606060606061 47.0 501
no enough support,skip 501 1.0
check potential rule 6426 1.0 35.60606060606061 47.0 501
no enough support,skip 501 1.0
no valid rule,skip 6426
check fids [6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6562
check potential rule 6562 1.0 36.84848484848485 48.0 501
no enough support,skip 501 1.0
check potential rule 6562 1.0 36.84848484848485 48.0 501
no enough support,skip 501 1.0
no valid rule,skip 6562
check fids [6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6704
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support,skip 501 1.0
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support

no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 19.80808080808081 37.0 502
no enough support,skip 502 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9152 1.0126912882818393 47.37373737373738 67.0 508
add rule 9152 (1.0126912882818393, 47.37373737373738, 67.0, 508)
check fids [9418, 9554, 9560, 4664, 507

no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6426
check potential rule 6426 1.0 32.75757575757576 47.0 508
no enough support,skip 508 1.0
check potential rule 6426 1.0 32.75757575757576 47.0 508
no enough support,skip 508 1.0
check potential rule 6426 1.0 32.75757575757576 47.0 508
no enough support,skip 508 1.0
no valid rule,skip 6426
check fids [6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6562
check potential rule 6562 1.0 33.93939393939394 48.0 508
no enough support,skip 508 1.0
check potential rule 6562 1.0 33.93939393939394 48.0 508
no enough support,skip 508 1.0
check potential rule 6562 1.0 33.93939393939394 48.0 508
no enough support,skip 508 1.0
no valid rule,skip 6562
check fids [6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6704
check potential rule 6704 1.0015475016934405 30.68686868686869 47.515151515151516 503
add rule 6704 (1.001547501693

no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9560 1.002736908454559 46.666666666666664 70.0 501
add rule 9560 (1.002736908454559, 46.666666666666664, 70.0, 501)
check fids [4664, 5072, 6426, 6562, 6704, 

check potential rule 9150 1.002 46.696969696969695 67.0 500
add rule 9150 (1.002, 46.696969696969695, 67.0, 500)
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach

check potential rule 9554 1.0 52.323232323232325 70.0 508
no enough support,skip 508 1.0
check potential rule 9554 1.0 52.323232323232325 70.0 508
no enough support,skip 508 1.0
check potential rule 9554 1.0 52.323232323232325 70.0 508
no enough support,skip 508 1.0
check potential rule 9554 1.0 52.323232323232325 70.0 508
no enough support,skip 508 1.0
no valid rule,skip 9554
check fids [9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 9560
check potential rule 9560 1.002736908454559 46.666666666666664 70.0 501
add rule 9560 (1.002736908454559, 46.666666666666664, 70.0, 501)
check fids [4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 4664
no valid rule,skip 4664
check fids [5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 5072
no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 

check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9150 1.002 46.696969696969695 67.0 500
add rule 9150 (1.002, 46.696969696969695, 67.0, 500)
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for fe

check potential rule 9560 1.0 47.37373737373738 70.0 508
no enough support,skip 508 1.0
check potential rule 9560 1.0 47.37373737373738 70.0 508
no enough support,skip 508 1.0
check potential rule 9560 1.0 47.37373737373738 70.0 508
no enough support,skip 508 1.0
no valid rule,skip 9560
check fids [4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 4664
no valid rule,skip 4664
check fids [5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 5072
no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6426
check potential rule 6426 1.0 32.75757575757576 47.0 508
no enough support,skip 508 1.0
check potential rule 6426 1.0 32.75757575757576 47.0 508
no enough support,skip 508 1.0
check potential rule 6426 1.0 32.75757575757576 47.0 508
no enough support,skip 508 1.0
no valid rule,skip 6426
check fids [6562, 6704, 9150, 3706, 4386, 5070, 5882, 6

no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 19.80808080808081 37.0 502
no enough support,skip 502 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 6704 1.0015475016934405 30.68686868686869 47.515151515151516 503
add rule 6704 (1.0015475016934405, 30.68686868686869, 47.515151515151516, 503)
check fids 

no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
check potential rule 5070 1.0 25.78787878787879 37.0 500
no enough support,skip 500 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
check potential rule 5882 1.0 31.70707070707071 43.0 500
no enough support,skip 500 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9150 1.002 46.696969696969695 67.0 500
add rule 9150 (1.002, 46.696969696969695, 67.0, 500)
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for fe

check potential rule 6562 1.0 36.84848484848485 48.0 501
no enough support,skip 501 1.0
check potential rule 6562 1.0 36.84848484848485 48.0 501
no enough support,skip 501 1.0
no valid rule,skip 6562
check fids [6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6704
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support,skip 501 1.0
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support,skip 501 1.0
check potential rule 6704 1.0 29.6969696969697 49.0 501
no enough support,skip 501 1.0
no valid rule,skip 6704
check fids [9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 9150
check potential rule 9150 1.002 46.696969696969695 67.0 500
add rule 9150 (1.002, 46.696969696969695, 67.0, 500)
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203

no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9560 1.002736908454559 46.666666666666664 70.0 501
add rule 9560 (1.002736908454559, 46.666666666666664, 70.0, 501)
check fids [4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 4664
no valid rule,skip 4664
check fids [5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 5072
no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6426
check potential rule 6426 1.0 35.60606060606061 47.0 501
no enough support,skip 501 1.0
check potential rule 6426 1.0 35.60606060606061 47.0 501
no enough support,skip 501 1.0
no valid rule,skip 6426
check fids [6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6562
check potential rule 6562 1.0 36.84848484848485 48.0 501
no enough support,skip 501 1.0
check potential rule 6562 1.0 36.84848484848485 48.0 501


check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9150 1.00199203187251 34.515151515151516 67.0 502
add rule 9150 (1.00199203187251, 34.515151515151516, 67.0, 502)
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 19.80808080808081 37.0 502
no enough support,skip 502 1.0
no valid rule,skip 5070
check fids [5882, 6291,

no valid rule,skip 6291
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 9150 1.00199203187251 34.515151515151516 67.0 502
add rule 9150 (1.00199203187251, 34.515151515151516, 67.0, 502)
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 19.80808080808081 37.0 502
no enough support,skip 502 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
no valid rule,skip 5882
check fids [6291,

no valid rule,skip 5072
check fids [6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6426
check potential rule 6426 1.0 6.646464646464646 47.0 500
no enough support,skip 500 1.0
no valid rule,skip 6426
check fids [6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6562
check potential rule 6562 1.0 6.787878787878788 48.0 500
no enough support,skip 500 1.0
no valid rule,skip 6562
check fids [6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6704
no valid rule,skip 6704
check fids [9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 9150
check potential rule 9150 1.0 4.737373737373738 67.0 500
no enough support,skip 500 1.0
no valid rule,skip 9150
check fids [3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 3706
no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
check potential rule 4386 1.0 3.878787878787879 32.0 500
no enough support,s

check potential rule 5203 1.0 4.98989898989899 38.0 514
no enough support,skip 514 1.0
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 5338 2.020629696814679 0.0 10.636363636363637 1891
add rule 5338 (2.020629696814679, 0.0, 10.636363636363637, 1891)
check fids [6698, 6702, 9152, 9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6698
check potential rule 6698 1.1332935573601568 5.9393939393939394 8.90909090909091 756
add rule 6698 (1.1332935573601568, 5.9393939393939394, 8.90909090909091, 756)
check fids [6702, 9152, 9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for feature 6702
check potential rule 6702 1.190337502141511 5.9393939393939394 49.0 546
add rule 6702 (1.190337502141511, 5.9393939393939394, 49.0, 546)
check fids [9152, 9418, 9554, 9560, 4664, 5072, 6426, 6562, 6704, 9150, 3706, 4386, 5070, 5882, 6291, 5203]
search rule for featur

no valid rule,skip 3706
check fids [4386, 5070, 5882, 6291, 5203]
search rule for feature 4386
no valid rule,skip 4386
check fids [5070, 5882, 6291, 5203]
search rule for feature 5070
check potential rule 5070 1.0 0.7474747474747475 37.0 600
no enough support,skip 600 1.0
check potential rule 5070 1.0 0.7474747474747475 37.0 600
no enough support,skip 600 1.0
no valid rule,skip 5070
check fids [5882, 6291, 5203]
search rule for feature 5882
no valid rule,skip 5882
check fids [6291, 5203]
search rule for feature 6291
check potential rule 6291 1.0033444816053512 0.9292929292929293 19.97979797979798 598
add rule 6291 (1.0033444816053512, 0.9292929292929293, 19.97979797979798, 598)
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 6291 1.0033444816053512 0.9292929292929293 19.97979797979798 598
add rule 6291 (1.0033444816053512, 0.9292929292929293, 19.97979797979798, 598)
check fids [5203]
search rule for feature 5203


check potential rule 6291 1.0033444816053512 0.9292929292929293 19.97979797979798 598
add rule 6291 (1.0033444816053512, 0.9292929292929293, 19.97979797979798, 598)
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth
check potential rule 6291 1.0033444816053512 0.9292929292929293 19.97979797979798 598
add rule 6291 (1.0033444816053512, 0.9292929292929293, 19.97979797979798, 598)
check fids [5203]
search rule for feature 5203
no valid rule,skip 5203
check fids []
reach max depth


In [168]:
y_rule_candidates

[{'rules': [(9146, '>=', 6.0),
   (9146, '<=', 8.0),
   (6702, '>=', 6.0),
   (3706, '>=', 6.0),
   (6291, '>=', 5.0)],
  'cond_prob_target': 0.7115009746588694,
  'support': 513,
  'cond_prob_y': 0.5497076023391813,
  'ratio_y': 0.16197587593337162},
 {'rules': [(9146, '<=', 10.0),
   (5338, '>=', 6.0),
   (5338, '<=', 8.0),
   (6698, '>=', 6.0),
   (6698, '<=', 8.0),
   (6702, '>=', 6.0),
   (3706, '>=', 6.0),
   (6291, '>=', 5.0)],
  'cond_prob_target': 0.7120622568093385,
  'support': 514,
  'cond_prob_y': 0.5486381322957199,
  'ratio_y': 0.16197587593337162},
 {'rules': [(9146, '<=', 24.0),
   (5338, '>=', 6.0),
   (5338, '<=', 8.0),
   (6698, '>=', 6.0),
   (6698, '<=', 8.0),
   (6702, '>=', 6.0),
   (3706, '>=', 6.0),
   (6291, '>=', 5.0)],
  'cond_prob_target': 0.7120622568093385,
  'support': 514,
  'cond_prob_y': 0.5486381322957199,
  'ratio_y': 0.16197587593337162},
 {'rules': [(9146, '<=', 24.0),
   (5338, '<=', 10.0),
   (6698, '>=', 6.0),
   (6698, '<=', 8.0),
   (6702, '

In [169]:
for i, rules in enumerate(y_rule_candidates):   
    rules["rules"] = rlm.replace_feature_names(rules["rules"],input_feature_names,time_index=True)
    y_rule_candidates[i] = rules
y_rule_candidates

[{'rules': [(9146, 'HR_ctime_t67', '>=', 6.0),
   (9146, 'HR_ctime_t67', '<=', 8.0),
   (6702, 'MAP_ctime_t49', '>=', 6.0),
   (3706, 'HR_ctime_t27', '>=', 6.0),
   (6291, 'O2Sat_ctime_t46', '>=', 5.0)],
  'cond_prob_target': 0.7115009746588694,
  'support': 513,
  'cond_prob_y': 0.5497076023391813,
  'ratio_y': 0.16197587593337162},
 {'rules': [(9146, 'HR_ctime_t67', '<=', 10.0),
   (5338, 'HR_ctime_t39', '>=', 6.0),
   (5338, 'HR_ctime_t39', '<=', 8.0),
   (6698, 'HR_ctime_t49', '>=', 6.0),
   (6698, 'HR_ctime_t49', '<=', 8.0),
   (6702, 'MAP_ctime_t49', '>=', 6.0),
   (3706, 'HR_ctime_t27', '>=', 6.0),
   (6291, 'O2Sat_ctime_t46', '>=', 5.0)],
  'cond_prob_target': 0.7120622568093385,
  'support': 514,
  'cond_prob_y': 0.5486381322957199,
  'ratio_y': 0.16197587593337162},
 {'rules': [(9146, 'HR_ctime_t67', '<=', 24.0),
   (5338, 'HR_ctime_t39', '>=', 6.0),
   (5338, 'HR_ctime_t39', '<=', 8.0),
   (6698, 'HR_ctime_t49', '>=', 6.0),
   (6698, 'HR_ctime_t49', '<=', 8.0),
   (6702, 'MA

In [None]:
# sample_local_rules = {}
# for tid in range(14,15):
#     xi = X_train_raw[tid,:,:]
#     if pred_y_train[tid] > 0:
#         match = rlm.match_sample_rules(xi,sorted_rules_pos,time_dim=True,num_latent_per_time=2)
#     else:
#         match = rlm.match_sample_rules(xi,sorted_rules_neg,time_dim=True,num_latent_per_time=2)
#     print("match",match)
#     if match is None:
#         x_z_rules = {}
#         tz = all_reps[tid]
#         sort_zids = torch.argsort(torch.abs(tz.reshape(-1)*linear_prams[0]),descending=True).numpy().reshape(-1)

#         for zid in sort_zids[:20]:
#             time_step= int(zid/latent_num)
#             latent_id = int(zid%latent_num)
#             x = X_train_raw[:,time_step,:].numpy()
#             z = all_reps[:,time_step,latent_id].numpy()
#             zi = z[tid]
#             xi_t = xi[time_step,:] 
#             zw_pos = zw[zid]
#             x_z_rules[zid] = rlm.find_pattern_by_sample_latent_state(xi_t,zi,x,y_train,z,itemsets[zid],zw_pos=zw_pos,c=1,num_grids=100,omega=0.2,min_support_pos=200,min_support_neg=2000,max_depth=5)
            
#         sample_local_rules[tid] = rlm.sort_rules(x_z_rules,input_feature_names,pos=(pred_y_train[tid]>0))
#     else:
#         sample_local_rules[tid] = match

In [133]:
tid = 14
xi = X_train_raw[tid,:,:]
pred_y_train[tid]

5.911136

In [155]:
reload(rlm)

<module 'explainer.rule_pattern_miner' from '/Users/chenyu/github/INSPIRE/code/notebooks/../explainer/rule_pattern_miner.py'>

In [170]:
match = rlm.match_sample_rules(xi.reshape(-1),y_rule_candidates,time_dim=False,num_latent_per_time=2)

In [171]:
for mr in match:
    print(rlm.display_rules(mr,x,pred_y_train>=0.6,y=y_train,c=1,verbose=False))

{'rules': [(9146, 'HR_ctime_t67', '>=', 6.0)], 'cond_prob_target': 0.2, 'support': 31760, 'cond_prob_y': 0.05403022670025189, 'ratio_y': 0.9856404365307294}
{'rules': [(9146, 'HR_ctime_t67', '>=', 53.0), (5338, 'HR_ctime_t39', '>=', 25.0), (5338, 'HR_ctime_t39', '<=', 37.0), (6702, 'MAP_ctime_t49', '>=', 37.0), (6702, 'MAP_ctime_t49', '<=', 46.0)], 'cond_prob_target': 0.7055655296229802, 'support': 557, 'cond_prob_y': 0.17773788150807898, 'ratio_y': 0.05686387133831131}
{'rules': [(9146, 'HR_ctime_t67', '>=', 53.0), (5338, 'HR_ctime_t39', '>=', 25.0), (5338, 'HR_ctime_t39', '<=', 37.0), (6702, 'MAP_ctime_t49', '>=', 23.0), (6702, 'MAP_ctime_t49', '<=', 46.0)], 'cond_prob_target': 0.7017241379310345, 'support': 580, 'cond_prob_y': 0.17413793103448275, 'ratio_y': 0.058012636415852956}


In [178]:
#x_z_rules[zid] = rlm.find_pattern_by_sample_latent_state(xi,zi,x,y_train,z,itemsets[zid],zw_pos=zw_pos,c=1,num_grids=100,omega=0.2,min_support_pos=200,min_support_neg=2000,max_depth=5)
reload(rlm)

<module 'explainer.rule_pattern_miner' from '/Users/chenyu/github/INSPIRE/code/notebooks/../explainer/rule_pattern_miner.py'>

In [179]:
local_rules = rlm.gen_rule_list_for_one_target_greedy(x,fids,pred_y_train>=0.6,y=y_train,c=1,sort_by="cond_prob_y",
                                                    min_support=500,num_grids=100,max_depth=20,top_K=10,
                                                    local_x=xi.reshape(-1),feature_types=feature_types*args.time_len,
                                                    verbose=False)

build_rule_tree
init rule tree
search rule for feature 9146
check potential rule 9146 2.655689043383564 52.78787878787879 67.0 2514
add rule 9146 (2.655689043383564, 52.78787878787879, 67.0, 2514)
search rule for feature 5338
check potential rule 5338 1.2565260079827538 24.818181818181817 37.81818181818181 828
add rule 5338 (1.2565260079827538, 24.818181818181817, 37.81818181818181, 828)
search rule for feature 6698
check potential rule 6698 1.0 34.64646464646465 49.0 828
no enough support,skip 828 1.0
check potential rule 6698 1.0 34.64646464646465 49.0 828
no enough support,skip 828 1.0
no valid rule,skip 6698
search rule for feature 6702
check potential rule 6702 1.0395164742488037 36.62626262626263 46.525252525252526 557
add rule 6702 (1.0395164742488037, 36.62626262626263, 46.525252525252526, 557)
search rule for feature 9152
no valid rule,skip 9152
search rule for feature 9418
check potential rule 9418 1.0 52.969696969696976 69.0 557
no enough support,skip 557 1.0
check potential

no valid rule,skip 3706
search rule for feature 4386
no valid rule,skip 4386
search rule for feature 5070
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
no valid rule,skip 5070
search rule for feature 5882
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
no valid rule,skip 5882
search rule for feature 6291
check potential rule 6291 1.004 26.949494949494948 46.0 500
add rule 6291 (1.004, 26.949494949494948, 46.0, 500)
search rule for feature 5203
no valid rule,skip 5203
reach max depth
check potential rule 6291 1.004 26.949494949494948 46.0 500
add rule 6291 (

no valid rule,skip 5203
reach max depth
check potential rule 6291 1.004 26.949494949494948 46.0 500
add rule 6291 (1.004, 26.949494949494948, 46.0, 500)
search rule for feature 5203
no valid rule,skip 5203
reach max depth
check potential rule 9150 1.0070730775528025 48.72727272727273 67.0 502
add rule 9150 (1.0070730775528025, 48.72727272727273, 67.0, 502)
search rule for feature 3706
no valid rule,skip 3706
search rule for feature 4386
no valid rule,skip 4386
search rule for feature 5070
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
no valid rule,skip 5070
search rule for feature 5882
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check poten

no valid rule,skip 5203
reach max depth
check potential rule 6291 1.004 26.949494949494948 46.0 500
add rule 6291 (1.004, 26.949494949494948, 46.0, 500)
search rule for feature 5203
no valid rule,skip 5203
reach max depth
check potential rule 9150 1.0070730775528025 48.72727272727273 67.0 502
add rule 9150 (1.0070730775528025, 48.72727272727273, 67.0, 502)
search rule for feature 3706
no valid rule,skip 3706
search rule for feature 4386
no valid rule,skip 4386
search rule for feature 5070
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
check potential rule 5070 1.0 23.91919191919192 37.0 502
no enough support,skip 502 1.0
no valid rule,skip 5070
search rule for feature 5882
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check potential rule 5882 1.0 28.666666666666668 43.0 502
no enough support,skip 502 1.0
check poten

In [180]:
local_rules

[{'rules': [(9146, '>=', 53.0),
   (5338, '>=', 25.0),
   (5338, '<=', 37.0),
   (6702, '>=', 23.0),
   (6702, '<=', 46.0),
   (6704, '<=', 46.0),
   (5070, '>=', 17.0),
   (5070, '<=', 35.0),
   (6291, '>=', 33.0)],
  'cond_prob_target': 0.7261904761904762,
  'support': 504,
  'cond_prob_y': 0.1865079365079365,
  'ratio_y': 0.05399195864445721},
 {'rules': [(9146, '>=', 53.0),
   (5338, '>=', 25.0),
   (5338, '<=', 37.0),
   (6702, '>=', 37.0),
   (6702, '<=', 46.0),
   (5072, '<=', 35.0),
   (6704, '<=', 46.0),
   (9150, '>=', 47.0),
   (6291, '>=', 27.0)],
  'cond_prob_target': 0.72,
  'support': 500,
  'cond_prob_y': 0.186,
  'ratio_y': 0.053417576105686385},
 {'rules': [(9146, '>=', 53.0),
   (5338, '>=', 25.0),
   (5338, '<=', 37.0),
   (6702, '>=', 23.0),
   (6702, '<=', 46.0),
   (6704, '>=', 24.0),
   (6704, '<=', 46.0),
   (9150, '>=', 49.0),
   (6291, '>=', 27.0)],
  'cond_prob_target': 0.718,
  'support': 500,
  'cond_prob_y': 0.184,
  'ratio_y': 0.05284319356691557}]

In [127]:




y_rules = rlm.find_top_pattern_for_one_target(x,y_train,pred_y_train>=0.6,itemsets_y[0],c=1,num_grids=100,omega=0.2,
                                              min_support=500,max_depth=20,feature_types=feature_types*args.time_len,verbose=False)


[9147.0] 1599.0
[9147.0, 5339.0] 1597.0
[9147.0, 5339.0, 6699.0] 1595.0
[9147.0, 5339.0, 6699.0, 6703.0] 1593.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0] 1591.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0] 1590.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0] 1590.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0] 1590.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0] 1589.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0] 1588.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0, 6563.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0, 6563.0, 6705.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0, 9555.0, 9561.0, 4665.0, 5073.0, 6427.0, 6563.0, 6705.0, 9151.0] 1585.0
[9147.0, 5339.0, 6699.0, 6703.0, 9153.0, 9419.0,

search rule for feature 4658
search rule for feature 7923
add rule {'rule': [(7923, '>=', 23.434343434343432), (7923, '<=', 58.0)], 'support': (1.001890756302521, 500)}
search rule for feature 8330
search rule for feature 9010
search rule for feature 9147
search rule for feature 4256
search rule for feature 5474
search rule for feature 6566
search rule for feature 6838
search rule for feature 7654
search rule for feature 9011
search rule for feature 5066
search rule for feature 5746
search rule for feature 5747
search rule for feature 6019
9146 {'rule': [(9146, '>=', 63.61616161616162), (9146, '<=', 67.0)], 'support': (4.596094086996457, 568)}
5338 {'rule': [(5338, '>=', 36.63636363636363), (5338, '<=', 39.0)], 'support': (1.002951238701455, 529)}
6698 {'rule': [(6698, '>=', 46.525252525252526), (6698, '<=', 49.0)], 'support': (1.0020829403257248, 513)}
6702 {'rule': [(6702, '>=', 41.57575757575758), (6702, '<=', 49.0)], 'support': (1.0054204319944895, 506)}
9554 {'rule': [(9554, '>=',

In [128]:
y_rules

{'rules': [(9146, '>=', 64.0),
  (5338, '>=', 37.0),
  (6698, '>=', 47.0),
  (6702, '>=', 42.0),
  (9554, '>=', 65.0),
  (7923, '>=', 24.0)],
 'cond_prob_target': 0.95,
 'support': 500,
 'cond_prob_y': 0.138,
 'ratio_y': 0.03963239517518667}

In [129]:
y_rules["rules"] = rlm.replace_feature_names(y_rules["rules"],input_feature_names,time_index=True)

In [130]:
y_rules

{'rules': [(9146, 'HR_ctime_t67', '>=', 64.0),
  (5338, 'HR_ctime_t39', '>=', 37.0),
  (6698, 'HR_ctime_t49', '>=', 47.0),
  (6702, 'MAP_ctime_t49', '>=', 42.0),
  (9554, 'HR_ctime_t70', '>=', 65.0),
  (7923, 'O2Sat_ctime_t58', '>=', 24.0)],
 'cond_prob_target': 0.95,
 'support': 500,
 'cond_prob_y': 0.138,
 'ratio_y': 0.03963239517518667}

In [None]:
(y_train==0).sum()

In [None]:
tz.reshape(-1)[40],all_reps.reshape(all_reps.shape[0],-1)[:,40].min(),all_reps.reshape(all_reps.shape[0],-1)[:,40].max()

In [None]:
z = all_reps.reshape(all_reps.shape[0],-1)[:,138].numpy()
sns.displot(z)

In [None]:
sns.scatterplot(x=z,y=y_train)

In [None]:
z.shape,y_train.shape

In [None]:
np.arange(len(grids))[(grids -3.)<=0][-1]

In [None]:
X_train_raw[14][71,:][35]

In [None]:
xi_t[34],x[:,34].max()

In [None]:
torch.argsort(pred_y_train[:100],descending=True)

In [None]:
pred_y_train[14],y_train[14]

In [None]:
for zid in x_z_rules.keys():
    print(zid)
    sort_rules({zid:x_z_rules[zid]},input_feature_names,pos=(pred_y_train[tid]>0))

In [None]:
def sort_rules(z_rules,input_feature_names,sort_by="cond_prob_y",pos=True,top=3):
    sorted_rules = []
    for zid, zr_dict in z_rules.items():
        new_zr = []
        print(zid,zr_dict['pos'])
        if zr_dict['pos']==pos:
            list_name = 'rule_dict_higher_z'
            p_thd = 'p(z>=thd_h)'
            thd = 'thd_h'
        else:
            list_name = 'rule_dict_lower_z'
            p_thd = 'p(z<=thd_l)'
            thd = 'thd_l'
        ## for dict with only one side    
        if list_name not in zr_dict.keys():
            if 'p(z>=thd_h)' in zr_dict.keys():
                list_name = 'rule_dict_higher_z'  
                p_thd = 'p(z>=thd_h)'
                thd = 'thd_h'
            else:
                list_name = 'rule_dict_lower_z'
                p_thd = 'p(z<=thd_l)'
                thd = 'thd_l'

        for path,rules_dict in zr_dict[list_name].items():
            #print(path,rules_dict)
            new_zrd = {}
            new_r = []
            for r in rules_dict["rules"]:
                new_r.append((r[0],input_feature_names[r[0]],r[1],r[2]))
            new_zrd["rules"] = new_r
            new_zrd['zid'] = zid
            new_zrd[p_thd]=zr_dict[p_thd]
            new_zrd[thd]=zr_dict[thd] 
            for f in ['cond_prob_y','cond_prob_z','support','ratio_y']:
                new_zrd[f] = rules_dict[f]

            new_zr.append(new_zrd)

            
        new_zr.sort(key=lambda x: x[sort_by], reverse=True)
        sorted_rules = sorted_rules + new_zr[:top]

    sorted_rules.sort(key=lambda x: x[sort_by], reverse=True)
    return sorted_rules

In [None]:
pred_y_train[tid]

In [None]:
zid = top_latent[0]
time_step= int(zid/latent_num)
latent_id = int(zid%latent_num)

z = all_reps[:,time_step,latent_id].numpy()
zi = z[tid]
zid,zi

In [None]:
tz = all_reps[tid]
sort_zids = torch.argsort(torch.abs(tz.reshape(-1)*linear_prams[0]+linear_prams[1]))

In [None]:
sort_zids[0][-1]

In [None]:
match

In [None]:
zid

In [None]:
yi = y_train[tid]
yi

In [None]:
match

In [None]:
y_train.shape,z.shape

In [None]:
if match is None:
    zw_pos = zw[zid]
    x_z_rules = find_pattern_by_sample_latent_state(xi,zi,x,y_train,z,itemsets[zid],zw_pos=zw_pos,c=1,num_grids=20,omega=0.1,min_support_pos=500,min_support_neg=2000,max_depth=5)

In [None]:
x_z_rules

In [None]:
x_z_rules = sort_rules({zid:x_z_rules},input_feature_names,pos=False)

In [None]:
x_z_rules

In [None]:
g = np.linspace(1,100,100)

In [None]:
g[1]=1
g[2]=1


In [None]:
np.unique(g)

In [None]:
g[99]