In [1]:
import sys
from pathlib import Path

reporoot = Path(".").absolute().parent


In [2]:
sys.path.append(str(reporoot))


In [3]:
from c2st.check import c2st_


In [4]:
help(c2st)

NameError: name 'c2st' is not defined

In [12]:
from __future__ import annotations
import numpy as np
from functools import partial
from numpy.random import default_rng
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, cross_val_score
from sklearn.neural_network import MLPClassifier
from sklearn import __version__ as sklversion
import time
FIXEDSEED = 1309

In [6]:
def nn_c2st_(
    X: np.ndarray,
    Y: np.ndarray,
    seed: int = FIXEDSEED,
    n_folds: int = 5,
    scoring: str = "accuracy",
    z_score: bool = True,
    noise_scale: Optional[float] = None,
    verbosity: int = 0,
) -> np.ndarray:

    ndim = X.shape[1]
    clf_class = MLPClassifier
    clf_kwargs = {
        "activation": "relu",
        "hidden_layer_sizes": (10 * ndim, 10 * ndim),
        "max_iter": 1000,
        "solver": "adam",
    }

    return c2st_(
        X,
        Y,
        seed,
        n_folds,
        scoring,
        z_score,
        noise_scale,
        verbosity,
        clf_class,
        clf_kwargs,
    )

In [7]:
NDIM = 10
max_nsamples = 8096
sample_sizes = [ 2**it for it in range(7,12)]
sample_sizes.append(max_nsamples)
RNG = default_rng(FIXEDSEED)
print(sample_sizes)

[128, 256, 512, 1024, 2048, 8096]


In [8]:
center_normal = partial(RNG.multivariate_normal, mean=np.zeros(NDIM), cov=np.eye(NDIM))
distributions = { 0. : center_normal}

for alpha in np.linspace(0,2,9):
    distributions[alpha] = partial(RNG.multivariate_normal, mean=np.zeros(NDIM) + alpha, cov=np.eye(NDIM))



In [9]:
center_samples = center_normal(size=max_nsamples)
samples = {}
for k,v in distributions.items():
    samples[k] = v(size=max_nsamples)

assert samples[0.].shape == (8096, 10)

In [10]:
rf_results = {}
rf_timings = {}
total = len(samples.values())*len(sample_sizes)
cnt = 0


for k,v in samples.items():
    for size in sample_sizes:

        try:
            start = time.time()
            scores = c2st_(X = center_samples[:size,...],
                          Y = v[:size,...],
                          n_folds=10)
            end = time.time()
        except Exception as ex:
            print(ex)
            continue

        if not k in rf_results.keys():
            rf_results[k] = []
            rf_timings[k] = []

        mean = np.mean(scores)
        std = np.std(scores)
        rf_results[k].append((mean,std))
        rf_timings[k].append(end-start)
        cnt += 1

        print(f"{cnt}/{total}: {k}[{size},...] = {mean,std} ({end-start} seconds)")

1/54: 0.0[128,...] = (0.42153846153846153, 0.12246385772832794) (1.0489280223846436 seconds)


2/54: 0.0[256,...] = (0.47835595776772244, 0.07806781384146226) (1.3461580276489258 seconds)


3/54: 0.0[512,...] = (0.5009232819341329, 0.048259547167832775) (2.162215232849121 seconds)


4/54: 0.0[1024,...] = (0.48974414155906265, 0.019876207761461956) (4.205496311187744 seconds)


5/54: 0.0[2048,...] = (0.5007263402707377, 0.01581535273212949) (9.193150758743286 seconds)


6/54: 0.0[8096,...] = (0.5003095189074188, 0.007796416628318716) (50.5123770236969 seconds)


7/54: 0.25[128,...] = (0.6201538461538462, 0.0922985892947385) (2.1747143268585205 seconds)


8/54: 0.25[256,...] = (0.5820889894419305, 0.053563623821549874) (1.361140251159668 seconds)


9/54: 0.25[512,...] = (0.5878926327812678, 0.044028780012037166) (2.0011892318725586 seconds)


10/54: 0.25[1024,...] = (0.6147321855571497, 0.03028625429776028) (3.710228204727173 seconds)


11/54: 0.25[2048,...] = (0.6310996481602957, 0.02729351243745489) (7.187965631484985 seconds)


12/54: 0.25[8096,...] = (0.6420456919756898, 0.009978095785152703) (31.963542938232422 seconds)


13/54: 0.5[128,...] = (0.7461538461538462, 0.084270626326962) (0.9898128509521484 seconds)


14/54: 0.5[256,...] = (0.7130844645550527, 0.055577278557293136) (1.2809925079345703 seconds)


15/54: 0.5[512,...] = (0.7509423186750428, 0.045301667986281394) (1.8889577388763428 seconds)


16/54: 0.5[1024,...] = (0.7641535150645624, 0.014418230248470484) (3.2316195964813232 seconds)


17/54: 0.5[2048,...] = (0.7678221718647504, 0.019321606896966624) (6.415999889373779 seconds)


18/54: 0.5[8096,...] = (0.7739629324609765, 0.008799769004857676) (28.750935316085815 seconds)


19/54: 0.75[128,...] = (0.8515384615384616, 0.038110707622700454) (0.9564695358276367 seconds)


20/54: 0.75[256,...] = (0.8377828054298643, 0.038485486289529275) (1.1988365650177002 seconds)


21/54: 0.75[512,...] = (0.8476584808680755, 0.032474086886125006) (1.7847466468811035 seconds)


22/54: 0.75[1024,...] = (0.8510616929698708, 0.01759884234274294) (3.067359685897827 seconds)


23/54: 0.75[2048,...] = (0.8647468543145089, 0.01774062606967589) (6.041255235671997 seconds)


24/54: 0.75[8096,...] = (0.8708621386467794, 0.004054615967776685) (27.22282075881958 seconds)


25/54: 1.0[128,...] = (0.8984615384615384, 0.06938785597071738) (0.9390015602111816 seconds)


26/54: 1.0[256,...] = (0.9219457013574661, 0.03888361550904267) (1.1617100238800049 seconds)


27/54: 1.0[512,...] = (0.9208833047782219, 0.030780055386088787) (1.6966526508331299 seconds)


28/54: 1.0[1024,...] = (0.9311573409851744, 0.0168358066396484) (2.9206058979034424 seconds)


29/54: 1.0[2048,...] = (0.9350593356789313, 0.01492591725480785) (5.7851057052612305 seconds)


30/54: 1.0[8096,...] = (0.9344123792311976, 0.003614051898303775) (27.331603050231934 seconds)


31/54: 1.25[128,...] = (0.9452307692307693, 0.04008983993822111) (0.9335131645202637 seconds)


32/54: 1.25[256,...] = (0.9530165912518853, 0.02941717860216238) (1.2945091724395752 seconds)


33/54: 1.25[512,...] = (0.9551208833047783, 0.027305324953413746) (1.7083992958068848 seconds)


34/54: 1.25[1024,...] = (0.9604519368723098, 0.012053293529227824) (2.9349284172058105 seconds)


35/54: 1.25[2048,...] = (0.9667940843222613, 0.0070194217423293044) (5.822935104370117 seconds)


36/54: 1.25[8096,...] = (0.9725166807738355, 0.003931870642098374) (30.33746099472046 seconds)


37/54: 1.5[128,...] = (0.9763076923076923, 0.026353985360447817) (0.9330744743347168 seconds)


38/54: 1.5[256,...] = (0.9805429864253394, 0.021343315778351402) (1.1498992443084717 seconds)


39/54: 1.5[512,...] = (0.9824195697696554, 0.009602337614476528) (1.6228747367858887 seconds)


40/54: 1.5[1024,...] = (0.9843758967001435, 0.008942707086403262) (2.856144666671753 seconds)


41/54: 1.5[2048,...] = (0.9855978293279264, 0.00710716215025746) (5.847603797912598 seconds)


42/54: 1.5[8096,...] = (0.9891923073990194, 0.002652581628563902) (29.06977105140686 seconds)


43/54: 1.75[128,...] = (0.9883076923076924, 0.024799312836442926) (0.9283976554870605 seconds)


44/54: 1.75[256,...] = (0.9922322775263952, 0.00951422551306799) (1.091942310333252 seconds)


45/54: 1.75[512,...] = (0.9941557205406435, 0.0078004425129916686) (1.6763129234313965 seconds)


46/54: 1.75[1024,...] = (0.9941391678622669, 0.004256450194349834) (2.739069938659668 seconds)


47/54: 1.75[2048,...] = (0.9958488878287317, 0.0024549937141941465) (5.715183258056641 seconds)


48/54: 1.75[8096,...] = (0.9958003721242346, 0.0018280116523806229) (29.078433513641357 seconds)


49/54: 2.0[128,...] = (1.0, 0.0) (0.8969318866729736 seconds)


50/54: 2.0[256,...] = (0.998076923076923, 0.0057692307692307826) (1.078972339630127 seconds)


51/54: 2.0[512,...] = (0.9980487340567296, 0.00390258992495368) (1.5362026691436768 seconds)


52/54: 2.0[1024,...] = (0.9985365853658535, 0.002235402778027241) (2.6288297176361084 seconds)


53/54: 2.0[2048,...] = (0.998533603673445, 0.0019556928463501473) (5.543246269226074 seconds)


54/54: 2.0[8096,...] = (0.9986412508864639, 0.0006052481616375465) (28.65970253944397 seconds)


In [13]:
header = "ndims,mode,dist_sigma,nsamples,mean_c2st_score,std_c2st_score,time_sec,nfolds,sklearn_version"
with open("rf_results.csv","w") as ocsv:
    ocsv.write(header+"\n")
    for k in samples.keys():
        for idx in range(len(sample_sizes)):
            mn,std = rf_results[k][idx]
            row = f"{NDIM},rf,{k},{sample_sizes[idx]},{mn},{std},{rf_timings[k][idx]},10,{sklversion}\n"
            print(row)
            ocsv.write(row)

10,rf,0.0,128,0.42153846153846153,0.12246385772832794,1.0489280223846436,10,1.0.2

10,rf,0.0,256,0.47835595776772244,0.07806781384146226,1.3461580276489258,10,1.0.2

10,rf,0.0,512,0.5009232819341329,0.048259547167832775,2.162215232849121,10,1.0.2

10,rf,0.0,1024,0.48974414155906265,0.019876207761461956,4.205496311187744,10,1.0.2

10,rf,0.0,2048,0.5007263402707377,0.01581535273212949,9.193150758743286,10,1.0.2

10,rf,0.0,8096,0.5003095189074188,0.007796416628318716,50.5123770236969,10,1.0.2

10,rf,0.25,128,0.6201538461538462,0.0922985892947385,2.1747143268585205,10,1.0.2

10,rf,0.25,256,0.5820889894419305,0.053563623821549874,1.361140251159668,10,1.0.2

10,rf,0.25,512,0.5878926327812678,0.044028780012037166,2.0011892318725586,10,1.0.2

10,rf,0.25,1024,0.6147321855571497,0.03028625429776028,3.710228204727173,10,1.0.2

10,rf,0.25,2048,0.6310996481602957,0.02729351243745489,7.187965631484985,10,1.0.2

10,rf,0.25,8096,0.6420456919756898,0.009978095785152703,31.963542938232422,10,1.0.2

10,r

In [14]:
!head -n10 rf_results.csv

ndims,mode,dist_sigma,nsamples,mean_c2st_score,std_c2st_score,time_sec,nfolds,sklearn_version
10,rf,0.0,128,0.42153846153846153,0.12246385772832794,1.0489280223846436,10,1.0.2
10,rf,0.0,256,0.47835595776772244,0.07806781384146226,1.3461580276489258,10,1.0.2
10,rf,0.0,512,0.5009232819341329,0.048259547167832775,2.162215232849121,10,1.0.2
10,rf,0.0,1024,0.48974414155906265,0.019876207761461956,4.205496311187744,10,1.0.2
10,rf,0.0,2048,0.5007263402707377,0.01581535273212949,9.193150758743286,10,1.0.2
10,rf,0.0,8096,0.5003095189074188,0.007796416628318716,50.5123770236969,10,1.0.2
10,rf,0.25,128,0.6201538461538462,0.0922985892947385,2.1747143268585205,10,1.0.2
10,rf,0.25,256,0.5820889894419305,0.053563623821549874,1.361140251159668,10,1.0.2
10,rf,0.25,512,0.5878926327812678,0.044028780012037166,2.0011892318725586,10,1.0.2


In [15]:
nn_results = {}
nn_timings = {}
cnt = 0

for k,v in samples.items():
    for size in sample_sizes:

        try:
            start = time.time()
            scores = nn_c2st_(center_samples[:size,...],
                             v[:size,...])
            end = time.time()
        except Exception as ex:
            print(ex)
            continue
        
        if not k in nn_results.keys():
            nn_results[k] = []
            nn_timings[k] = []

        mean = np.mean(scores)
        std = np.std(scores)
        nn_results[k].append((mean,std))
        nn_timings[k].append(end-start)
        cnt += 1

        print(f"{cnt}/{total}: {k}[{size},...] = {mean,std} ({end-start} seconds)")

1/54: 0.0[128,...] = (0.47639517345399707, 0.06944370621739751) (3.9503777027130127 seconds)


2/54: 0.0[256,...] = (0.47270131353512274, 0.014611290821653794) (14.877474546432495 seconds)


3/54: 0.0[512,...] = (0.46679100908656146, 0.02653006985499437) (27.8264639377594 seconds)


4/54: 0.0[1024,...] = (0.4936513805235852, 0.01827138663151398) (62.5630042552948 seconds)


5/54: 0.0[2048,...] = (0.49633878316805147, 0.013397183951439668) (132.11983609199524 seconds)


6/54: 0.0[8096,...] = (0.4998762953282654, 0.005431641934181123) (463.0335690975189 seconds)


7/54: 0.25[128,...] = (0.5739819004524886, 0.03610501852829092) (9.77917218208313 seconds)


8/54: 0.25[256,...] = (0.5196268798781649, 0.05430573278705191) (27.03339195251465 seconds)


9/54: 0.25[512,...] = (0.5527307508369201, 0.02066399167358149) (36.92687106132507 seconds)


10/54: 0.25[1024,...] = (0.5424783827300377, 0.0155300323600427) (69.17806816101074 seconds)


11/54: 0.25[2048,...] = (0.5720215015336967, 0.008651104135964238) (152.62547278404236 seconds)


12/54: 0.25[8096,...] = (0.5718257127606889, 0.007568655171052921) (582.9192526340485 seconds)


13/54: 0.5[128,...] = (0.69947209653092, 0.04175691311166378) (7.267455577850342 seconds)


14/54: 0.5[256,...] = (0.7264801066057491, 0.028329113928785933) (18.055891036987305 seconds)


15/54: 0.5[512,...] = (0.7118938307030129, 0.0131882243029468) (30.425918579101562 seconds)


16/54: 0.5[1024,...] = (0.7280314866718349, 0.020240400863217968) (51.75952458381653 seconds)


17/54: 0.5[2048,...] = (0.7155746150868103, 0.01159427086422544) (114.5666275024414 seconds)


18/54: 0.5[8096,...] = (0.7097948279738464, 0.005346606546113173) (612.9915263652802 seconds)


19/54: 0.75[128,...] = (0.8594268476621417, 0.02851277200221321) (11.615701913833618 seconds)


20/54: 0.75[256,...] = (0.841766609556444, 0.04944163036708353) (26.568673372268677 seconds)


21/54: 0.75[512,...] = (0.8164466762314682, 0.023166476509885802) (40.89202094078064 seconds)


22/54: 0.75[1024,...] = (0.830562347188264, 0.006579611937979073) (88.67784833908081 seconds)


23/54: 0.75[2048,...] = (0.837646743500402, 0.003169646464086814) (208.0921094417572 seconds)


24/54: 0.75[8096,...] = (0.8349178413715943, 0.00529144628132102) (602.3718740940094 seconds)


25/54: 1.0[128,...] = (0.9255656108597285, 0.040095990326560416) (6.844011068344116 seconds)


26/54: 1.0[256,...] = (0.8906148867313917, 0.02661921851192302) (25.916473388671875 seconds)


27/54: 1.0[512,...] = (0.9082113821138211, 0.005564207685290608) (41.892882108688354 seconds)


28/54: 1.0[1024,...] = (0.911616673623949, 0.006314998708401032) (70.7614254951477 seconds)


29/54: 1.0[2048,...] = (0.9179662884540933, 0.006992492696744581) (91.480872631073 seconds)


30/54: 1.0[8096,...] = (0.9160084180962371, 0.00518478691269885) (363.63505005836487 seconds)


31/54: 1.25[128,...] = (0.9569381598793363, 0.02294444371596286) (5.173536062240601 seconds)


32/54: 1.25[256,...] = (0.9492670854749667, 0.01872270968974444) (12.426153898239136 seconds)


33/54: 1.25[512,...] = (0.9462936394069823, 0.011523740205567238) (27.137492895126343 seconds)


34/54: 1.25[1024,...] = (0.955556085634206, 0.008979577176816664) (19.78152060508728 seconds)


35/54: 1.25[2048,...] = (0.9614270824026923, 0.006100491626951001) (41.298165798187256 seconds)


36/54: 1.25[8096,...] = (0.9670208532094469, 0.001771033416762611) (217.97580647468567 seconds)


37/54: 1.5[128,...] = (0.984389140271493, 0.01465377316848801) (3.6381447315216064 seconds)


38/54: 1.5[256,...] = (0.972663240053303, 0.012989894212332874) (7.363297700881958 seconds)


39/54: 1.5[512,...] = (0.9755954088952654, 0.009240355708280776) (7.77424693107605 seconds)


40/54: 1.5[1024,...] = (0.9809565269246825, 0.003245877751347846) (11.301787376403809 seconds)


41/54: 1.5[2048,...] = (0.9865716668155692, 0.0030910387501745726) (26.72587251663208 seconds)


42/54: 1.5[8096,...] = (0.9857954351507768, 0.002327595026958495) (103.90563941001892 seconds)


43/54: 1.75[128,...] = (0.996078431372549, 0.007843137254901978) (2.4891104698181152 seconds)


44/54: 1.75[256,...] = (0.9863125832857416, 0.004834132877613558) (4.85188364982605 seconds)


45/54: 1.75[512,...] = (0.990234337637494, 0.0030852082431564814) (6.342369318008423 seconds)


46/54: 1.75[1024,...] = (0.9916977756574632, 0.003314449925174374) (7.313701391220093 seconds)


47/54: 1.75[2048,...] = (0.9943854790196254, 0.0019808149449771937) (20.251517295837402 seconds)


48/54: 1.75[8096,...] = (0.9949974265538074, 0.0010947661700702277) (69.76494026184082 seconds)


49/54: 2.0[128,...] = (1.0, 0.0) (1.4720993041992188 seconds)


50/54: 2.0[256,...] = (1.0, 0.0) (1.5427064895629883 seconds)


51/54: 2.0[512,...] = (0.9980439980870397, 0.002395615245328274) (4.2837560176849365 seconds)


52/54: 2.0[1024,...] = (0.9985365853658535, 0.0029268292682926855) (6.391915321350098 seconds)


53/54: 2.0[2048,...] = (0.9987789987789988, 0.0010920966923075857) (13.088403701782227 seconds)


54/54: 2.0[8096,...] = (0.9983324755179359, 0.000315090857989144) (44.79178190231323 seconds)


In [16]:
with open("nn_results.csv","w") as ocsv:
    ocsv.write(header+"\n")
    for k in samples.keys():
        for idx in range(len(sample_sizes)):
            mn,std = rf_results[k][idx]
            row = f"{NDIM},rf,{k},{sample_sizes[idx]},{mn},{std},{nn_timings[k][idx]},10,{sklversion}\n"
            print(row)
            ocsv.write(row)

10,rf,0.0,128,0.42153846153846153,0.12246385772832794,3.9503777027130127,10,1.0.2

10,rf,0.0,256,0.47835595776772244,0.07806781384146226,14.877474546432495,10,1.0.2

10,rf,0.0,512,0.5009232819341329,0.048259547167832775,27.8264639377594,10,1.0.2

10,rf,0.0,1024,0.48974414155906265,0.019876207761461956,62.5630042552948,10,1.0.2

10,rf,0.0,2048,0.5007263402707377,0.01581535273212949,132.11983609199524,10,1.0.2

10,rf,0.0,8096,0.5003095189074188,0.007796416628318716,463.0335690975189,10,1.0.2

10,rf,0.25,128,0.6201538461538462,0.0922985892947385,9.77917218208313,10,1.0.2

10,rf,0.25,256,0.5820889894419305,0.053563623821549874,27.03339195251465,10,1.0.2

10,rf,0.25,512,0.5878926327812678,0.044028780012037166,36.92687106132507,10,1.0.2

10,rf,0.25,1024,0.6147321855571497,0.03028625429776028,69.17806816101074,10,1.0.2

10,rf,0.25,2048,0.6310996481602957,0.02729351243745489,152.62547278404236,10,1.0.2

10,rf,0.25,8096,0.6420456919756898,0.009978095785152703,582.9192526340485,10,1.0.2

10,rf,0

In [17]:
!head -n10 nn_results.csv

ndims,mode,dist_sigma,nsamples,mean_c2st_score,std_c2st_score,time_sec,nfolds,sklearn_version
10,rf,0.0,128,0.42153846153846153,0.12246385772832794,3.9503777027130127,10,1.0.2
10,rf,0.0,256,0.47835595776772244,0.07806781384146226,14.877474546432495,10,1.0.2
10,rf,0.0,512,0.5009232819341329,0.048259547167832775,27.8264639377594,10,1.0.2
10,rf,0.0,1024,0.48974414155906265,0.019876207761461956,62.5630042552948,10,1.0.2
10,rf,0.0,2048,0.5007263402707377,0.01581535273212949,132.11983609199524,10,1.0.2
10,rf,0.0,8096,0.5003095189074188,0.007796416628318716,463.0335690975189,10,1.0.2
10,rf,0.25,128,0.6201538461538462,0.0922985892947385,9.77917218208313,10,1.0.2
10,rf,0.25,256,0.5820889894419305,0.053563623821549874,27.03339195251465,10,1.0.2
10,rf,0.25,512,0.5878926327812678,0.044028780012037166,36.92687106132507,10,1.0.2
