This notebook is to prove the inconsistency of NID method.

We fisrt run the NID code and then compare it with our method.

The code will run for a few minutes.

In [1]:
import torch
import numpy as np
from neural_interaction_detection import get_interactions
from multilayer_perceptron import MLP, train, get_weights
from utils import preprocess_data, get_pairwise_auc, get_anyorder_R_precision, set_seed, print_rankings

In [2]:
use_main_effect_nets = True # toggle this to use "main effect" nets
num_samples = 30000
num_features = 8

## Generate synthetic data with ground truth interactions

In [3]:
def synth_func(X):
    X1, X2, X3, X4, X5, X6, X7, X8 = X.transpose()
    interaction1 = X1**2*X2*X3**2*X4                        
    interaction2 = X5**2*X6*X7**2*X8

    Y =              interaction1 + interaction2 
    ground_truth = [     {1,2,3,4},        {5,6,7,8}     ]
    
    return Y, ground_truth

In [4]:
set_seed(42)
X = np.random.uniform(low=-1, high=1, size=(num_samples,num_features))
Y, ground_truth = synth_func(X)
data_loaders = preprocess_data(X, Y, valid_size=10000, test_size=10000, std_scale=True, get_torch_loaders=True)

## Train a multilayer perceptron (MLP)

In [5]:
device = torch.device("cpu")
model = MLP(num_features, [140, 100, 60, 20], use_main_effect_nets=use_main_effect_nets).to(device)

In [6]:
model, mlp_loss = train(model, data_loaders, device=device, learning_rate=1e-2, l1_const = 5e-5, verbose=True)

starting to train
early stopping enabled
[epoch 1, total 100] train loss: 0.8319, val loss: 0.8109
[epoch 3, total 100] train loss: 0.5896, val loss: 0.6396
[epoch 5, total 100] train loss: 0.2602, val loss: 0.2320
[epoch 7, total 100] train loss: 0.1517, val loss: 0.2999
[epoch 9, total 100] train loss: 0.1350, val loss: 0.1484
[epoch 11, total 100] train loss: 0.0774, val loss: 0.1224
[epoch 13, total 100] train loss: 0.0862, val loss: 0.1022
[epoch 15, total 100] train loss: 0.0766, val loss: 0.0910
[epoch 17, total 100] train loss: 0.0715, val loss: 0.0839
[epoch 19, total 100] train loss: 0.0511, val loss: 0.0757
[epoch 21, total 100] train loss: 0.0790, val loss: 0.1116
[epoch 23, total 100] train loss: 0.0698, val loss: 0.1030
[epoch 25, total 100] train loss: 0.0740, val loss: 0.0621
[epoch 27, total 100] train loss: 0.0696, val loss: 0.0835
[epoch 29, total 100] train loss: 0.0487, val loss: 0.0788
[epoch 31, total 100] train loss: 0.0639, val loss: 0.0875
[epoch 33, total 100

## Get the MLP's learned weights

In [7]:
model_weights = get_weights(model)

## Detect interactions from the weights

In [8]:
pairwise_interactions = get_interactions(model_weights, pairwise=True, one_indexed=True)
pairwise_interactions

[((2, 4), 70.7305),
 ((6, 8), 61.814903),
 ((5, 6), 35.181442),
 ((6, 7), 33.913887),
 ((5, 7), 29.893131),
 ((5, 8), 22.185238),
 ((2, 8), 20.204613),
 ((2, 6), 19.646027),
 ((7, 8), 18.416555),
 ((4, 8), 18.112186),
 ((4, 6), 17.493275),
 ((1, 6), 15.431152),
 ((1, 3), 15.218493),
 ((2, 7), 14.409973),
 ((1, 5), 13.93491),
 ((3, 6), 13.413807),
 ((4, 5), 13.235764),
 ((2, 3), 13.13788),
 ((1, 8), 12.991602),
 ((1, 2), 12.955213),
 ((2, 5), 12.950325),
 ((1, 7), 12.683628),
 ((3, 8), 12.498389),
 ((4, 7), 12.068602),
 ((3, 7), 12.059618),
 ((3, 5), 11.738636),
 ((1, 4), 11.333738),
 ((3, 4), 10.039406)]

## Detect interactions with our principled method

In [9]:
# we use the exact same model trained above
from UCBtools import *
X_torch = torch.Tensor(X)
UCB_interactions=detect_Hessian_UCB(model,X_torch,20) 
(UCB_interactions)

start dectecting
Initialization done! initial try3times
chosen arm: 0 strength: 0.9775244333012112 iteration: 177
chosen arm: 7 strength: 0.6998232301504943 iteration: 208
chosen arm: 2 strength: 0.8073215459666722 iteration: 275
chosen arm: 13 strength: 0.6627016655325304 iteration: 338
chosen arm: 27 strength: 0.8149593249545433 iteration: 359
chosen arm: 25 strength: 0.40065102562099736 iteration: 385
chosen arm: 24 strength: 5.1987242418581445 iteration: 423
chosen arm: 1 strength: 0.09203294216400458 iteration: 477
chosen arm: 22 strength: 0.40091696214383576 iteration: 487
chosen arm: 26 strength: 0.16952629989439932 iteration: 493
chosen arm: 23 strength: 0.09347923261619893 iteration: 524
chosen arm: 8 strength: 0.02082446837529359 iteration: 620
chosen arm: 16 strength: 0.010436514839125452 iteration: 681
chosen arm: 5 strength: 0.00732620194638539 iteration: 773
chosen arm: 9 strength: 0.004539545773195402 iteration: 858
chosen arm: 11 strength: 0.009087912594623049 iteration

([array([0, 1]),
  array([1, 2]),
  array([0, 3]),
  array([2, 3]),
  array([6, 7]),
  array([5, 6]),
  array([4, 7]),
  array([0, 2]),
  array([4, 5]),
  array([5, 7]),
  array([4, 6]),
  array([1, 3]),
  array([2, 6]),
  array([0, 6]),
  array([1, 4]),
  array([1, 6]),
  array([0, 5]),
  array([2, 5]),
  array([0, 4]),
  array([3, 4])],
 9.656995296478271)

In [10]:
# transfer pairwise interaction format only
our_pairwise_interactions=[]
c=range(len(UCB_interactions[0]))
d=sorted(c, reverse=True)
for i in c:
    our_pairwise_interactions.append(((UCB_interactions[0][i][0]+1,UCB_interactions[0][i][1]+1),d[i]))
(our_pairwise_interactions)

[((1, 2), 19),
 ((2, 3), 18),
 ((1, 4), 17),
 ((3, 4), 16),
 ((7, 8), 15),
 ((6, 7), 14),
 ((5, 8), 13),
 ((1, 3), 12),
 ((5, 6), 11),
 ((6, 8), 10),
 ((5, 7), 9),
 ((2, 4), 8),
 ((3, 7), 7),
 ((1, 7), 6),
 ((2, 5), 5),
 ((2, 7), 4),
 ((1, 6), 3),
 ((3, 6), 2),
 ((1, 5), 1),
 ((4, 5), 0)]

## Evaluate the interactions (NID & Our method)

In [11]:
NIDauc = get_pairwise_auc(pairwise_interactions, ground_truth)
ourauc = get_pairwise_auc(our_pairwise_interactions, ground_truth)
print("NID Pairwise AUC", NIDauc)
print("Our Pairwise AUC", ourauc)

NID Pairwise AUC 0.6979166666666666
Our Pairwise AUC 1.0
