In [1]:
%cd ~/gymnasium-test/
# %load_ext line_profiler
# %load_ext memory_profiler
# %load_ext autoreload
# %autoreload 2

/home/n.saumik/gymnasium-test


In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["MPLCONFIGDIR"] = "/tmp/"
import tensorflow as tf

tf.config.set_visible_devices([], "GPU")
tf.get_logger().setLevel("ERROR")

import numpy as np
from tqdm import tqdm

from tree import Tree, make_model, Forest, EarlyStopper
np.set_printoptions(linewidth=200)

import pandas as pd

importing numpy
importing tensorflow
done with imports


2024-11-09 02:06:57.089312: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [3]:
threshes = [0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 0.99, 1]

def get_diff(x, steps, nodes=32, layers=2, seeds=10, verbose=False):
    
    x = int(x/4)
    
    all_pcts = []
    
    for seed in tqdm(range(seeds), desc=f"{x},{steps},{nodes},{layers}", disable=verbose):
    
        low = Forest(4, "crit-9-10-0-0.2", x*seed, x*(seed+1))
        low_X, low_Y = low.get_training_data("A", window=0)
        low_Y[:] = 0

        high = Forest(4, "crit-9-10-0-0.2", x*(seed+1), x*(seed+2))
        high_X, high_Y = high.get_training_data("D", window=0)
        high_Y[:] = 1

        train_X = tf.convert_to_tensor(np.concatenate([low_X, high_X]))
        train_Y = tf.convert_to_tensor(np.concatenate([low_Y, high_Y]))

        model = make_model(4, 0, num_nodes=nodes, num_layers=layers)

        ds = tf.data.Dataset.from_tensor_slices((train_X, train_Y))
        ds = ds.repeat().shuffle(100000).batch(32).take(steps)

        acc = model.fit(ds, verbose=verbose).history["accuracy"][0]
        low_pred = model.predict(low_X)
        pcts = [(low_pred[:,1] <= t).mean().round(3) for t in threshes]
        
        all_pcts.append(pcts)
    pcts = np.array(all_pcts)
    return pcts

    

In [8]:
results = {}

In [9]:
%%time
for data_size in [2**6, 2**8, 2**10, 2**12, 2**14, 2**16]:
    for nodes, layers in [(16,1), (32,2), (64,3), (128,4)]:
        for steps in [2**4, 2**8, 2**12]:
            pcts = get_diff(data_size, steps, nodes, layers, seeds=20, verbose=False)
            results[data_size, steps, nodes, layers] = pcts.mean(axis=0)
        for steps in [2**16]:
            pcts = get_diff(data_size, steps, nodes, layers, seeds=2, verbose=True)
            results[data_size, steps, nodes, layers] = pcts.mean(axis=0)
        print("*"*50)
    print("$"*50)
    print("*"*50)

16,16,16,1: 100%|██████████| 20/20 [00:13<00:00,  1.46it/s]
16,256,16,1: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s]
16,4096,16,1: 100%|██████████| 20/20 [01:03<00:00,  3.17s/it]


**************************************************


16,16,32,2: 100%|██████████| 20/20 [00:14<00:00,  1.42it/s]
16,256,32,2: 100%|██████████| 20/20 [00:17<00:00,  1.16it/s]
16,4096,32,2: 100%|██████████| 20/20 [01:07<00:00,  3.37s/it]


**************************************************


16,16,64,3: 100%|██████████| 20/20 [00:14<00:00,  1.38it/s]
16,256,64,3: 100%|██████████| 20/20 [00:18<00:00,  1.06it/s]
16,4096,64,3: 100%|██████████| 20/20 [01:14<00:00,  3.70s/it]


**************************************************


16,16,128,4: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s]
16,256,128,4: 100%|██████████| 20/20 [00:20<00:00,  1.02s/it]
16,4096,128,4: 100%|██████████| 20/20 [01:39<00:00,  4.96s/it]


**************************************************
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
**************************************************


64,16,16,1: 100%|██████████| 20/20 [00:13<00:00,  1.47it/s]
64,256,16,1: 100%|██████████| 20/20 [00:16<00:00,  1.21it/s]
64,4096,16,1: 100%|██████████| 20/20 [01:02<00:00,  3.13s/it]


**************************************************


64,16,32,2: 100%|██████████| 20/20 [00:14<00:00,  1.39it/s]
64,256,32,2: 100%|██████████| 20/20 [00:17<00:00,  1.14it/s]
64,4096,32,2: 100%|██████████| 20/20 [01:06<00:00,  3.34s/it]


**************************************************


64,16,64,3: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s]
64,256,64,3: 100%|██████████| 20/20 [00:18<00:00,  1.06it/s]
64,4096,64,3: 100%|██████████| 20/20 [01:14<00:00,  3.71s/it]


**************************************************


64,16,128,4: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s]
64,256,128,4: 100%|██████████| 20/20 [00:20<00:00,  1.03s/it]
64,4096,128,4: 100%|██████████| 20/20 [01:37<00:00,  4.87s/it]


**************************************************
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
**************************************************


256,16,16,1: 100%|██████████| 20/20 [00:15<00:00,  1.31it/s]
256,256,16,1: 100%|██████████| 20/20 [00:17<00:00,  1.12it/s]
256,4096,16,1: 100%|██████████| 20/20 [01:03<00:00,  3.20s/it]


**************************************************


256,16,32,2: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s]
256,256,32,2: 100%|██████████| 20/20 [00:19<00:00,  1.05it/s]
256,4096,32,2: 100%|██████████| 20/20 [01:08<00:00,  3.43s/it]


**************************************************


256,16,64,3: 100%|██████████| 20/20 [00:16<00:00,  1.22it/s]
256,256,64,3: 100%|██████████| 20/20 [00:20<00:00,  1.01s/it]
256,4096,64,3: 100%|██████████| 20/20 [01:14<00:00,  3.74s/it]


**************************************************


256,16,128,4: 100%|██████████| 20/20 [00:17<00:00,  1.16it/s]
256,256,128,4: 100%|██████████| 20/20 [00:22<00:00,  1.10s/it]
256,4096,128,4: 100%|██████████| 20/20 [01:39<00:00,  4.95s/it]


**************************************************
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
**************************************************


1024,16,16,1: 100%|██████████| 20/20 [00:19<00:00,  1.01it/s]
1024,256,16,1: 100%|██████████| 20/20 [00:22<00:00,  1.13s/it]
1024,4096,16,1: 100%|██████████| 20/20 [01:08<00:00,  3.43s/it]


**************************************************


1024,16,32,2: 100%|██████████| 20/20 [00:20<00:00,  1.04s/it]
1024,256,32,2: 100%|██████████| 20/20 [00:23<00:00,  1.20s/it]
1024,4096,32,2: 100%|██████████| 20/20 [01:12<00:00,  3.64s/it]


**************************************************


1024,16,64,3: 100%|██████████| 20/20 [00:21<00:00,  1.07s/it]
1024,256,64,3: 100%|██████████| 20/20 [00:25<00:00,  1.26s/it]
1024,4096,64,3: 100%|██████████| 20/20 [01:22<00:00,  4.12s/it]


**************************************************


1024,16,128,4: 100%|██████████| 20/20 [00:22<00:00,  1.13s/it]
1024,256,128,4: 100%|██████████| 20/20 [00:27<00:00,  1.40s/it]
1024,4096,128,4: 100%|██████████| 20/20 [01:44<00:00,  5.21s/it]


**************************************************
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
**************************************************


4096,16,16,1: 100%|██████████| 20/20 [00:38<00:00,  1.92s/it]
4096,256,16,1: 100%|██████████| 20/20 [00:41<00:00,  2.09s/it]
4096,4096,16,1: 100%|██████████| 20/20 [01:28<00:00,  4.42s/it]


**************************************************


4096,16,32,2: 100%|██████████| 20/20 [00:39<00:00,  1.97s/it]
4096,256,32,2: 100%|██████████| 20/20 [00:43<00:00,  2.19s/it]
4096,4096,32,2: 100%|██████████| 20/20 [01:32<00:00,  4.60s/it]


**************************************************


4096,16,64,3: 100%|██████████| 20/20 [00:40<00:00,  2.01s/it]
4096,256,64,3: 100%|██████████| 20/20 [00:45<00:00,  2.26s/it]
4096,4096,64,3: 100%|██████████| 20/20 [01:41<00:00,  5.07s/it]


**************************************************


4096,16,128,4: 100%|██████████| 20/20 [00:42<00:00,  2.11s/it]
4096,256,128,4: 100%|██████████| 20/20 [00:47<00:00,  2.38s/it]
4096,4096,128,4: 100%|██████████| 20/20 [02:04<00:00,  6.22s/it]


**************************************************
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
**************************************************


16384,16,16,1: 100%|██████████| 20/20 [01:52<00:00,  5.61s/it]
16384,256,16,1: 100%|██████████| 20/20 [01:55<00:00,  5.78s/it]
16384,4096,16,1: 100%|██████████| 20/20 [02:42<00:00,  8.13s/it]


**************************************************


16384,16,32,2: 100%|██████████| 20/20 [01:55<00:00,  5.79s/it]
16384,256,32,2: 100%|██████████| 20/20 [01:56<00:00,  5.85s/it]
16384,4096,32,2: 100%|██████████| 20/20 [02:48<00:00,  8.41s/it]


**************************************************


16384,16,64,3: 100%|██████████| 20/20 [01:56<00:00,  5.83s/it]
16384,256,64,3: 100%|██████████| 20/20 [02:00<00:00,  6.05s/it]
16384,4096,64,3: 100%|██████████| 20/20 [02:58<00:00,  8.91s/it]


**************************************************


16384,16,128,4: 100%|██████████| 20/20 [02:01<00:00,  6.08s/it]
16384,256,128,4: 100%|██████████| 20/20 [02:06<00:00,  6.33s/it]
16384,4096,128,4: 100%|██████████| 20/20 [03:25<00:00, 10.28s/it]


**************************************************
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
**************************************************
CPU times: user 2h 20min 44s, sys: 16min 23s, total: 2h 37min 8s
Wall time: 1h 54min 44s


In [10]:
df = []
for k,v in results.items():
    # print(f"{k[0]};{k[1]};{k[2]};{k[3]};{';'.join(v.round(3).astype(str))}")
    df.append(list(k) + list(v))
df = pd.DataFrame(df, columns=["x", "steps", "nodes", "layers"] + threshes)

In [21]:
# df.groupby(['steps', 'x']).apply(display)
df.round(3).groupby(['nodes', 'layers', 'steps']).apply(display)

Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
0,64,16,16,1,0.0,0.061,0.074,0.11,0.154,0.23,0.608,0.936,0.967,0.98,0.989,1.0
16,256,16,16,1,0.0,0.034,0.045,0.06,0.099,0.172,0.536,0.879,0.93,0.955,0.981,1.0
32,1024,16,16,1,0.0,0.038,0.053,0.079,0.105,0.14,0.482,0.841,0.887,0.927,0.971,1.0
48,4096,16,16,1,0.0,0.03,0.046,0.072,0.105,0.16,0.587,0.874,0.918,0.943,0.976,1.0
64,16384,16,16,1,0.0,0.029,0.04,0.069,0.104,0.143,0.484,0.851,0.894,0.924,0.969,1.0
80,65536,16,16,1,0.0,0.037,0.048,0.073,0.099,0.144,0.477,0.833,0.882,0.912,0.958,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
1,64,256,16,1,0.0,0.131,0.201,0.322,0.364,0.399,0.806,0.999,1.0,1.0,1.0,1.0
17,256,256,16,1,0.0,0.022,0.036,0.062,0.096,0.172,0.658,0.995,1.0,1.0,1.0,1.0
33,1024,256,16,1,0.0,0.001,0.003,0.01,0.021,0.057,0.577,0.985,0.995,0.998,0.999,1.0
49,4096,256,16,1,0.0,0.001,0.002,0.005,0.011,0.027,0.58,0.983,0.994,0.998,1.0,1.0
65,16384,256,16,1,0.0,0.001,0.001,0.005,0.013,0.031,0.735,0.987,0.995,0.998,1.0,1.0
81,65536,256,16,1,0.0,0.0,0.001,0.004,0.01,0.027,0.695,0.986,0.995,0.998,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
2,64,4096,16,1,0.0,0.444,0.49,0.576,0.698,0.867,0.993,1.0,1.0,1.0,1.0,1.0
18,256,4096,16,1,0.0,0.162,0.189,0.227,0.257,0.353,0.817,0.995,1.0,1.0,1.0,1.0
34,1024,4096,16,1,0.0,0.012,0.021,0.044,0.07,0.114,0.633,0.995,1.0,1.0,1.0,1.0
50,4096,4096,16,1,0.0,0.0,0.0,0.001,0.003,0.017,0.621,0.998,1.0,1.0,1.0,1.0
66,16384,4096,16,1,0.0,0.0,0.0,0.0,0.0,0.002,0.74,1.0,1.0,1.0,1.0,1.0
82,65536,4096,16,1,0.0,0.0,0.0,0.0,0.0,0.0,0.953,1.0,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
3,64,65536,16,1,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
19,256,65536,16,1,0.042,0.734,0.844,0.956,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
35,1024,65536,16,1,0.0,0.147,0.186,0.262,0.346,0.485,0.809,0.954,0.979,0.992,0.998,1.0
51,4096,65536,16,1,0.0,0.017,0.024,0.038,0.069,0.141,0.602,0.987,1.0,1.0,1.0,1.0
67,16384,65536,16,1,0.0,0.0,0.0,0.002,0.027,0.088,0.614,0.998,1.0,1.0,1.0,1.0
83,65536,65536,16,1,0.0,0.0,0.0,0.002,0.016,0.088,0.841,1.0,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
4,64,16,32,2,0.0,0.007,0.017,0.044,0.1,0.2,0.616,0.982,0.997,1.0,1.0,1.0
20,256,16,32,2,0.0,0.003,0.006,0.019,0.045,0.098,0.512,0.938,0.977,0.993,1.0,1.0
36,1024,16,32,2,0.0,0.004,0.009,0.022,0.043,0.091,0.496,0.945,0.98,0.989,0.999,1.0
52,4096,16,32,2,0.0,0.004,0.007,0.019,0.037,0.086,0.553,0.919,0.968,0.986,0.998,1.0
68,16384,16,32,2,0.0,0.002,0.005,0.012,0.029,0.07,0.501,0.933,0.973,0.988,0.998,1.0
84,65536,16,32,2,0.0,0.002,0.005,0.016,0.035,0.078,0.472,0.935,0.973,0.987,0.998,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
5,64,256,32,2,0.0,0.108,0.217,0.332,0.392,0.405,0.848,1.0,1.0,1.0,1.0,1.0
21,256,256,32,2,0.0,0.007,0.018,0.058,0.108,0.2,0.694,0.996,1.0,1.0,1.0,1.0
37,1024,256,32,2,0.0,0.001,0.001,0.005,0.013,0.041,0.607,0.995,0.999,1.0,1.0,1.0
53,4096,256,32,2,0.0,0.0,0.0,0.002,0.005,0.018,0.612,0.992,0.998,0.999,1.0,1.0
69,16384,256,32,2,0.0,0.0,0.0,0.001,0.003,0.015,0.787,0.995,0.999,1.0,1.0,1.0
85,65536,256,32,2,0.0,0.0,0.0,0.001,0.003,0.015,0.781,0.995,0.999,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
6,64,4096,32,2,0.0,0.65,0.744,0.91,0.976,0.997,1.0,1.0,1.0,1.0,1.0,1.0
22,256,4096,32,2,0.0,0.205,0.234,0.282,0.363,0.514,0.877,0.995,1.0,1.0,1.0,1.0
38,1024,4096,32,2,0.0,0.03,0.052,0.082,0.103,0.149,0.676,0.995,1.0,1.0,1.0,1.0
54,4096,4096,32,2,0.0,0.0,0.001,0.004,0.012,0.035,0.563,0.997,1.0,1.0,1.0,1.0
70,16384,4096,32,2,0.0,0.0,0.0,0.0,0.0,0.005,0.701,1.0,1.0,1.0,1.0,1.0
86,65536,4096,32,2,0.0,0.0,0.0,0.0,0.0,0.002,0.963,1.0,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
7,64,65536,32,2,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
23,256,65536,32,2,0.01,0.998,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
39,1024,65536,32,2,0.016,0.649,0.72,0.839,0.93,0.984,0.998,1.0,1.0,1.0,1.0,1.0
55,4096,65536,32,2,0.0,0.081,0.112,0.168,0.226,0.331,0.715,0.97,0.996,1.0,1.0,1.0
71,16384,65536,32,2,0.0,0.006,0.018,0.069,0.121,0.172,0.76,0.988,0.999,1.0,1.0,1.0
87,65536,65536,32,2,0.0,0.002,0.016,0.062,0.108,0.138,0.79,0.996,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
8,64,16,64,3,0.0,0.0,0.0,0.016,0.03,0.111,0.688,1.0,1.0,1.0,1.0,1.0
24,256,16,64,3,0.0,0.0,0.001,0.002,0.008,0.041,0.598,0.988,0.998,1.0,1.0,1.0
40,1024,16,64,3,0.0,0.0,0.0,0.001,0.003,0.017,0.538,0.975,0.996,0.999,1.0,1.0
56,4096,16,64,3,0.0,0.0,0.0,0.001,0.005,0.024,0.52,0.978,0.996,1.0,1.0,1.0
72,16384,16,64,3,0.0,0.0,0.0,0.002,0.006,0.022,0.5,0.981,0.994,0.998,1.0,1.0
88,65536,16,64,3,0.0,0.0,0.0,0.001,0.004,0.024,0.486,0.977,0.996,0.999,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
9,64,256,64,3,0.0,0.142,0.236,0.353,0.397,0.407,0.869,1.0,1.0,1.0,1.0,1.0
25,256,256,64,3,0.0,0.003,0.009,0.05,0.129,0.219,0.701,0.998,1.0,1.0,1.0,1.0
41,1024,256,64,3,0.0,0.0,0.0,0.001,0.008,0.037,0.632,0.998,1.0,1.0,1.0,1.0
57,4096,256,64,3,0.0,0.0,0.0,0.0,0.001,0.007,0.626,0.999,1.0,1.0,1.0,1.0
73,16384,256,64,3,0.0,0.0,0.0,0.0,0.0,0.005,0.819,0.999,1.0,1.0,1.0,1.0
89,65536,256,64,3,0.0,0.0,0.0,0.0,0.0,0.006,0.851,0.999,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
10,64,4096,64,3,0.006,0.89,0.966,0.998,0.999,1.0,1.0,1.0,1.0,1.0,1.0,1.0
26,256,4096,64,3,0.0,0.331,0.403,0.513,0.636,0.786,0.957,0.994,0.999,1.0,1.0,1.0
42,1024,4096,64,3,0.0,0.061,0.082,0.106,0.121,0.191,0.705,0.991,1.0,1.0,1.0,1.0
58,4096,4096,64,3,0.0,0.001,0.002,0.01,0.027,0.056,0.646,0.998,1.0,1.0,1.0,1.0
74,16384,4096,64,3,0.0,0.0,0.0,0.0,0.0,0.006,0.735,1.0,1.0,1.0,1.0,1.0
90,65536,4096,64,3,0.0,0.0,0.0,0.0,0.0,0.002,0.97,1.0,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
11,64,65536,64,3,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
27,256,65536,64,3,0.034,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
43,1024,65536,64,3,0.006,0.98,0.991,0.998,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
59,4096,65536,64,3,0.0,0.167,0.207,0.284,0.369,0.488,0.745,0.936,0.976,0.992,1.0,1.0
75,16384,65536,64,3,0.0,0.016,0.048,0.104,0.141,0.204,0.693,0.974,0.994,0.998,1.0,1.0
91,65536,65536,64,3,0.0,0.054,0.085,0.114,0.134,0.182,0.818,0.993,0.999,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
12,64,16,128,4,0.0,0.0,0.0,0.0,0.0,0.035,0.696,1.0,1.0,1.0,1.0,1.0
28,256,16,128,4,0.0,0.0,0.0,0.0,0.0,0.007,0.578,0.998,1.0,1.0,1.0,1.0
44,1024,16,128,4,0.0,0.0,0.0,0.0,0.0,0.002,0.614,0.999,1.0,1.0,1.0,1.0
60,4096,16,128,4,0.0,0.0,0.0,0.0,0.0,0.002,0.606,0.999,1.0,1.0,1.0,1.0
76,16384,16,128,4,0.0,0.0,0.0,0.0,0.0,0.003,0.441,0.997,1.0,1.0,1.0,1.0
92,65536,16,128,4,0.0,0.0,0.0,0.0,0.0,0.002,0.446,0.998,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
13,64,256,128,4,0.0,0.123,0.239,0.368,0.396,0.4,0.874,1.0,1.0,1.0,1.0,1.0
29,256,256,128,4,0.0,0.0,0.003,0.051,0.123,0.22,0.71,0.997,1.0,1.0,1.0,1.0
45,1024,256,128,4,0.0,0.0,0.0,0.0,0.002,0.028,0.586,0.999,1.0,1.0,1.0,1.0
61,4096,256,128,4,0.0,0.0,0.0,0.0,0.0,0.002,0.6,1.0,1.0,1.0,1.0,1.0
77,16384,256,128,4,0.0,0.0,0.0,0.0,0.0,0.001,0.79,1.0,1.0,1.0,1.0,1.0
93,65536,256,128,4,0.0,0.0,0.0,0.0,0.0,0.001,0.852,1.0,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
14,64,4096,128,4,0.0,0.986,0.998,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
30,256,4096,128,4,0.0,0.565,0.677,0.841,0.931,0.977,0.994,0.998,0.999,1.0,1.0,1.0
46,1024,4096,128,4,0.0,0.076,0.095,0.119,0.152,0.23,0.716,0.98,0.996,1.0,1.0,1.0
62,4096,4096,128,4,0.0,0.001,0.004,0.017,0.033,0.061,0.641,0.998,1.0,1.0,1.0,1.0
78,16384,4096,128,4,0.0,0.0,0.0,0.0,0.0,0.008,0.786,1.0,1.0,1.0,1.0,1.0
94,65536,4096,128,4,0.0,0.0,0.0,0.0,0.0,0.001,0.974,1.0,1.0,1.0,1.0,1.0


Unnamed: 0,x,steps,nodes,layers,0,0.01,0.02,0.05,0.1,0.2,0.5,0.8,0.9,0.95,0.99,1
15,64,65536,128,4,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
31,256,65536,128,4,0.016,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
47,1024,65536,128,4,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
63,4096,65536,128,4,0.0,0.636,0.694,0.766,0.817,0.866,0.948,0.988,0.996,0.998,1.0,1.0
79,16384,65536,128,4,0.0,0.086,0.116,0.152,0.194,0.275,0.707,0.978,0.996,0.999,1.0,1.0
95,65536,65536,128,4,0.0,0.078,0.104,0.134,0.158,0.213,0.81,0.992,0.999,1.0,1.0,1.0


  df.round(3).groupby(['nodes', 'layers', 'steps']).apply(display)


In [19]:
df[df.nodes == 64].pivot(index='x', columns='steps', values=0.2)

steps,16,256,4096,65536
x,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
64,0.1114,0.40715,1.0,1.0
256,0.0414,0.2193,0.7859,1.0
1024,0.01735,0.0369,0.19115,1.0
4096,0.0241,0.0073,0.0555,0.488
16384,0.02225,0.00485,0.0061,0.204
65536,0.02435,0.00615,0.0017,0.1825


In [20]:
df[df.nodes == 128].pivot(index='x', columns='steps', values=0.2)

steps,16,256,4096,65536
x,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
64,0.0354,0.39985,1.0,1.0
256,0.00735,0.22045,0.9767,1.0
1024,0.00225,0.0278,0.23015,1.0
4096,0.0022,0.0025,0.0613,0.8665
16384,0.0031,0.00085,0.0078,0.275
65536,0.0025,0.00065,0.0013,0.213


In [5]:
%%time
get_diff(64, 2**4)
get_diff(64, 2**8)
get_diff(64, 2**12)
get_diff(64, 2**16)
get_diff(64, 2**20)

Steps 16: 0.498


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.031,
 0.2: 0.047,
 0.5: 0.745,
 0.8: 0.974,
 0.9: 0.995,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 256: 0.631


{0: 0.0,
 0.01: 0.026,
 0.02: 0.047,
 0.05: 0.047,
 0.1: 0.151,
 0.2: 0.234,
 0.5: 0.458,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 4096: 0.814


{0: 0.0,
 0.01: 0.333,
 0.02: 0.344,
 0.05: 0.349,
 0.1: 0.365,
 0.2: 0.547,
 0.5: 0.906,
 0.8: 0.99,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 65536: 0.986


{0: 0.016,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 1048576: 0.999


{0: 0.094,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


CPU times: user 15min 39s, sys: 59.7 s, total: 16min 38s
Wall time: 12min 13s


In [6]:
%%time
get_diff(256, 2**4)
get_diff(256, 2**8)
get_diff(256, 2**12)
get_diff(256, 2**16)
get_diff(256, 2**20)

Steps 16: 0.463


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.003,
 0.2: 0.046,
 0.5: 0.602,
 0.8: 0.944,
 0.9: 0.983,
 0.95: 0.996,
 0.99: 1.0,
 1: 1.0}


Steps 256: 0.568


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.008,
 0.2: 0.031,
 0.5: 0.561,
 0.8: 0.987,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 4096: 0.653


{0: 0.0,
 0.01: 0.025,
 0.02: 0.059,
 0.05: 0.087,
 0.1: 0.117,
 0.2: 0.133,
 0.5: 0.628,
 0.8: 0.98,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 65536: 0.899


{0: 0.023,
 0.01: 0.725,
 0.02: 0.807,
 0.05: 0.906,
 0.1: 0.978,
 0.2: 0.997,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 1048576: 0.993


{0: 0.035,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


CPU times: user 15min 15s, sys: 54.8 s, total: 16min 10s
Wall time: 11min 52s


In [7]:
%%time
get_diff(1024, 2**4)
get_diff(1024, 2**8)
get_diff(1024, 2**12)
get_diff(1024, 2**16)
get_diff(1024, 2**20)

Steps 16: 0.467


{0: 0.0,
 0.01: 0.01,
 0.02: 0.013,
 0.05: 0.03,
 0.1: 0.092,
 0.2: 0.159,
 0.5: 0.252,
 0.8: 0.834,
 0.9: 0.96,
 0.95: 0.978,
 0.99: 0.994,
 1: 1.0}


Steps 256: 0.518


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.004,
 0.5: 0.537,
 0.8: 0.996,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 4096: 0.560


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.002,
 0.1: 0.012,
 0.2: 0.021,
 0.5: 0.626,
 0.8: 0.998,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 65536: 0.686


{0: 0.0,
 0.01: 0.078,
 0.02: 0.109,
 0.05: 0.167,
 0.1: 0.226,
 0.2: 0.332,
 0.5: 0.71,
 0.8: 0.98,
 0.9: 0.998,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 1048576: 0.901


{0: 0.009,
 0.01: 0.666,
 0.02: 0.715,
 0.05: 0.786,
 0.1: 0.84,
 0.2: 0.891,
 0.5: 0.94,
 0.8: 0.966,
 0.9: 0.973,
 0.95: 0.979,
 0.99: 0.988,
 1: 1.0}


CPU times: user 15min 15s, sys: 53.6 s, total: 16min 9s
Wall time: 11min 54s


In [8]:
%%time
get_diff(4096, 2**4)
get_diff(4096, 2**8)
get_diff(4096, 2**12)
get_diff(4096, 2**16)
get_diff(4096, 2**20)

Steps 16: 0.486


{0: 0.0,
 0.01: 0.002,
 0.02: 0.004,
 0.05: 0.008,
 0.1: 0.013,
 0.2: 0.039,
 0.5: 0.427,
 0.8: 0.923,
 0.9: 0.979,
 0.95: 0.987,
 0.99: 1.0,
 1: 1.0}


Steps 256: 0.507


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.01,
 0.5: 0.788,
 0.8: 0.996,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 4096: 0.530


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.604,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 65536: 0.607


{0: 0.0,
 0.01: 0.002,
 0.02: 0.008,
 0.05: 0.043,
 0.1: 0.084,
 0.2: 0.139,
 0.5: 0.658,
 0.8: 0.971,
 0.9: 0.998,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 1048576: 0.696


{0: 0.0,
 0.01: 0.108,
 0.02: 0.132,
 0.05: 0.171,
 0.1: 0.222,
 0.2: 0.321,
 0.5: 0.732,
 0.8: 0.977,
 0.9: 0.995,
 0.95: 0.999,
 0.99: 1.0,
 1: 1.0}


CPU times: user 15min 24s, sys: 54.4 s, total: 16min 18s
Wall time: 12min 2s


In [9]:
%%time
get_diff(16384, 2**4)
get_diff(16384, 2**8)
get_diff(16384, 2**12)
get_diff(16384, 2**16)
get_diff(16384, 2**20)

Steps 16: 0.477


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.001,
 0.2: 0.02,
 0.5: 0.515,
 0.8: 0.935,
 0.9: 0.972,
 0.95: 0.994,
 0.99: 1.0,
 1: 1.0}


Steps 256: 0.504


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.001,
 0.2: 0.01,
 0.5: 0.671,
 0.8: 0.987,
 0.9: 0.995,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 4096: 0.572


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.003,
 0.5: 0.989,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 65536: 0.612


{0: 0.0,
 0.01: 0.001,
 0.02: 0.013,
 0.05: 0.058,
 0.1: 0.104,
 0.2: 0.132,
 0.5: 0.73,
 0.8: 0.993,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}


Steps 1048576: 0.657


{0: 0.0,
 0.01: 0.091,
 0.02: 0.116,
 0.05: 0.141,
 0.1: 0.162,
 0.2: 0.205,
 0.5: 0.724,
 0.8: 0.994,
 0.9: 0.998,
 0.95: 0.999,
 0.99: 1.0,
 1: 1.0}


CPU times: user 15min 50s, sys: 54.2 s, total: 16min 45s
Wall time: 12min 26s


In [10]:
%%time
for num_nodes, num_layers in [(64,4)]:
    for train_size in [2**6, 2**8, 2**10, 2**12, 2**14, 2**16, 2**18]:
        for discrim_step in [2**4, 2**8, 2**12, 2**16]:
            print(num_nodes, num_layers, train_size, discrim_step)
            get_diff(train_size, discrim_step, num_nodes, num_layers)
            print()
        print('*'*50)
        print()

64 4 64 16
Steps 16: 0.516


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.536,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 64 256
Steps 256: 0.639


{0: 0.0,
 0.01: 0.0,
 0.02: 0.005,
 0.05: 0.031,
 0.1: 0.057,
 0.2: 0.224,
 0.5: 0.568,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 64 4096
Steps 4096: 0.886


{0: 0.0,
 0.01: 0.661,
 0.02: 0.76,
 0.05: 0.885,
 0.1: 0.958,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 64 65536
Steps 65536: 0.994


{0: 0.099,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

64 4 256 16
Steps 16: 0.541


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.003,
 0.5: 0.227,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 256 256
Steps 256: 0.564


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.008,
 0.5: 0.663,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 256 4096
Steps 4096: 0.692


{0: 0.0,
 0.01: 0.125,
 0.02: 0.129,
 0.05: 0.143,
 0.1: 0.163,
 0.2: 0.285,
 0.5: 0.831,
 0.8: 0.993,
 0.9: 0.999,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 256 65536
Steps 65536: 0.953


{0: 0.0,
 0.01: 0.999,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

64 4 1024 16
Steps 16: 0.531


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.003,
 0.5: 0.674,
 0.8: 0.992,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 1024 256
Steps 256: 0.523


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.003,
 0.5: 0.903,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 1024 4096
Steps 4096: 0.571


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.007,
 0.1: 0.014,
 0.2: 0.041,
 0.5: 0.739,
 0.8: 0.999,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 1024 65536
Steps 65536: 0.774


{0: 0.001,
 0.01: 0.446,
 0.02: 0.519,
 0.05: 0.636,
 0.1: 0.729,
 0.2: 0.821,
 0.5: 0.939,
 0.8: 0.986,
 0.9: 0.995,
 0.95: 0.997,
 0.99: 0.999,
 1: 1.0}



**************************************************

64 4 4096 16
Steps 16: 0.494


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.166,
 0.8: 0.911,
 0.9: 0.982,
 0.95: 0.995,
 0.99: 1.0,
 1: 1.0}



64 4 4096 256
Steps 256: 0.530


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.921,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 4096 4096
Steps 4096: 0.546


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.005,
 0.5: 0.775,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 4096 65536
Steps 65536: 0.638


{0: 0.0,
 0.01: 0.035,
 0.02: 0.072,
 0.05: 0.122,
 0.1: 0.168,
 0.2: 0.252,
 0.5: 0.753,
 0.8: 0.985,
 0.9: 0.999,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

64 4 16384 16
Steps 16: 0.479


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.004,
 0.5: 0.343,
 0.8: 0.988,
 0.9: 0.998,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 16384 256
Steps 256: 0.515


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.002,
 0.5: 0.973,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 16384 4096
Steps 4096: 0.582


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.003,
 0.5: 0.988,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 16384 65536
Steps 65536: 0.627


{0: 0.0,
 0.01: 0.076,
 0.02: 0.106,
 0.05: 0.135,
 0.1: 0.163,
 0.2: 0.217,
 0.5: 0.816,
 0.8: 0.997,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

64 4 65536 16
Steps 16: 1.000


{0: 0.0,
 0.01: 0.103,
 0.02: 0.153,
 0.05: 0.319,
 0.1: 0.418,
 0.2: 0.52,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 65536 256
Steps 256: 1.000


{0: 0.0,
 0.01: 0.525,
 0.02: 0.633,
 0.05: 0.979,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 65536 4096
Steps 4096: 0.959


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.025,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 65536 65536
Steps 65536: 0.751


{0: 0.0,
 0.01: 0.039,
 0.02: 0.064,
 0.05: 0.104,
 0.1: 0.127,
 0.2: 0.178,
 0.5: 0.912,
 0.8: 0.998,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

64 4 262144 16
Steps 16: 0.980


{0: 0.0,
 0.01: 0.035,
 0.02: 0.067,
 0.05: 0.224,
 0.1: 0.353,
 0.2: 0.498,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 262144 256
Steps 256: 0.994


{0: 0.0,
 0.01: 0.53,
 0.02: 0.708,
 0.05: 0.992,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 262144 4096
Steps 4096: 1.000


{0: 0.0,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



64 4 262144 65536
Steps 65536: 0.934


{0: 0.0,
 0.01: 0.999,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

CPU times: user 15min 44s, sys: 54.8 s, total: 16min 39s
Wall time: 13min 15s


In [11]:
%%time
for num_nodes, num_layers in [(128,6)]:
    for train_size in [2**6, 2**8, 2**10, 2**12, 2**14, 2**16, 2**18]:
        for discrim_step in [2**4, 2**8, 2**12, 2**16]:
            print(num_nodes, num_layers, train_size, discrim_step)
            get_diff(train_size, discrim_step, num_nodes, num_layers)
            print()
        print('*'*50)
        print()

128 6 64 16
Steps 16: 0.574


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.531,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 64 256
Steps 256: 0.686


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.016,
 0.1: 0.193,
 0.2: 0.286,
 0.5: 0.589,
 0.8: 0.969,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 64 4096
Steps 4096: 0.919


{0: 0.0,
 0.01: 0.984,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 64 65536
Steps 65536: 0.995


{0: 0.0,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

128 6 256 16
Steps 16: 0.531


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.996,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 256 256
Steps 256: 0.587


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.009,
 0.5: 0.415,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 256 4096
Steps 4096: 0.685


{0: 0.0,
 0.01: 0.108,
 0.02: 0.13,
 0.05: 0.143,
 0.1: 0.176,
 0.2: 0.266,
 0.5: 0.688,
 0.8: 0.965,
 0.9: 0.996,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 256 65536
Steps 65536: 0.965


{0: 0.0,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

128 6 1024 16
Steps 16: 0.525


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.635,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 1024 256
Steps 256: 0.513


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.8,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 1024 4096
Steps 4096: 0.584


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.017,
 0.1: 0.021,
 0.2: 0.039,
 0.5: 0.442,
 0.8: 0.995,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 1024 65536
Steps 65536: 0.837


{0: 0.0,
 0.01: 0.714,
 0.02: 0.759,
 0.05: 0.81,
 0.1: 0.847,
 0.2: 0.891,
 0.5: 0.95,
 0.8: 0.982,
 0.9: 0.993,
 0.95: 0.997,
 0.99: 1.0,
 1: 1.0}



**************************************************

128 6 4096 16
Steps 16: 0.514


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.456,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 4096 256
Steps 256: 0.526


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.905,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 4096 4096
Steps 4096: 0.547


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.003,
 0.5: 0.842,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 4096 65536
Steps 65536: 0.665


{0: 0.0,
 0.01: 0.131,
 0.02: 0.162,
 0.05: 0.216,
 0.1: 0.282,
 0.2: 0.389,
 0.5: 0.74,
 0.8: 0.981,
 0.9: 0.997,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

128 6 16384 16
Steps 16: 0.479


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.319,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 16384 256
Steps 256: 0.528


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.771,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 16384 4096
Steps 4096: 0.581


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.0,
 0.5: 0.86,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 16384 65536
Steps 65536: 0.633


{0: 0.0,
 0.01: 0.081,
 0.02: 0.115,
 0.05: 0.149,
 0.1: 0.173,
 0.2: 0.215,
 0.5: 0.907,
 0.8: 0.999,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

128 6 65536 16
Steps 16: 1.000


{0: 0.0,
 0.01: 0.009,
 0.02: 0.04,
 0.05: 0.193,
 0.1: 0.392,
 0.2: 0.523,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 65536 256
Steps 256: 0.999


{0: 0.0,
 0.01: 0.664,
 0.02: 0.927,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 65536 4096
Steps 4096: 0.959


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.0,
 0.2: 0.001,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 65536 65536
Steps 65536: 0.754


{0: 0.0,
 0.01: 0.092,
 0.02: 0.114,
 0.05: 0.136,
 0.1: 0.157,
 0.2: 0.196,
 0.5: 0.941,
 0.8: 0.995,
 0.9: 0.999,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

128 6 262144 16
Steps 16: 0.736


{0: 0.0,
 0.01: 0.0,
 0.02: 0.0,
 0.05: 0.0,
 0.1: 0.002,
 0.2: 0.054,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 262144 256
Steps 256: 0.994


{0: 0.0,
 0.01: 0.605,
 0.02: 0.879,
 0.05: 0.999,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 262144 4096
Steps 4096: 0.999


{0: 0.0,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



128 6 262144 65536
Steps 65536: 0.934


{0: 0.0,
 0.01: 1.0,
 0.02: 1.0,
 0.05: 1.0,
 0.1: 1.0,
 0.2: 1.0,
 0.5: 1.0,
 0.8: 1.0,
 0.9: 1.0,
 0.95: 1.0,
 0.99: 1.0,
 1: 1.0}



**************************************************

CPU times: user 22min 30s, sys: 1min 3s, total: 23min 34s
Wall time: 17min 36s
