In [1]:
import sys
from pathlib import Path

reporoot = Path(".").absolute().parent


In [2]:
sys.path.append(str(reporoot))


In [3]:
from c2st.check import c2st_


In [4]:
help(c2st_)

Help on function c2st_ in module c2st.check:

c2st_(X: numpy.ndarray, Y: numpy.ndarray, seed: int = 1, n_folds: int = 5, scoring: str = 'accuracy', z_score: bool = True, noise_scale: Optional[float] = None, verbosity: int = 0, clf_class=<class 'sklearn.ensemble._forest.RandomForestClassifier'>, clf_kwargs={}) -> numpy.ndarray
    Return accuracy of classifier trained to distinguish samples from supposedly
    two distributions <X> and <Y>. For details on the method, see [1,2].
    If the returned accuracy is 0.5, <X> and <Y> are considered to be from the
    same generating PDF, i.e. they can not be differentiated.
    If the returned accuracy is around 1., <X> and <Y> are considered to be from
    two different generating PDFs.
    
    Trains classifiers with N-fold cross-validation [3]. By default, a `RandomForestClassifier`
    by scikit-learn is used. This can be adopted using <clf_class> and
    <clf_kwargs> as in:
    
    ``` py
    clf = clf_class(random_state=seed, **clf_kwar

In [5]:
from __future__ import annotations
import numpy as np
from functools import partial
from numpy.random import default_rng
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, cross_val_score
from sklearn.neural_network import MLPClassifier
from sklearn import __version__ as sklversion
import time
FIXEDSEED = 1309

In [6]:
def nn_c2st_(
    X: np.ndarray,
    Y: np.ndarray,
    seed: int = FIXEDSEED,
    n_folds: int = 5,
    scoring: str = "accuracy",
    z_score: bool = True,
    noise_scale: Optional[float] = None,
    verbosity: int = 0,
) -> np.ndarray:

    ndim = X.shape[1]
    clf_class = MLPClassifier
    clf_kwargs = {
        "activation": "relu",
        "hidden_layer_sizes": (10 * ndim, 10 * ndim),
        "max_iter": 1000,
        "solver": "adam",
    }

    return c2st_(
        X,
        Y,
        seed,
        n_folds,
        scoring,
        z_score,
        noise_scale,
        verbosity,
        clf_class,
        clf_kwargs,
    )

def early_c2st_(
    X: np.ndarray,
    Y: np.ndarray,
    seed: int = FIXEDSEED,
    n_folds: int = 5,
    scoring: str = "accuracy",
    z_score: bool = True,
    noise_scale: Optional[float] = None,
    verbosity: int = 0,
) -> np.ndarray:

    ndim = X.shape[1]
    clf_class = MLPClassifier
    clf_kwargs = {
        "activation": "relu",
        "hidden_layer_sizes": (10 * ndim, 10 * ndim),
        "max_iter": 1000,
        "solver": "adam",
        "early_stopping": True,
        "n_iter_no_change": 50,
    }

    return c2st_(
        X,
        Y,
        seed,
        n_folds,
        scoring,
        z_score,
        noise_scale,
        verbosity,
        clf_class,
        clf_kwargs,
    )

In [7]:
NDIM = 10
max_nsamples = 8096
sample_sizes = [ 2**it for it in range(7,12)]
sample_sizes.append(max_nsamples)
RNG = default_rng(FIXEDSEED)
print(sample_sizes)

[128, 256, 512, 1024, 2048, 8096]


In [8]:
center_normal = partial(RNG.multivariate_normal, mean=np.zeros(NDIM), cov=np.eye(NDIM))
distributions = { 0. : center_normal}

for alpha in np.linspace(0,2,9):
    distributions[alpha] = partial(RNG.multivariate_normal, mean=np.zeros(NDIM) + alpha, cov=np.eye(NDIM))



In [9]:
center_samples = center_normal(size=max_nsamples)
samples = {}
for k,v in distributions.items():
    samples[k] = v(size=max_nsamples)

assert samples[0.].shape == (8096, 10)

In [10]:
rf_results = {}
rf_timings = {}
total = len(samples.values())*len(sample_sizes)
cnt = 0


for k,v in samples.items():
    for size in sample_sizes:

        try:
            start = time.time()
            scores = c2st_(X = center_samples[:size,...],
                          Y = v[:size,...],
                          n_folds=10)
            end = time.time()
        except Exception as ex:
            print(ex)
            continue

        if not k in rf_results.keys():
            rf_results[k] = []
            rf_timings[k] = []

        mean = np.mean(scores)
        std = np.std(scores)
        rf_results[k].append(list(scores))
        rf_timings[k].append(len(scores)*[end-start])
        cnt += 1

        print(f"{cnt}/{total}: {k}[{size},...] = {mean,std} ({end-start} seconds)")

1/54: 0.0[128,...] = (0.42153846153846153, 0.12246385772832794) (1.0709571838378906 seconds)


2/54: 0.0[256,...] = (0.47835595776772244, 0.07806781384146226) (1.415243148803711 seconds)


3/54: 0.0[512,...] = (0.5009232819341329, 0.048259547167832775) (2.2146804332733154 seconds)


4/54: 0.0[1024,...] = (0.48974414155906265, 0.019876207761461956) (4.505118370056152 seconds)


5/54: 0.0[2048,...] = (0.5007263402707377, 0.01581535273212949) (9.713266134262085 seconds)


6/54: 0.0[8096,...] = (0.5003095189074188, 0.007796416628318716) (51.849912881851196 seconds)


7/54: 0.25[128,...] = (0.6201538461538462, 0.0922985892947385) (1.0183866024017334 seconds)


8/54: 0.25[256,...] = (0.5820889894419305, 0.053563623821549874) (1.343022346496582 seconds)


9/54: 0.25[512,...] = (0.5878926327812678, 0.044028780012037166) (2.0739011764526367 seconds)


10/54: 0.25[1024,...] = (0.6147321855571497, 0.03028625429776028) (3.8184967041015625 seconds)


11/54: 0.25[2048,...] = (0.6310996481602957, 0.02729351243745489) (7.2080979347229 seconds)


12/54: 0.25[8096,...] = (0.6420456919756898, 0.009978095785152703) (33.08326601982117 seconds)


13/54: 0.5[128,...] = (0.7461538461538462, 0.084270626326962) (1.048398494720459 seconds)


14/54: 0.5[256,...] = (0.7130844645550527, 0.055577278557293136) (1.311382532119751 seconds)


15/54: 0.5[512,...] = (0.7509423186750428, 0.045301667986281394) (1.9235589504241943 seconds)


16/54: 0.5[1024,...] = (0.7641535150645624, 0.014418230248470484) (3.3507187366485596 seconds)


17/54: 0.5[2048,...] = (0.7678221718647504, 0.019321606896966624) (6.692022323608398 seconds)


18/54: 0.5[8096,...] = (0.7739629324609765, 0.008799769004857676) (29.567869186401367 seconds)


19/54: 0.75[128,...] = (0.8515384615384616, 0.038110707622700454) (0.9769423007965088 seconds)


20/54: 0.75[256,...] = (0.8377828054298643, 0.038485486289529275) (1.221602201461792 seconds)


21/54: 0.75[512,...] = (0.8476584808680755, 0.032474086886125006) (1.8697459697723389 seconds)


22/54: 0.75[1024,...] = (0.8510616929698708, 0.01759884234274294) (3.1699037551879883 seconds)


23/54: 0.75[2048,...] = (0.8647468543145089, 0.01774062606967589) (6.287116765975952 seconds)


24/54: 0.75[8096,...] = (0.8708621386467794, 0.004054615967776685) (27.999122619628906 seconds)


25/54: 1.0[128,...] = (0.8984615384615384, 0.06938785597071738) (0.986522912979126 seconds)


26/54: 1.0[256,...] = (0.9219457013574661, 0.03888361550904267) (1.204362392425537 seconds)


27/54: 1.0[512,...] = (0.9208833047782219, 0.030780055386088787) (1.7639496326446533 seconds)


28/54: 1.0[1024,...] = (0.9311573409851744, 0.0168358066396484) (3.020636796951294 seconds)


29/54: 1.0[2048,...] = (0.9350593356789313, 0.01492591725480785) (5.9438629150390625 seconds)


30/54: 1.0[8096,...] = (0.9344123792311976, 0.003614051898303775) (29.04834747314453 seconds)


31/54: 1.25[128,...] = (0.9452307692307693, 0.04008983993822111) (0.9557645320892334 seconds)


32/54: 1.25[256,...] = (0.9530165912518853, 0.02941717860216238) (1.2233014106750488 seconds)


33/54: 1.25[512,...] = (0.9551208833047783, 0.027305324953413746) (1.927248239517212 seconds)


34/54: 1.25[1024,...] = (0.9604519368723098, 0.012053293529227824) (3.3602335453033447 seconds)


35/54: 1.25[2048,...] = (0.9667940843222613, 0.0070194217423293044) (5.902421236038208 seconds)


36/54: 1.25[8096,...] = (0.9725166807738355, 0.003931870642098374) (29.058518886566162 seconds)


37/54: 1.5[128,...] = (0.9763076923076923, 0.026353985360447817) (0.9288384914398193 seconds)


38/54: 1.5[256,...] = (0.9805429864253394, 0.021343315778351402) (1.1442861557006836 seconds)


39/54: 1.5[512,...] = (0.9824195697696554, 0.009602337614476528) (1.6496150493621826 seconds)


40/54: 1.5[1024,...] = (0.9843758967001435, 0.008942707086403262) (2.934046983718872 seconds)


41/54: 1.5[2048,...] = (0.9855978293279264, 0.00710716215025746) (5.90256142616272 seconds)


42/54: 1.5[8096,...] = (0.9891923073990194, 0.002652581628563902) (29.3986074924469 seconds)


43/54: 1.75[128,...] = (0.9883076923076924, 0.024799312836442926) (0.9422547817230225 seconds)


44/54: 1.75[256,...] = (0.9922322775263952, 0.00951422551306799) (1.1106212139129639 seconds)


45/54: 1.75[512,...] = (0.9941557205406435, 0.0078004425129916686) (1.618781328201294 seconds)


46/54: 1.75[1024,...] = (0.9941391678622669, 0.004256450194349834) (2.792975664138794 seconds)


47/54: 1.75[2048,...] = (0.9958488878287317, 0.0024549937141941465) (5.707763195037842 seconds)


48/54: 1.75[8096,...] = (0.9958003721242346, 0.0018280116523806229) (29.352715015411377 seconds)


49/54: 2.0[128,...] = (1.0, 0.0) (0.9355366230010986 seconds)


50/54: 2.0[256,...] = (0.998076923076923, 0.0057692307692307826) (1.1295642852783203 seconds)


51/54: 2.0[512,...] = (0.9980487340567296, 0.00390258992495368) (1.59157395362854 seconds)


52/54: 2.0[1024,...] = (0.9985365853658535, 0.002235402778027241) (2.721519708633423 seconds)


53/54: 2.0[2048,...] = (0.998533603673445, 0.0019556928463501473) (5.600362300872803 seconds)


54/54: 2.0[8096,...] = (0.9986412508864639, 0.0006052481616375465) (29.05195903778076 seconds)


In [11]:
header = "ndims,mode,dist_sigma,nsamples,c2st_score,crossvalid,total_cvtime_sec,nfolds,sklearn_version"
with open("rf_results.csv","w") as ocsv:
    ocsv.write(header+"\n")
    for k in samples.keys():
        for idx in range(len(sample_sizes)):
            scores = rf_results[k][idx]
            timings = rf_timings[k][idx]
            for cvid in range(len(timings)):
                score = scores[cvid]
                timing = timings[cvid]
                row = f"{NDIM},rf,{k},{sample_sizes[idx]},{score},{cvid},{timing},10,{sklversion}\n"
                ocsv.write(row)

In [18]:
!head -n25 rf_results.csv|column -t -s","

ndims  mode  dist_sigma  nsamples  c2st_score           crossvalid  total_cvtime_sec    nfolds  sklearn_version
10     rf    0.0         128       0.6153846153846154   0           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.34615384615384615  1           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.34615384615384615  2           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.6153846153846154   3           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.4230769230769231   4           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.2692307692307692   5           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.48                 6           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.48                 7           1.0709571838378906  10      1.0.2
10     rf    0.0         128       0.24                 8      

In [19]:
!tail -n25 rf_results.csv|column -t -s","

10  rf  2.0  1024  1.0                 5  2.721519708633423  10  1.0.2
10  rf  2.0  1024  1.0                 6  2.721519708633423  10  1.0.2
10  rf  2.0  1024  1.0                 7  2.721519708633423  10  1.0.2
10  rf  2.0  1024  1.0                 8  2.721519708633423  10  1.0.2
10  rf  2.0  1024  1.0                 9  2.721519708633423  10  1.0.2
10  rf  2.0  2048  1.0                 0  5.600362300872803  10  1.0.2
10  rf  2.0  2048  1.0                 1  5.600362300872803  10  1.0.2
10  rf  2.0  2048  0.9975609756097561  2  5.600362300872803  10  1.0.2
10  rf  2.0  2048  1.0                 3  5.600362300872803  10  1.0.2
10  rf  2.0  2048  1.0                 4  5.600362300872803  10  1.0.2
10  rf  2.0  2048  1.0                 5  5.600362300872803  10  1.0.2
10  rf  2.0  2048  0.9951100244498777  6  5.600362300872803  10  1.0.2
10  rf  2.0  2048  0.9951100244498777  7  5.600362300872803  10  1.0.2
10  rf  2.0  2048  1.0                 8  5.600362300872803  10 

In [23]:
early_results = {}
early_timings = {}
cnt = 0

for k,v in samples.items():
    for size in sample_sizes:

        try:
            start = time.time()
            scores = early_c2st_(center_samples[:size,...],
                                 v[:size,...],
                                 n_folds=10)
            end = time.time()
        except Exception as ex:
            print(ex)
            continue
        
        if not k in early_results.keys():
            early_results[k] = []
            early_timings[k] = []

        mean = np.mean(scores)
        std = np.std(scores)
        early_results[k].append(list(scores))
        early_timings[k].append(len(scores)*[end-start])
        # early_results[k].append((mean,std))
        # early_timings[k].append(end-start)
        cnt += 1

        print(f"{cnt}/{total}: {k}[{size},...] = {mean,std} ({end-start} seconds)")

1/54: 0.0[128,...] = (0.44984615384615384, 0.09478271632322777) (2.912928819656372 seconds)


2/54: 0.0[256,...] = (0.44528657616892914, 0.04583470991937402) (3.3204894065856934 seconds)


3/54: 0.0[512,...] = (0.46774224252807917, 0.045399203720113135) (6.249772787094116 seconds)


4/54: 0.0[1024,...] = (0.48537063605930175, 0.0309927203690838) (16.076107501983643 seconds)


5/54: 0.0[2048,...] = (0.5024312719899815, 0.02263994456097309) (24.020041942596436 seconds)


6/54: 0.0[8096,...] = (0.5069763762115007, 0.016434380431142717) (127.88017725944519 seconds)


7/54: 0.25[128,...] = (0.6095384615384615, 0.07421637963359005) (4.231306314468384 seconds)


8/54: 0.25[256,...] = (0.5975867269984916, 0.07653110436111268) (4.737743854522705 seconds)


9/54: 0.25[512,...] = (0.6102798400913765, 0.05881663203928149) (10.663694620132446 seconds)


10/54: 0.25[1024,...] = (0.6352295552367287, 0.03471375039508743) (15.909212350845337 seconds)


11/54: 0.25[2048,...] = (0.6367052298884848, 0.024157368858210623) (33.589600801467896 seconds)


12/54: 0.25[8096,...] = (0.6516185879105377, 0.008131449011014185) (116.77839612960815 seconds)


13/54: 0.5[128,...] = (0.7584615384615384, 0.06926153162696888) (5.854063272476196 seconds)


14/54: 0.5[256,...] = (0.7556184012066365, 0.051403472283029175) (10.744370222091675 seconds)


15/54: 0.5[512,...] = (0.7607272035027604, 0.038030000871389164) (8.065553665161133 seconds)


16/54: 0.5[1024,...] = (0.7778240076518412, 0.029306090268237012) (24.016499042510986 seconds)


17/54: 0.5[2048,...] = (0.7790381060289822, 0.017149559549045536) (35.47059869766235 seconds)


18/54: 0.5[8096,...] = (0.7824840055208595, 0.010050876573596214) (184.20869851112366 seconds)


19/54: 0.75[128,...] = (0.8715384615384615, 0.07286609263313974) (3.6640777587890625 seconds)


20/54: 0.75[256,...] = (0.8377073906485671, 0.07969155268451598) (7.17254376411438 seconds)


21/54: 0.75[512,...] = (0.8671711403007807, 0.0442401376973219) (14.189313650131226 seconds)


22/54: 0.75[1024,...] = (0.8701243424198948, 0.017998024416236505) (26.421959161758423 seconds)


23/54: 0.75[2048,...] = (0.8696231140795516, 0.01356402924201838) (54.42694807052612 seconds)


24/54: 0.75[8096,...] = (0.8778416413118905, 0.007610421458650434) (185.84786534309387 seconds)


25/54: 1.0[128,...] = (0.9299999999999999, 0.04487062269531246) (5.73030686378479 seconds)


26/54: 1.0[256,...] = (0.9237933634992459, 0.039700868708114415) (5.774042844772339 seconds)


27/54: 1.0[512,...] = (0.9267846944603084, 0.015755695102608564) (7.967698097229004 seconds)


28/54: 1.0[1024,...] = (0.9316355810616928, 0.012560645808229173) (15.577455997467041 seconds)


29/54: 1.0[2048,...] = (0.935549525910907, 0.011801919748741535) (41.068233489990234 seconds)


30/54: 1.0[8096,...] = (0.9386734686096432, 0.005772118226269467) (150.14718437194824 seconds)


31/54: 1.25[128,...] = (0.9453846153846154, 0.04616538317343739) (4.677994251251221 seconds)


32/54: 1.25[256,...] = (0.9414027149321267, 0.02314152670089475) (6.820226669311523 seconds)


33/54: 1.25[512,...] = (0.9560346468684561, 0.016518497765750092) (8.479653358459473 seconds)


34/54: 1.25[1024,...] = (0.9604423720707794, 0.014245832104329394) (18.327195405960083 seconds)


35/54: 1.25[2048,...] = (0.9670409684536944, 0.010315648958454732) (39.73344826698303 seconds)


36/54: 1.25[8096,...] = (0.9740614157497006, 0.0070136062896592) (262.198358297348 seconds)


37/54: 1.5[128,...] = (0.9610769230769233, 0.03440826926892488) (4.172002792358398 seconds)


38/54: 1.5[256,...] = (0.9570135746606334, 0.028839622088005123) (6.783476829528809 seconds)


39/54: 1.5[512,...] = (0.9765657719398438, 0.013257200584368951) (10.458009719848633 seconds)


40/54: 1.5[1024,...] = (0.9824175035868004, 0.005450573822579632) (14.669718980789185 seconds)


41/54: 1.5[2048,...] = (0.9829113244677679, 0.003617714026350787) (32.26592755317688 seconds)


42/54: 1.5[8096,...] = (0.9885127993960607, 0.0037894019329374824) (160.98377633094788 seconds)


43/54: 1.75[128,...] = (0.9690769230769231, 0.03771867209820716) (4.1310875415802 seconds)


44/54: 1.75[256,...] = (0.9687028657616892, 0.023505348913272887) (3.8933193683624268 seconds)


45/54: 1.75[512,...] = (0.984370835712926, 0.012493077876521912) (11.27149510383606 seconds)


46/54: 1.75[1024,...] = (0.9902391200382592, 0.0048732734195616545) (27.70166826248169 seconds)


47/54: 1.75[2048,...] = (0.992676963444451, 0.0030832717992818758) (37.20150279998779 seconds)


48/54: 1.75[8096,...] = (0.9951210928861742, 0.002000102755999108) (160.02601170539856 seconds)


49/54: 2.0[128,...] = (0.9729230769230771, 0.038702017438967624) (4.21706485748291 seconds)


50/54: 2.0[256,...] = (0.9804298642533936, 0.021479613861054833) (6.342424631118774 seconds)


51/54: 2.0[512,...] = (0.9892727964972398, 0.012661869316598583) (6.657649755477905 seconds)


52/54: 2.0[1024,...] = (0.9965829746532758, 0.003124603599945947) (17.929768562316895 seconds)


53/54: 2.0[2048,...] = (0.9956079670821157, 0.0034144658947605116) (33.33112931251526 seconds)


54/54: 2.0[8096,...] = (0.9984562182112111, 0.0012114221736267189) (141.06727719306946 seconds)


In [35]:
with open("early_results.csv","w") as ocsv:
    ocsv.write(header+"\n")
    for k in samples.keys():
        for idx in range(len(sample_sizes)):
            scores = early_results[k][idx]
            timings = early_timings[k][idx]
            for cvid in range(len(timings)):
                score = scores[cvid]
                timing = timings[cvid]
                row = f"{NDIM},early,{k},{sample_sizes[idx]},{score},{cvid},{timing},10,{sklversion}\n"
                ocsv.write(row)
            

In [36]:
!head early_results.csv|column -t -s,

ndims  mode   dist_sigma  nsamples  c2st_score           crossvalid  total_cvtime_sec   nfolds  sklearn_version
10     early  0.0         128       0.4230769230769231   0           2.912928819656372  10      1.0.2
10     early  0.0         128       0.5384615384615384   1           2.912928819656372  10      1.0.2
10     early  0.0         128       0.46153846153846156  2           2.912928819656372  10      1.0.2
10     early  0.0         128       0.2692307692307692   3           2.912928819656372  10      1.0.2
10     early  0.0         128       0.5384615384615384   4           2.912928819656372  10      1.0.2
10     early  0.0         128       0.3076923076923077   5           2.912928819656372  10      1.0.2
10     early  0.0         128       0.48                 6           2.912928819656372  10      1.0.2
10     early  0.0         128       0.56                 7           2.912928819656372  10      1.0.2
10     early  0.0         128       0.4                  8     

In [28]:
nn_results = {}
nn_timings = {}
cnt = 0

for k,v in samples.items():
    for size in sample_sizes:

        try:
            start = time.time()
            scores = nn_c2st_(center_samples[:size,...],
                              v[:size,...],
                              n_folds=10)
            end = time.time()
        except Exception as ex:
            print(ex)
            continue
        
        if not k in nn_results.keys():
            nn_results[k] = []
            nn_timings[k] = []

        mean = np.mean(scores)
        std = np.std(scores)
        nn_results[k].append(list(scores))
        nn_timings[k].append(len(scores)*[end-start])
        cnt += 1

        print(f"{cnt}/{total}: {k}[{size},...] = {mean,std} ({end-start} seconds)")

1/54: 0.0[128,...] = (0.43784615384615383, 0.11545474791695684) (11.881926774978638 seconds)


2/54: 0.0[256,...] = (0.4765837104072398, 0.05999918906760437) (23.398439407348633 seconds)


3/54: 0.0[512,...] = (0.46576242147344377, 0.03861967938800877) (46.214720726013184 seconds)


4/54: 0.0[1024,...] = (0.4751123864179818, 0.033294503971132865) (124.35642886161804 seconds)


5/54: 0.0[2048,...] = (0.48705766593118255, 0.01723649469724518) (285.923570394516 seconds)


6/54: 0.0[8096,...] = (0.4948741411784442, 0.009190496443046547) (880.7563226222992 seconds)


7/54: 0.25[128,...] = (0.5190769230769231, 0.0586765277708438) (15.179469347000122 seconds)


8/54: 0.25[256,...] = (0.5253770739064856, 0.06325139134827444) (32.36412858963013 seconds)


9/54: 0.25[512,...] = (0.5195126594327052, 0.022111197102863717) (87.37766671180725 seconds)


10/54: 0.25[1024,...] = (0.5493208990913439, 0.025072287794564267) (520.8345577716827 seconds)


11/54: 0.25[2048,...] = (0.5649496093982944, 0.024611829681571454) (835.7785665988922 seconds)


12/54: 0.25[8096,...] = (0.5714554022830736, 0.007700613348966535) (2103.6959426403046 seconds)


13/54: 0.5[128,...] = (0.6839999999999999, 0.0982936067915644) (22.50494647026062 seconds)


14/54: 0.5[256,...] = (0.7185520361990951, 0.050611185052735895) (36.13365578651428 seconds)


15/54: 0.5[512,...] = (0.702170188463735, 0.026780235946993206) (57.48804187774658 seconds)


16/54: 0.5[1024,...] = (0.7329172644667623, 0.022864435492072725) (299.14223551750183 seconds)


17/54: 0.5[2048,...] = (0.7162991233824318, 0.016708368358987667) (929.4284360408783 seconds)


18/54: 0.5[8096,...] = (0.7074471743722309, 0.013474582414089411) (2254.7413051128387 seconds)


19/54: 0.75[128,...] = (0.8441538461538463, 0.04786099201033153) (12.906589984893799 seconds)


20/54: 0.75[256,...] = (0.8397812971342382, 0.051827693947208174) (25.5204336643219 seconds)


21/54: 0.75[512,...] = (0.8231772320578716, 0.032514147424770795) (74.54472470283508 seconds)


22/54: 0.75[1024,...] = (0.8295911047345769, 0.022421507692777582) (110.90081596374512 seconds)


23/54: 0.75[2048,...] = (0.8325099886695689, 0.014729883580079187) (218.03869938850403 seconds)


24/54: 0.75[8096,...] = (0.8310274975407774, 0.005503296987667779) (1231.364823102951 seconds)


25/54: 1.0[128,...] = (0.9256923076923078, 0.04030093895275112) (6.457927703857422 seconds)


26/54: 1.0[256,...] = (0.8985294117647058, 0.05509627871795327) (18.24162220954895 seconds)


27/54: 1.0[512,...] = (0.9121264039596421, 0.018979535216871325) (32.48452019691467 seconds)


28/54: 1.0[1024,...] = (0.9086967957914874, 0.013425205182141265) (95.90080428123474 seconds)


29/54: 1.0[2048,...] = (0.9152859443019858, 0.01326973939868088) (155.38958072662354 seconds)


30/54: 1.0[8096,...] = (0.9134144304897855, 0.0061867088345594) (837.5820505619049 seconds)


31/54: 1.25[128,...] = (0.9570769230769229, 0.03655877977583804) (8.01132845878601 seconds)


32/54: 1.25[256,...] = (0.9511689291101055, 0.02007737088083253) (16.43439483642578 seconds)


33/54: 1.25[512,...] = (0.948239101465829, 0.02047347276785048) (37.954694747924805 seconds)


34/54: 1.25[1024,...] = (0.9584911525585843, 0.01162702074716398) (65.02123880386353 seconds)


35/54: 1.25[2048,...] = (0.9592331087125052, 0.009051391905055814) (120.53906226158142 seconds)


36/54: 1.25[8096,...] = (0.9644884816873699, 0.00427263558300804) (580.1094274520874 seconds)


37/54: 1.5[128,...] = (0.9846153846153847, 0.030769230769230747) (6.886920690536499 seconds)


38/54: 1.5[256,...] = (0.9706636500754147, 0.026697684710007448) (9.637597560882568 seconds)


39/54: 1.5[512,...] = (0.9775461640967066, 0.010724416040111248) (15.713615417480469 seconds)


40/54: 1.5[1024,...] = (0.9838880918220948, 0.0053620161359407194) (39.14696025848389 seconds)


41/54: 1.5[2048,...] = (0.9868167451845669, 0.004904660552339394) (67.03842520713806 seconds)


42/54: 1.5[8096,...] = (0.9866598418472003, 0.00280699037985347) (329.3699767589569 seconds)


43/54: 1.75[128,...] = (0.9961538461538462, 0.01153846153846153) (4.159655809402466 seconds)


44/54: 1.75[256,...] = (0.9843514328808446, 0.011752674456636443) (7.52869176864624 seconds)


45/54: 1.75[512,...] = (0.9912050256996002, 0.009225967497263024) (11.210096836090088 seconds)


46/54: 1.75[1024,...] = (0.9902319464371114, 0.003787849346090159) (16.486979722976685 seconds)


47/54: 1.75[2048,...] = (0.9948738744111157, 0.004001138986942138) (54.54613780975342 seconds)


48/54: 1.75[8096,...] = (0.9954916157664767, 0.002158127099500247) (171.8548617362976 seconds)


49/54: 2.0[128,...] = (1.0, 0.0) (3.7789576053619385 seconds)


50/54: 2.0[256,...] = (0.996078431372549, 0.007843137254901976) (6.071957111358643 seconds)


51/54: 2.0[512,...] = (0.9980487340567296, 0.0039025899249536806) (8.293025016784668 seconds)


52/54: 2.0[1024,...] = (0.9985365853658535, 0.003123475237772124) (20.737361669540405 seconds)


53/54: 2.0[2048,...] = (0.998046395133878, 0.002129578406386084) (37.38613438606262 seconds)


54/54: 2.0[8096,...] = (0.9980855046934932, 0.001085598195463115) (108.4162175655365 seconds)


In [33]:
with open("nn_results.csv","w") as ocsv:
    ocsv.write(header+"\n")
    for k in samples.keys():
        for idx in range(len(sample_sizes)):
            scores = nn_results[k][idx]
            timings = nn_timings[k][idx]
            for cvid in range(len(timings)):
                score = scores[cvid]
                timing = timings[cvid]
                row = f"{NDIM},nn,{k},{sample_sizes[idx]},{score},{cvid},{timing},10,{sklversion}\n"
                ocsv.write(row)

In [34]:
!head nn_results.csv|column -t -s,

ndims  mode  dist_sigma  nsamples  c2st_score           crossvalid  total_cvtime_sec    nfolds  sklearn_version
10     nn    0.0         128       0.34615384615384615  0           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.6538461538461539   1           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.5384615384615384   2           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.23076923076923078  3           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.38461538461538464  4           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.38461538461538464  5           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.56                 6           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.48                 7           11.881926774978638  10      1.0.2
10     nn    0.0         128       0.4                  8      