In [1]:
import sys
from pathlib import Path

reporoot = Path(".").absolute().parent


In [2]:
sys.path.append(str(reporoot))


In [3]:
from c2st.check import c2st


In [4]:
help(c2st)

Help on function c2st in module c2st.check:

c2st(X: numpy.ndarray, Y: numpy.ndarray, seed: int = 1, n_folds: int = 5, scoring: str = 'accuracy', z_score: bool = True, noise_scale: Optional[float] = None, verbosity: int = 0, clf_class=<class 'sklearn.ensemble._forest.RandomForestClassifier'>, clf_kwargs={}) -> numpy.ndarray
    Return accuracy of classifier trained to distinguish samples from supposedly
    two distributions <X> and <Y>. For details on the method, see [1,2].
    If the returned accuracy is 0.5, <X> and <Y> are considered to be from the
    same generating PDF, i.e. they can not be differentiated.
    If the returned accuracy is around 1., <X> and <Y> are considered to be from
    two different generating PDFs.
    
    Trains classifiers with N-fold cross-validation [3]. By default, a `RandomForestClassifier`
    by scikit-learn is used. This can be adopted using <clf_class> and
    <clf_kwargs> as in:
    
    ``` py
    clf = clf_class(random_state=seed, **clf_kwargs

In [5]:
from __future__ import annotations
import numpy as np
from functools import partial
from numpy.random import default_rng
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, cross_val_score
from sklearn.neural_network import MLPClassifier
import time
FIXEDSEED = 1309

In [6]:
def nn_c2st(
    X: np.ndarray,
    Y: np.ndarray,
    seed: int = FIXEDSEED,
    n_folds: int = 5,
    scoring: str = "accuracy",
    z_score: bool = True,
    noise_scale: Optional[float] = None,
    verbosity: int = 0,
) -> np.ndarray:

    ndim = X.shape[1]
    clf_class = MLPClassifier
    clf_kwargs = {
        "activation": "relu",
        "hidden_layer_sizes": (10 * ndim, 10 * ndim),
        "max_iter": 1000,
        "solver": "adam",
    }

    return c2st(
        X,
        Y,
        seed,
        n_folds,
        scoring,
        z_score,
        noise_scale,
        verbosity,
        clf_class,
        clf_kwargs,
    )

In [7]:
NDIM = 10
max_nsamples = 8096
sample_sizes = [ 2**it for it in range(7,12)]
sample_sizes.append(max_nsamples)
RNG = default_rng(FIXEDSEED)
print(sample_sizes)

[128, 256, 512, 1024, 2048, 8096]


In [19]:
center_normal = partial(RNG.multivariate_normal, mean=np.zeros(NDIM), cov=np.eye(NDIM))
distributions = { 0. : center_normal}

for alpha in np.linspace(0,2,9):
    distributions[alpha] = partial(RNG.multivariate_normal, mean=np.zeros(NDIM) + alpha, cov=np.eye(NDIM))



In [20]:
samples = {}
for k,v in distributions.items():
    samples[k] = v(size=max_nsamples)

assert samples[0.].shape == (8096, 10)

In [21]:
rf_results = {}
rf_timings = {}
total = len(samples.values())*len(sample_sizes)
cnt = 0

for k,v in samples.items():
    for size in sample_sizes:

        try:
            start = time.time()
            scores = c2st(center_samples[:size,...], v[:size,...])
            end = time.time()
        except Exception as ex:
            print(ex)
            continue
        
        if not k in rf_results.keys():
            rf_results[k] = []
            rf_timings[k] = []

        rf_results[k].extend(scores)
        rf_timings[k].append(end-start)
        cnt += 1

        print(f"{cnt}/{total}: {k}[{size},...] = {scores} ({end-start} seconds)")

1/54: 0.0[128,...] = [0.5309201] (0.5159294605255127 seconds)


2/54: 0.0[256,...] = [0.4959452] (0.65325927734375 seconds)


3/54: 0.0[512,...] = [0.5487709] (0.994964599609375 seconds)


4/54: 0.0[1024,...] = [0.5092993] (1.8338210582733154 seconds)


5/54: 0.0[2048,...] = [0.50682926] (3.970792293548584 seconds)


6/54: 0.0[8096,...] = [0.5030264] (21.190895557403564 seconds)


7/54: 0.25[128,...] = [0.5702866] (0.4957125186920166 seconds)


8/54: 0.25[256,...] = [0.6290501] (0.6247773170471191 seconds)


9/54: 0.25[512,...] = [0.63085127] (0.9626574516296387 seconds)


10/54: 0.25[1024,...] = [0.623547] (1.6151199340820312 seconds)


11/54: 0.25[2048,...] = [0.6367167] (3.2666213512420654 seconds)


12/54: 0.25[8096,...] = [0.6366715] (14.225350379943848 seconds)


13/54: 0.5[128,...] = [0.69947207] (0.48907971382141113 seconds)


14/54: 0.5[256,...] = [0.7423187] (0.6067798137664795 seconds)


15/54: 0.5[512,...] = [0.75392634] (0.8784573078155518 seconds)


16/54: 0.5[1024,...] = [0.77295965] (1.4654514789581299 seconds)


17/54: 0.5[2048,...] = [0.76513475] (2.8497238159179688 seconds)


18/54: 0.5[8096,...] = [0.76790947] (12.819429397583008 seconds)


19/54: 0.75[128,...] = [0.8555053] (0.47698473930358887 seconds)


20/54: 0.75[256,...] = [0.8887112] (0.5860490798950195 seconds)


21/54: 0.75[512,...] = [0.88086563] (0.8349876403808594 seconds)


22/54: 0.75[1024,...] = [0.86864334] (1.4279158115386963 seconds)


23/54: 0.75[2048,...] = [0.8623017] (2.6851935386657715 seconds)


24/54: 0.75[8096,...] = [0.86962694] (12.227771520614624 seconds)


25/54: 1.0[128,...] = [0.8749623] (0.476351261138916 seconds)


26/54: 1.0[256,...] = [0.9042833] (0.5802707672119141 seconds)


27/54: 1.0[512,...] = [0.913099] (0.8170440196990967 seconds)


28/54: 1.0[1024,...] = [0.9272479] (1.3732187747955322 seconds)


29/54: 1.0[2048,...] = [0.9282248] (2.628946542739868 seconds)


30/54: 1.0[8096,...] = [0.93181795] (12.15562629699707 seconds)


31/54: 1.25[128,...] = [0.94932127] (0.46962904930114746 seconds)


32/54: 1.25[256,...] = [0.96874166] (0.5610203742980957 seconds)


33/54: 1.25[512,...] = [0.9599378] (0.7886550426483154 seconds)


34/54: 1.25[1024,...] = [0.9692361] (1.3193495273590088 seconds)


35/54: 1.25[2048,...] = [0.9677739] (2.590153217315674 seconds)


36/54: 1.25[8096,...] = [0.9703555] (12.321477174758911 seconds)


37/54: 1.5[128,...] = [0.976546] (0.4641392230987549 seconds)


38/54: 1.5[256,...] = [0.9785075] (0.5498793125152588 seconds)


39/54: 1.5[512,...] = [0.98534197] (0.7691514492034912 seconds)


40/54: 1.5[1024,...] = [0.98584175] (1.294787883758545 seconds)


41/54: 1.5[2048,...] = [0.9875485] (2.600203037261963 seconds)


42/54: 1.5[8096,...] = [0.98906845] (12.577839612960815 seconds)


43/54: 1.75[128,...] = [0.9883107] (0.5092556476593018 seconds)


44/54: 1.75[256,...] = [0.99023414] (0.6388118267059326 seconds)


45/54: 1.75[512,...] = [0.99316597] (1.0531127452850342 seconds)


46/54: 1.75[1024,...] = [0.9941392] (1.4125792980194092 seconds)


47/54: 1.75[2048,...] = [0.9958495] (2.5202813148498535 seconds)


48/54: 1.75[8096,...] = [0.99542993] (12.592328786849976 seconds)


49/54: 2.0[128,...] = [0.99230766] (0.45456409454345703 seconds)


50/54: 2.0[256,...] = [0.99805826] (0.5216641426086426 seconds)


51/54: 2.0[512,...] = [0.9980488] (0.7174582481384277 seconds)


52/54: 2.0[1024,...] = [0.9980476] (1.1915857791900635 seconds)


53/54: 2.0[2048,...] = [0.9982912] (2.37479829788208 seconds)


54/54: 2.0[8096,...] = [0.998456] (12.236571788787842 seconds)


In [23]:
header = "ndims,mode,dist_sigma,nsamples,c2st_score,time_sec"
with open("rf_results.csv","w") as ocsv:
    ocsv.write(header+"\n")
    for k in samples.keys():
        for idx in range(len(sample_sizes)):
            
            row = f"{NDIM},rf,{k},{sample_sizes[idx]},{rf_results[k][idx]},{rf_timings[k][idx]}\n"
            print(row)
            ocsv.write(row)

10,rf,0.0,128,0.5309200882911682,0.5159294605255127

10,rf,0.0,256,0.49594518542289734,0.65325927734375

10,rf,0.0,512,0.5487709045410156,0.994964599609375

10,rf,0.0,1024,0.5092992782592773,1.8338210582733154

10,rf,0.0,2048,0.5068292617797852,3.970792293548584

10,rf,0.0,8096,0.5030264258384705,21.190895557403564

10,rf,0.25,128,0.5702865719795227,0.4957125186920166

10,rf,0.25,256,0.629050076007843,0.6247773170471191

10,rf,0.25,512,0.6308512687683105,0.9626574516296387

10,rf,0.25,1024,0.6235470175743103,1.6151199340820312

10,rf,0.25,2048,0.6367167234420776,3.2666213512420654

10,rf,0.25,8096,0.6366714835166931,14.225350379943848

10,rf,0.5,128,0.6994720697402954,0.48907971382141113

10,rf,0.5,256,0.7423186898231506,0.6067798137664795

10,rf,0.5,512,0.7539263367652893,0.8784573078155518

10,rf,0.5,1024,0.7729596495628357,1.4654514789581299

10,rf,0.5,2048,0.7651347517967224,2.8497238159179688

10,rf,0.5,8096,0.7679094672203064,12.819429397583008

10,rf,0.75,128,0.8555052876472473,

In [24]:
!cat rf_results.csv

ndims,mode,dist_sigma,nsamples,c2st_score,time_sec
10,rf,0.0,128,0.5309200882911682,0.5159294605255127
10,rf,0.0,256,0.49594518542289734,0.65325927734375
10,rf,0.0,512,0.5487709045410156,0.994964599609375
10,rf,0.0,1024,0.5092992782592773,1.8338210582733154
10,rf,0.0,2048,0.5068292617797852,3.970792293548584
10,rf,0.0,8096,0.5030264258384705,21.190895557403564
10,rf,0.25,128,0.5702865719795227,0.4957125186920166
10,rf,0.25,256,0.629050076007843,0.6247773170471191
10,rf,0.25,512,0.6308512687683105,0.9626574516296387
10,rf,0.25,1024,0.6235470175743103,1.6151199340820312
10,rf,0.25,2048,0.6367167234420776,3.2666213512420654
10,rf,0.25,8096,0.6366714835166931,14.225350379943848
10,rf,0.5,128,0.6994720697402954,0.48907971382141113
10,rf,0.5,256,0.7423186898231506,0.6067798137664795
10,rf,0.5,512,0.7539263367652893,0.8784573078155518
10,rf,0.5,1024,0.7729596495628357,1.4654514789581299
10,rf,0.5,2048,0.7651347517967224,2.8497238159179688
10,rf,0.5,8096,0.7679094672203064,12

In [25]:
nn_results = {}
nn_timings = {}
cnt = 0

for k,v in samples.items():
    for size in sample_sizes:

        try:
            start = time.time()
            scores = nn_c2st(center_samples[:size,...],
                             v[:size,...])
            end = time.time()
        except Exception as ex:
            print(ex)
            continue
        
        if not k in nn_results.keys():
            nn_results[k] = []
            nn_timings[k] = []

        nn_results[k].extend(scores)
        nn_timings[k].append(end-start)
        cnt += 1

        print(f"{cnt}/{total}: {k}[{size},...] = {scores} ({end-start} seconds)")

1/54: 0.0[128,...] = [0.4649321] (6.982034206390381 seconds)


2/54: 0.0[256,...] = [0.5021702] (18.391536712646484 seconds)


3/54: 0.0[512,...] = [0.5439694] (28.89311933517456 seconds)


4/54: 0.0[1024,...] = [0.526871] (54.0725200176239 seconds)


5/54: 0.0[2048,...] = [0.5070788] (123.58810925483704 seconds)


6/54: 0.0[8096,...] = [0.5031488] (509.06522130966187 seconds)


7/54: 0.25[128,...] = [0.5816742] (10.589311361312866 seconds)


8/54: 0.25[256,...] = [0.60930896] (20.84875774383545 seconds)


9/54: 0.25[512,...] = [0.58786225] (32.01209735870361 seconds)


10/54: 0.25[1024,...] = [0.5791043] (62.982462644577026 seconds)


11/54: 0.25[2048,...] = [0.58666486] (135.8197956085205 seconds)


12/54: 0.25[8096,...] = [0.5701574] (690.5073640346527 seconds)


13/54: 0.5[128,...] = [0.70746607] (10.090089082717896 seconds)


14/54: 0.5[256,...] = [0.69333714] (28.960004329681396 seconds)


15/54: 0.5[512,...] = [0.7070349] (43.92514443397522 seconds)


16/54: 0.5[1024,...] = [0.72313076] (76.67817330360413 seconds)


17/54: 0.5[2048,...] = [0.73168737] (148.7741618156433 seconds)


18/54: 0.5[8096,...] = [0.7093626] (577.8812856674194 seconds)


19/54: 0.75[128,...] = [0.8751131] (6.228252649307251 seconds)


20/54: 0.75[256,...] = [0.85537785] (21.79594659805298 seconds)


21/54: 0.75[512,...] = [0.86128646] (32.100399017333984 seconds)


22/54: 0.75[1024,...] = [0.85107994] (56.44421625137329 seconds)


23/54: 0.75[2048,...] = [0.84252447] (112.76290535926819 seconds)


24/54: 0.75[8096,...] = [0.8411555] (402.5190885066986 seconds)


25/54: 1.0[128,...] = [0.9062594] (5.21303391456604 seconds)


26/54: 1.0[256,...] = [0.90820485] (16.347495079040527 seconds)


27/54: 1.0[512,...] = [0.9111143] (16.1702241897583 seconds)


28/54: 1.0[1024,...] = [0.92235196] (26.505642890930176 seconds)


29/54: 1.0[2048,...] = [0.9187031] (54.359280824661255 seconds)


30/54: 1.0[8096,...] = [0.9159462] (281.2870545387268 seconds)


31/54: 1.25[128,...] = [0.93755656] (2.3409054279327393 seconds)


32/54: 1.25[256,...] = [0.9550543] (6.887072563171387 seconds)


33/54: 1.25[512,...] = [0.95214254] (10.649713039398193 seconds)


34/54: 1.25[1024,...] = [0.9663045] (19.645933628082275 seconds)


35/54: 1.25[2048,...] = [0.9665535] (33.377681732177734 seconds)


36/54: 1.25[8096,...] = [0.9639944] (182.7131426334381 seconds)


37/54: 1.5[128,...] = [0.9843137] (2.26505446434021 seconds)


38/54: 1.5[256,...] = [0.9824291] (3.7571702003479004 seconds)


39/54: 1.5[512,...] = [0.9843711] (8.773359298706055 seconds)


40/54: 1.5[1024,...] = [0.98340154] (12.13750672340393 seconds)


41/54: 1.5[2048,...] = [0.98437387] (23.305691480636597 seconds)


42/54: 1.5[8096,...] = [0.98591906] (102.44566059112549 seconds)


43/54: 1.75[128,...] = [0.9882353] (2.3607404232025146 seconds)


44/54: 1.75[256,...] = [0.9823529] (3.0930984020233154 seconds)


45/54: 1.75[512,...] = [0.9863271] (4.608965635299683 seconds)


46/54: 1.75[1024,...] = [0.9912088] (8.245190620422363 seconds)


47/54: 1.75[2048,...] = [0.9953614] (11.485974311828613 seconds)


48/54: 1.75[8096,...] = [0.99487394] (67.4766194820404 seconds)


49/54: 2.0[128,...] = [0.9883107] (1.8419666290283203 seconds)


50/54: 2.0[256,...] = [0.9882353] (2.3699238300323486 seconds)


51/54: 2.0[512,...] = [0.9931564] (3.8780014514923096 seconds)


52/54: 2.0[1024,...] = [0.994627] (5.921618461608887 seconds)


53/54: 2.0[2048,...] = [0.9973147] (9.838232278823853 seconds)


54/54: 2.0[8096,...] = [0.99895006] (28.462652683258057 seconds)


In [26]:
with open("nn_results.csv","w") as ocsv:
    ocsv.write(header+"\n")
    for k in samples.keys():
        for idx in range(len(sample_sizes)):
            
            row = f"{NDIM},nn,{k},{sample_sizes[idx]},{nn_results[k][idx]},{nn_timings[k][idx]}\n"
            print(row)
            ocsv.write(row)

10,nn,0.0,128,0.4649321138858795,6.982034206390381

10,nn,0.0,256,0.502170205116272,18.391536712646484

10,nn,0.0,512,0.5439693927764893,28.89311933517456

10,nn,0.0,1024,0.5268710255622864,54.0725200176239

10,nn,0.0,2048,0.5070788264274597,123.58810925483704

10,nn,0.0,8096,0.5031487941741943,509.06522130966187

10,nn,0.25,128,0.5816742181777954,10.589311361312866

10,nn,0.25,256,0.6093089580535889,20.84875774383545

10,nn,0.25,512,0.5878622531890869,32.01209735870361

10,nn,0.25,1024,0.5791043043136597,62.982462644577026

10,nn,0.25,2048,0.5866648554801941,135.8197956085205

10,nn,0.25,8096,0.5701574087142944,690.5073640346527

10,nn,0.5,128,0.7074660658836365,10.090089082717896

10,nn,0.5,256,0.6933371424674988,28.960004329681396

10,nn,0.5,512,0.7070348858833313,43.92514443397522

10,nn,0.5,1024,0.7231307625770569,76.67817330360413

10,nn,0.5,2048,0.7316873669624329,148.7741618156433

10,nn,0.5,8096,0.7093626260757446,577.8812856674194

10,nn,0.75,128,0.8751131296157837,6.22825264

In [27]:
!head -n10 nn_results.csv

ndims,mode,dist_sigma,nsamples,c2st_score,time_sec
10,nn,0.0,128,0.4649321138858795,6.982034206390381
10,nn,0.0,256,0.502170205116272,18.391536712646484
10,nn,0.0,512,0.5439693927764893,28.89311933517456
10,nn,0.0,1024,0.5268710255622864,54.0725200176239
10,nn,0.0,2048,0.5070788264274597,123.58810925483704
10,nn,0.0,8096,0.5031487941741943,509.06522130966187
10,nn,0.25,128,0.5816742181777954,10.589311361312866
10,nn,0.25,256,0.6093089580535889,20.84875774383545
10,nn,0.25,512,0.5878622531890869,32.01209735870361
