In [1]:
import utils, RKHS_DAGMA_extractj
import torch
import time
import numpy as np
import matplotlib.pyplot as plt
from CausalDisco.analytics import r2_sortability, var_sortability
from CausalDisco.baselines import r2_sort_regress, var_sort_regress
from cdt.metrics import SID, SHD
import seaborn as sns
import cdt
import multiprocessing
import threading
import json

Detecting 1 CUDA device(s).


In [2]:
sns.set_context("paper")

torch.set_default_dtype(torch.float64)

device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)

# utils.set_random_seed(1)
# torch.manual_seed(1)

cdt.SETTINGS.rpath= 'C:\Program Files\R\R-4.3.3\\bin\Rscript'

num_cores = multiprocessing.cpu_count()

# Do RKHS_DAGMA Discovery

In [3]:
def run_model(lambda1, tau, X, device, B_true, gamma, T=5, lr = 0.03):
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}
    
    X = X.to(device)
    eq_model = RKHS_DAGMA_extractj.DagmaRKHS(X, gamma = gamma).to(device)
    model = RKHS_DAGMA_extractj.DagmaRKHS_nonlinear(eq_model)
    
    x_est_start = eq_model.forward()
    start_mse = eq_model.mse(x_est_start).detach().cpu().numpy()
    time_start = time.time()
    W_est_no_thresh, output = model.fit(X, lambda1=lambda1, tau=tau, T = T, mu_init = 1.0, lr=lr, w_threshold=0.0)
    time_end = time.time()

    thresh_values = [0.005, 0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
    for thresh in thresh_values:
        try:
            W_est_dagma = abs(W_est_no_thresh) * (abs(W_est_no_thresh) > thresh)
            acc_dagma = utils.count_accuracy(B_true, W_est_dagma != 0)
            valid = "yes"
        #sid_dagma = SID(B_true, W_est_dagma != 0).item()
        except Exception as e:
            W_est_dagma = results_var_sort_regress
            acc_dagma = utils.count_accuracy(B_true, W_est_dagma != 0)
            valid = "no"
            
        diff_dagma = np.linalg.norm(W_est_dagma - abs(B_true))
        x_est = eq_model.forward()
        mse_dagma = eq_model.mse(x_est).detach().cpu().numpy()
        W_est = eq_model.fc1_to_adj()
        h_val = eq_model.h_func(W_est, s=1).detach().cpu().numpy()
        parameters = eq_model.get_parameters()
        W_est_dagma = W_est_dagma.tolist()
        B = B_true.tolist()
        key = f'lambda_{lambda1}_tau_{tau}_thresh_{thresh}_gamma_{gamma}'
        results[key] = {
            'SHD': acc_dagma['shd'],
            'TPR': acc_dagma['tpr'],
            'Time Elapsed': time_end - time_start,
            'F1': acc_dagma['f1'],
            'diff': diff_dagma,
            'mse': mse_dagma.item(),
            'valid': valid,
            'h_val': h_val.item(),
            'start mse': start_mse.item(),
            #'W_est_dagma': W_est_dagma,
            #'parameters': parameters,
            #'eq_model': eq_model,
            #'W_est_no_thresh': W_est_no_thresh,
            #'B_true': B
        }

    torch.cuda.empty_cache()

    # filename = f'model_results_lambda_{lambda1}_tau_{tau}_gamma_{gamma}.txt'

    # with open(filename, 'w') as file:
    #     json.dump(results, file, indent=4) 

    return results

In [4]:
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print('device:', device)
torch.set_default_device(device)

device: cuda


## d = 40

In [5]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 40, 160, 'ER', 'gp-add' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [6]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = 16, T=6, lr= 0.03)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_16': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 2419.8881788253784, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 51.38136997431348, 'valid': 'no', 'h_val': 4.870244527114279e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_16': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 2419.8881788253784, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 51.38136997431348, 'valid': 'no', 'h_val': 4.870244527114279e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_16': {'SHD': 176, 'TPR': 0.40625, 'Time Elapsed': 2419.8881788253784, 'F1': 0.40880503144654085, 'diff': 12.131661318435935, 'mse': 51.38136997431348, 'valid': 'yes', 'h_val': 4.870244527114279e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_16': {'SHD': 145, 'TPR': 0.19375, 'Time Elapsed': 2419.8881788253784, 'F1': 0.29523809523809524, 'diff': 12.314893379426465, 'mse': 51.

In [5]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 40, 160, 'ER', 'gp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [8]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = 16, T=6, lr= 0.03)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_16': {'SHD': 162, 'TPR': 0.04375, 'Time Elapsed': 3079.360290288925, 'F1': 0.07734806629834254, 'diff': 12.621631656561732, 'mse': 28.620857185472133, 'valid': 'no', 'h_val': 5.29215299455607e-06, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_16': {'SHD': 162, 'TPR': 0.04375, 'Time Elapsed': 3079.360290288925, 'F1': 0.07734806629834254, 'diff': 12.621631656561732, 'mse': 28.620857185472133, 'valid': 'no', 'h_val': 5.29215299455607e-06, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_16': {'SHD': 162, 'TPR': 0.04375, 'Time Elapsed': 3079.360290288925, 'F1': 0.07734806629834254, 'diff': 12.621631656561732, 'mse': 28.620857185472133, 'valid': 'no', 'h_val': 5.29215299455607e-06, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_16': {'SHD': 167, 'TPR': 0.04375, 'Time Elapsed': 3079.360290288925, 'F1': 0.0748663101604278, 'diff': 12.585532995259692, 'mse': 28.6208

In [6]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = 16, T=7, lr= 0.03)
print(results)

  0%|          | 0/38000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_16': {'SHD': 162, 'TPR': 0.04375, 'Time Elapsed': 3864.0099194049835, 'F1': 0.07734806629834254, 'diff': 12.621631656561732, 'mse': 29.98038585187362, 'valid': 'no', 'h_val': 4.051712796272634e-07, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_16': {'SHD': 162, 'TPR': 0.04375, 'Time Elapsed': 3864.0099194049835, 'F1': 0.07734806629834254, 'diff': 12.621631656561732, 'mse': 29.98038585187362, 'valid': 'no', 'h_val': 4.051712796272634e-07, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_16': {'SHD': 232, 'TPR': 0.18125, 'Time Elapsed': 3864.0099194049835, 'F1': 0.19141914191419143, 'diff': 12.486368093483742, 'mse': 29.98038585187362, 'valid': 'yes', 'h_val': 4.051712796272634e-07, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_16': {'SHD': 167, 'TPR': 0.04375, 'Time Elapsed': 3864.0099194049835, 'F1': 0.07526881720430108, 'diff': 12.587399705630412, 'mse': 2

In [11]:
utils.set_random_seed(1)
torch.manual_seed(1)
n, d, s0, graph_type, sem_type = 100, 40, 160, 'ER', 'gp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [12]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = 16, T=6, lr= 0.03)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_16': {'SHD': 157, 'TPR': 0.1125, 'Time Elapsed': 3297.2519915103912, 'F1': 0.18367346938775508, 'diff': 12.526856337014774, 'mse': 28.336569124081006, 'valid': 'no', 'h_val': 5.357851334662559e-06, 'start mse': 36.19370792903379}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_16': {'SHD': 157, 'TPR': 0.1125, 'Time Elapsed': 3297.2519915103912, 'F1': 0.18367346938775508, 'diff': 12.526856337014774, 'mse': 28.336569124081006, 'valid': 'no', 'h_val': 5.357851334662559e-06, 'start mse': 36.19370792903379}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_16': {'SHD': 157, 'TPR': 0.1125, 'Time Elapsed': 3297.2519915103912, 'F1': 0.18367346938775508, 'diff': 12.526856337014774, 'mse': 28.336569124081006, 'valid': 'no', 'h_val': 5.357851334662559e-06, 'start mse': 36.19370792903379}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_16': {'SHD': 171, 'TPR': 0.0625, 'Time Elapsed': 3297.2519915103912, 'F1': 0.10256410256410256, 'diff': 12.548830287230937, 'mse': 28.

In [7]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 40, 160, 'ER', 'mlp' # s0 = expected number of edges
B_true = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X = utils.simulate_nonlinear_sem(B_true, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X)
X = torch.from_numpy(X)

In [10]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = 16, T=6, lr= 0.03)
print(results)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_16': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 2637.6562778949738, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 411.5820018463442, 'valid': 'no', 'h_val': 0.0013344427467220782, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_16': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 2637.6562778949738, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 411.5820018463442, 'valid': 'no', 'h_val': 0.0013344427467220782, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_16': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 2637.6562778949738, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 411.5820018463442, 'valid': 'no', 'h_val': 0.0013344427467220782, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_16': {'SHD': 209, 'TPR': 0.0375, 'Time Elapsed': 2637.6562778949738, 'F1': 0.051063829787234026, 'diff': 12.65570140137631, 'mse': 411.58

In [8]:
results = run_model(lambda1 = 1e-3, tau = 1e-4, X = X, device = device, B_true = B_true, gamma = 16, T=7, lr= 0.03)
print(results)

  0%|          | 0/38000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_16': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 3416.2624838352203, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 697.3262170204313, 'valid': 'no', 'h_val': 8.645324390456802e-05, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_16': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 3416.2624838352203, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 697.3262170204313, 'valid': 'no', 'h_val': 8.645324390456802e-05, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_16': {'SHD': 250, 'TPR': 0.08125, 'Time Elapsed': 3416.2624838352203, 'F1': 0.087248322147651, 'diff': 12.660664810661958, 'mse': 697.3262170204313, 'valid': 'yes', 'h_val': 8.645324390456802e-05, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_16': {'SHD': 198, 'TPR': 0.00625, 'Time Elapsed': 3416.2624838352203, 'F1': 0.009615384615384621, 'diff': 12.707480076635907, 'mse': 697.

## d = 40 (other test)

In [5]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 40, 160, 'ER', 'gp-add' # s0 = expected number of edges
B_true_40 = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X_40 = utils.simulate_nonlinear_sem(B_true_40, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X_40)
X_40 = torch.from_numpy(X_40)

In [6]:
result_40 = run_model(lambda1 = 1e-3, tau = 1e-4, X = X_40, device = device, B_true = B_true_40, gamma = 5, T=6, lr= 0.03)
print(result_40)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_5': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 1763.4229617118835, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 13.235900923376636, 'valid': 'no', 'h_val': 9.408637880365182e-07, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_5': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 1763.4229617118835, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 13.235900923376636, 'valid': 'no', 'h_val': 9.408637880365182e-07, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_5': {'SHD': 160, 'TPR': 0.0, 'Time Elapsed': 1763.4229617118835, 'F1': 0.0, 'diff': 12.649110640673518, 'mse': 13.235900923376636, 'valid': 'yes', 'h_val': 9.408637880365182e-07, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_5': {'SHD': 160, 'TPR': 0.0, 'Time Elapsed': 1763.4229617118835, 'F1': 0.0, 'diff': 12.649110640673518, 'mse': 13.235900923376636, 'valid': 'yes', 'h_val':

In [6]:
result_40 = run_model(lambda1 = 1e-3, tau = 1e-4, X = X_40, device = device, B_true = B_true_40, gamma = 15, T=6, lr= 0.03)
print(result_40)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_15': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 2314.790427684784, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 51.11033842013492, 'valid': 'no', 'h_val': 5.1950516501837694e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_15': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 2314.790427684784, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 51.11033842013492, 'valid': 'no', 'h_val': 5.1950516501837694e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_15': {'SHD': 175, 'TPR': 0.39375, 'Time Elapsed': 2314.790427684784, 'F1': 0.4064516129032259, 'diff': 12.15993615139483, 'mse': 51.11033842013492, 'valid': 'yes', 'h_val': 5.1950516501837694e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_15': {'SHD': 146, 'TPR': 0.19375, 'Time Elapsed': 2314.790427684784, 'F1': 0.29523809523809524, 'diff': 12.32744396162458, 'mse': 51.1103

In [8]:
result_40 = run_model(lambda1 = 1e-3, tau = 1e-4, X = X_40, device = device, B_true = B_true_40, gamma = 20, T=6, lr= 0.03)
print(result_40)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_20': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 3085.1008954048157, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 52.898395498180946, 'valid': 'no', 'h_val': 3.431714412942995e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_20': {'SHD': 144, 'TPR': 0.38125, 'Time Elapsed': 3085.1008954048157, 'F1': 0.4535315985130111, 'diff': 13.827468670889223, 'mse': 52.898395498180946, 'valid': 'no', 'h_val': 3.431714412942995e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_20': {'SHD': 185, 'TPR': 0.40625, 'Time Elapsed': 3085.1008954048157, 'F1': 0.393939393939394, 'diff': 12.141875467122972, 'mse': 52.898395498180946, 'valid': 'yes', 'h_val': 3.431714412942995e-05, 'start mse': 124.6053332243056}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_20': {'SHD': 148, 'TPR': 0.20625, 'Time Elapsed': 3085.1008954048157, 'F1': 0.3013698630136986, 'diff': 12.305376604218946, 'mse': 52.

In [6]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 40, 160, 'ER', 'gp' # s0 = expected number of edges
B_true_40 = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X_40 = utils.simulate_nonlinear_sem(B_true_40, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X_40)
X_40 = torch.from_numpy(X_40)

In [7]:
result_40 = run_model(lambda1 = 1e-3, tau = 1e-4, X = X_40, device = device, B_true = B_true_40, gamma = 15, T=6, lr= 0.03)
print(result_40)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_15': {'SHD': 162, 'TPR': 0.04375, 'Time Elapsed': 3081.624889612198, 'F1': 0.07734806629834254, 'diff': 12.621631656561732, 'mse': 28.196594407773702, 'valid': 'no', 'h_val': 5.817910550851333e-06, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_15': {'SHD': 162, 'TPR': 0.04375, 'Time Elapsed': 3081.624889612198, 'F1': 0.07734806629834254, 'diff': 12.621631656561732, 'mse': 28.196594407773702, 'valid': 'no', 'h_val': 5.817910550851333e-06, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_15': {'SHD': 258, 'TPR': 0.1875, 'Time Elapsed': 3081.624889612198, 'F1': 0.17441860465116277, 'diff': 12.457318004375743, 'mse': 28.196594407773702, 'valid': 'yes', 'h_val': 5.817910550851333e-06, 'start mse': 37.04060189095214}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_15': {'SHD': 166, 'TPR': 0.05625, 'Time Elapsed': 3081.624889612198, 'F1': 0.09473684210526316, 'diff': 12.551448615451518, 'mse': 28.

In [8]:
utils.set_random_seed(0)
torch.manual_seed(0)
n, d, s0, graph_type, sem_type = 100, 40, 160, 'ER', 'mlp' # s0 = expected number of edges
B_true_40 = utils.simulate_dag(d, s0, graph_type) # [d, d] binary adj matrix of DAG
X_40 = utils.simulate_nonlinear_sem(B_true_40, n, sem_type) # [n, d] sample matrix
results_var_sort_regress = var_sort_regress(X_40)
X_40 = torch.from_numpy(X_40)

In [9]:
result_40 = run_model(lambda1 = 1e-3, tau = 1e-4, X = X_40, device = device, B_true = B_true_40, gamma = 15, T=6, lr= 0.03)
print(result_40)

  0%|          | 0/33000.0 [00:00<?, ?it/s]

{'lambda_0.001_tau_0.0001_thresh_0.005_gamma_15': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 2639.4664521217346, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 450.6736743203465, 'valid': 'no', 'h_val': 0.0012789279337924034, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.01_gamma_15': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 2639.4664521217346, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 450.6736743203465, 'valid': 'no', 'h_val': 0.0012789279337924034, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.05_gamma_15': {'SHD': 204, 'TPR': 0.29375, 'Time Elapsed': 2639.4664521217346, 'F1': 0.2883435582822086, 'diff': 14.586934967975731, 'mse': 450.6736743203465, 'valid': 'no', 'h_val': 0.0012789279337924034, 'start mse': 1608.782437108537}, 'lambda_0.001_tau_0.0001_thresh_0.1_gamma_15': {'SHD': 211, 'TPR': 0.0375, 'Time Elapsed': 2639.4664521217346, 'F1': 0.05084745762711865, 'diff': 12.663227923719894, 'mse': 450.67