In [1]:
import utils, RKHS_DAGMA_extractj
import torch
import time
import numpy as np
import matplotlib.pyplot as plt
from CausalDisco.analytics import r2_sortability, var_sortability
from CausalDisco.baselines import r2_sort_regress, var_sort_regress
from cdt.metrics import SID, SHD
import seaborn as sns
import cdt
import multiprocessing
import threading
import json
from sklearn.preprocessing import StandardScaler

Detecting 1 CUDA device(s).


In [2]:
sns.set_context("paper")

torch.set_default_dtype(torch.float64)

device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)

# utils.set_random_seed(1)
# torch.manual_seed(1)

cdt.SETTINGS.rpath= 'C:\Program Files\R\R-4.3.3\\bin\Rscript'

num_cores = multiprocessing.cpu_count()

# Print the number of CPU cores
# print("Number of CPU cores:", num_cores)
# print(torch.__version__)

# Do RKHS_DAGMA Discovery

In [3]:
def run_model(lambda1, tau, X, device, B_true, gamma, T=6, lr = 0.03):
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}
    
    X = X.to(device)
    eq_model = RKHS_DAGMA_extractj.DagmaRKHS(X, gamma = gamma).to(device)
    model = RKHS_DAGMA_extractj.DagmaRKHS_nonlinear(eq_model)
    
    x_est_start = eq_model.forward()
    start_mse = eq_model.mse(x_est_start).detach().cpu().numpy()
    time_start = time.time()
    W_est_no_thresh, output = model.fit(X, lambda1=lambda1, tau=tau, T = T, mu_init = 1.0, lr=lr, w_threshold=0.0)
    time_end = time.time()

    thresh_values = [0.005, 0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
    for thresh in thresh_values:
        try:
            W_est_dagma = abs(W_est_no_thresh) * (abs(W_est_no_thresh) > thresh)
            acc_dagma = utils.count_accuracy(B_true, W_est_dagma != 0)
            valid = "yes"
        #sid_dagma = SID(B_true, W_est_dagma != 0).item()
        except Exception as e:
            W_est_dagma = results_var_sort_regress
            acc_dagma = utils.count_accuracy(B_true, W_est_dagma != 0)
            valid = "no"
            
        diff_dagma = np.linalg.norm(W_est_dagma - abs(B_true))
        x_est = eq_model.forward()
        mse_dagma = eq_model.mse(x_est).detach().cpu().numpy()
        W_est = eq_model.fc1_to_adj()
        h_val = eq_model.h_func(W_est, s=1).detach().cpu().numpy()
        #parameters = eq_model.get_parameters()
        W_est_dagma = W_est_dagma.tolist()
        B = B_true.tolist()
        key = f'lambda_{lambda1}_tau_{tau}_thresh_{thresh}_gamma_{gamma}'
        results[key] = {
            'SHD': acc_dagma['shd'],
            'TPR': acc_dagma['tpr'],
            'Time Elapsed': time_end - time_start,
            'F1': acc_dagma['f1'],
            'diff': diff_dagma,
            'mse': mse_dagma.item(),
            'valid': valid,
            'h_val': h_val.item(),
            'start mse': start_mse,
            #'W_est_dagma': W_est_dagma,
            #'parameters': parameters,
            #'eq_model': eq_model,
            #'W_est_no_thresh': W_est_no_thresh,
            #'B_true': B
        }

    torch.cuda.empty_cache()

    # filename = f'model_results_lambda_{lambda1}_tau_{tau}.txt'

    # with open(filename, 'w') as file:
    #     json.dump(results, file, indent=4) 

    return results

In [4]:
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print('device:', device)
torch.set_default_device(device)

device: cuda


## d = 10

In [5]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 10, 40, 'ER', 'gp-add' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [6]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (10,)


3.178208630818641

In [7]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_3.178208630818641': {'SHD': 10, 'TPR': 0.875, 'Time Elapsed': 111.54160380363464, 'F1': 0.823529411764706, 'diff': 4.788222279912862, 'mse': 5.5314629353657665, 'valid': 'yes', 'h_val': 5.752303225031771e-07, 'start mse': array(25.27285238)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_3.178208630818641': {'SHD': 10, 'TPR': 0.875, 'Time Elapsed': 111.54160380363464, 'F1': 0.823529411764706, 'diff': 4.788222279912862, 'mse': 5.5314629353657665, 'valid': 'yes', 'h_val': 5.752303225031771e-07, 'start mse': array(25.27285238)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_3.178208630818641': {'SHD': 10, 'TPR': 0.875, 'Time Elapsed': 111.54160380363464, 'F1': 0.823529411764706, 'diff': 4.788222279912862, 'mse': 5.5314629353657665, 'valid': 'yes', 'h_val': 5.752303225031771e-07, 'start mse': array(25.27285238)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_3.178208630818641': {'SHD': 15, 'TPR': 0.75, 'Time Elapsed': 111.54160380363464, 'F1': 0.759493670

In [14]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 10, 40, 'ER', 'gp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [15]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (10,)


3.178208630818641

In [16]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_3.178208630818641': {'SHD': 12, 'TPR': 0.825, 'Time Elapsed': 116.18225765228271, 'F1': 0.7764705882352941, 'diff': 5.439742479059952, 'mse': 5.196008865737593, 'valid': 'yes', 'h_val': 6.066906434313414e-07, 'start mse': array(8.75967607)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_3.178208630818641': {'SHD': 12, 'TPR': 0.825, 'Time Elapsed': 116.18225765228271, 'F1': 0.7764705882352941, 'diff': 5.439742479059952, 'mse': 5.196008865737593, 'valid': 'yes', 'h_val': 6.066906434313414e-07, 'start mse': array(8.75967607)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_3.178208630818641': {'SHD': 14, 'TPR': 0.775, 'Time Elapsed': 116.18225765228271, 'F1': 0.7654320987654322, 'diff': 5.454810483613001, 'mse': 5.196008865737593, 'valid': 'yes', 'h_val': 6.066906434313414e-07, 'start mse': array(8.75967607)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_3.178208630818641': {'SHD': 16, 'TPR': 0.7, 'Time Elapsed': 116.18225765228271, 'F1': 0.7368421052631

In [17]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 10, 40, 'ER', 'mlp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [18]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (10,)


3.1782086308186415

In [19]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_3.1782086308186415': {'SHD': 11, 'TPR': 0.85, 'Time Elapsed': 126.6613118648529, 'F1': 0.7999999999999998, 'diff': 4.728593918590755, 'mse': 15.105515544973583, 'valid': 'yes', 'h_val': 7.562950843365002e-08, 'start mse': array(233.45203392)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_3.1782086308186415': {'SHD': 11, 'TPR': 0.85, 'Time Elapsed': 126.6613118648529, 'F1': 0.7999999999999998, 'diff': 4.728593918590755, 'mse': 15.105515544973583, 'valid': 'yes', 'h_val': 7.562950843365002e-08, 'start mse': array(233.45203392)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_3.1782086308186415': {'SHD': 11, 'TPR': 0.85, 'Time Elapsed': 126.6613118648529, 'F1': 0.7999999999999998, 'diff': 4.728593918590755, 'mse': 15.105515544973583, 'valid': 'yes', 'h_val': 7.562950843365002e-08, 'start mse': array(233.45203392)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_3.1782086308186415': {'SHD': 14, 'TPR': 0.775, 'Time Elapsed': 126.6613118648529, 'F1': 0.75609

## d = 20

In [11]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 20, 80, 'ER', 'gp-add' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [12]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (20,)


4.494665749754947

In [13]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_4.494665749754947': {'SHD': 57, 'TPR': 0.4, 'Time Elapsed': 213.7935333251953, 'F1': 0.512, 'diff': 9.549452203399568, 'mse': 9.070456274509697, 'valid': 'no', 'h_val': 2.6989195856332477e-06, 'start mse': array(45.23177469)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_4.494665749754947': {'SHD': 57, 'TPR': 0.4, 'Time Elapsed': 213.7935333251953, 'F1': 0.512, 'diff': 9.549452203399568, 'mse': 9.070456274509697, 'valid': 'no', 'h_val': 2.6989195856332477e-06, 'start mse': array(45.23177469)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_4.494665749754947': {'SHD': 80, 'TPR': 0.6375, 'Time Elapsed': 213.7935333251953, 'F1': 0.5454545454545454, 'diff': 8.34355358366529, 'mse': 9.070456274509697, 'valid': 'yes', 'h_val': 2.6989195856332477e-06, 'start mse': array(45.23177469)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_4.494665749754947': {'SHD': 82, 'TPR': 0.25, 'Time Elapsed': 213.7935333251953, 'F1': 0.3252032520325203, 'diff': 8.59191405450685

In [20]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 20, 80, 'ER', 'gp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [21]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (20,)


4.494665749754947

In [22]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_4.494665749754947': {'SHD': 77, 'TPR': 0.0625, 'Time Elapsed': 259.99971771240234, 'F1': 0.11494252873563218, 'diff': 8.979311720307308, 'mse': 8.762290217358036, 'valid': 'no', 'h_val': 2.933543130921039e-06, 'start mse': array(16.65709448)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_4.494665749754947': {'SHD': 77, 'TPR': 0.0625, 'Time Elapsed': 259.99971771240234, 'F1': 0.11494252873563218, 'diff': 8.979311720307308, 'mse': 8.762290217358036, 'valid': 'no', 'h_val': 2.933543130921039e-06, 'start mse': array(16.65709448)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_4.494665749754947': {'SHD': 90, 'TPR': 0.35, 'Time Elapsed': 259.99971771240234, 'F1': 0.3544303797468354, 'diff': 8.714333146792647, 'mse': 8.762290217358036, 'valid': 'yes', 'h_val': 2.933543130921039e-06, 'start mse': array(16.65709448)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_4.494665749754947': {'SHD': 78, 'TPR': 0.075, 'Time Elapsed': 259.99971771240234, 'F1': 0.1304347

In [23]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 20, 80, 'ER', 'mlp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [24]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (20,)


4.494665749754947

In [25]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_4.494665749754947': {'SHD': 89, 'TPR': 0.3125, 'Time Elapsed': 201.7218062877655, 'F1': 0.32258064516129037, 'diff': 11.573401030471501, 'mse': 56.189253245633864, 'valid': 'no', 'h_val': 3.308850372527302e-06, 'start mse': array(566.18164828)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_4.494665749754947': {'SHD': 89, 'TPR': 0.3125, 'Time Elapsed': 201.7218062877655, 'F1': 0.32258064516129037, 'diff': 11.573401030471501, 'mse': 56.189253245633864, 'valid': 'no', 'h_val': 3.308850372527302e-06, 'start mse': array(566.18164828)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_4.494665749754947': {'SHD': 134, 'TPR': 0.325, 'Time Elapsed': 201.7218062877655, 'F1': 0.23318385650224208, 'diff': 8.832931582726905, 'mse': 56.189253245633864, 'valid': 'yes', 'h_val': 3.308850372527302e-06, 'start mse': array(566.18164828)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_4.494665749754947': {'SHD': 120, 'TPR': 0.15, 'Time Elapsed': 201.7218062877655, 'F1': 0.

## d = 30

In [8]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 30, 120, 'ER', 'gp-add' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [9]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (30,)


5.504818825631803

In [10]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_5.504818825631803': {'SHD': 100, 'TPR': 0.31666666666666665, 'Time Elapsed': 657.9865639209747, 'F1': 0.4153005464480874, 'diff': 11.061193522749809, 'mse': 11.910339478139663, 'valid': 'no', 'h_val': 8.758266565100176e-06, 'start mse': array(82.66305331)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_5.504818825631803': {'SHD': 100, 'TPR': 0.31666666666666665, 'Time Elapsed': 657.9865639209747, 'F1': 0.4153005464480874, 'diff': 11.061193522749809, 'mse': 11.910339478139663, 'valid': 'no', 'h_val': 8.758266565100176e-06, 'start mse': array(82.66305331)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_5.504818825631803': {'SHD': 152, 'TPR': 0.16666666666666666, 'Time Elapsed': 657.9865639209747, 'F1': 0.2, 'diff': 10.815765689591244, 'mse': 11.910339478139663, 'valid': 'yes', 'h_val': 8.758266565100176e-06, 'start mse': array(82.66305331)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_5.504818825631803': {'SHD': 137, 'TPR': 0.08333333333333333, 'Time 

In [26]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 30, 120, 'ER', 'gp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [27]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (30,)


5.504818825631802

In [28]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_5.504818825631802': {'SHD': 120, 'TPR': 0.11666666666666667, 'Time Elapsed': 623.3074667453766, 'F1': 0.18918918918918917, 'diff': 10.92158351096266, 'mse': 10.81834864435269, 'valid': 'no', 'h_val': 9.963814463609275e-06, 'start mse': array(28.51419061)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_5.504818825631802': {'SHD': 120, 'TPR': 0.11666666666666667, 'Time Elapsed': 623.3074667453766, 'F1': 0.18918918918918917, 'diff': 10.92158351096266, 'mse': 10.81834864435269, 'valid': 'no', 'h_val': 9.963814463609275e-06, 'start mse': array(28.51419061)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_5.504818825631802': {'SHD': 146, 'TPR': 0.10833333333333334, 'Time Elapsed': 623.3074667453766, 'F1': 0.143646408839779, 'diff': 10.87873481406377, 'mse': 10.81834864435269, 'valid': 'yes', 'h_val': 9.963814463609275e-06, 'start mse': array(28.51419061)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_5.504818825631802': {'SHD': 120, 'TPR': 0.016666666666666

In [29]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 30, 120, 'ER', 'mlp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [30]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
variances = np.var(X_scaled, axis=0, ddof=1)
print("variance: ", variances.shape)
average_variance = np.mean(variances)
gamma = 1 / (d * average_variance)
gamma = np.sqrt(1/gamma)
gamma

variance:  (30,)


5.504818825631803

In [31]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_5.504818825631803': {'SHD': 136, 'TPR': 0.38333333333333336, 'Time Elapsed': 658.035590171814, 'F1': 0.3739837398373984, 'diff': 14.821410204881557, 'mse': 133.10486720750913, 'valid': 'no', 'h_val': 1.89771960400576e-05, 'start mse': array(1327.67953455)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_5.504818825631803': {'SHD': 136, 'TPR': 0.38333333333333336, 'Time Elapsed': 658.035590171814, 'F1': 0.3739837398373984, 'diff': 14.821410204881557, 'mse': 133.10486720750913, 'valid': 'no', 'h_val': 1.89771960400576e-05, 'start mse': array(1327.67953455)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_5.504818825631803': {'SHD': 270, 'TPR': 0.19166666666666668, 'Time Elapsed': 658.035590171814, 'F1': 0.12885154061624654, 'diff': 10.918137717658894, 'mse': 133.10486720750913, 'valid': 'yes', 'h_val': 1.89771960400576e-05, 'start mse': array(1327.67953455)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_5.504818825631803': {'SHD': 198, 'TPR': 0.058333333

In [32]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = gamma, T = 7)
print(results)

  0%|          | 0/38000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_5.504818825631803': {'SHD': 136, 'TPR': 0.38333333333333336, 'Time Elapsed': 761.8567404747009, 'F1': 0.3739837398373984, 'diff': 14.821410204881557, 'mse': 133.55946718635593, 'valid': 'no', 'h_val': 2.6893091759744644e-07, 'start mse': array(1327.67953455)}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_5.504818825631803': {'SHD': 136, 'TPR': 0.38333333333333336, 'Time Elapsed': 761.8567404747009, 'F1': 0.3739837398373984, 'diff': 14.821410204881557, 'mse': 133.55946718635593, 'valid': 'no', 'h_val': 2.6893091759744644e-07, 'start mse': array(1327.67953455)}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_5.504818825631803': {'SHD': 278, 'TPR': 0.19166666666666668, 'Time Elapsed': 761.8567404747009, 'F1': 0.1253405994550409, 'diff': 10.918919402254616, 'mse': 133.55946718635593, 'valid': 'yes', 'h_val': 2.6893091759744644e-07, 'start mse': array(1327.67953455)}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_5.504818825631803': {'SHD': 200, 'TPR': 0.0