In [4]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '..')

import numpy as np
from causal_simulation_code import Sine_Base, Simulation

In [80]:
causes = np.array([0,1])    # treatment variables (must match amount of base functions!)
dims = 5                   # covariate dimension count

sim = Simulation(
    D=dims,                                                # covariates dimension
    C=causes,                                              # causes
    base_functions=np.array([
        Sine_Base(f=6, D=dims, G=np.zeros(dims)),          # functions for causes
        Sine_Base(f=2, D=dims, G=((4)*np.random.rand(dims)-2))]),
    std=0)                                                 # noise

In [81]:
n = 1000     # data count
data = np.ndarray((n, dims + len(causes) * 2))

for d in range(n):
    covariates = sim.get_new_state()
    for c in causes:
        p = sim.get_response_rate(c, covariates, drift=False)
        """
        Note that the parameters of the Sine_Base are
        highly sensitive. It is good practice to keep
        in mind the given response rates. Also, the
        higher the dimension of the covariates the more
        sensitive the parameters will be.
        """
        print(p)
        covariates = np.append(covariates, [c, np.random.binomial(1,p)])
    
    data[d] = covariates
    

0.5194546368822334
0.5064863344088797
0.5416727317171718
0.5139052501098343
0.5180013849953493
0.5060016145953736
0.5155480711464583
0.5051834331468136
0.5138447845492788
0.5046154525576382
0.50122205895219
0.5004073533445694
0.5093309545706151
0.5031104786921944
0.582633059824257
0.5276571827441575
0.9500165439969459
0.6823366001675649
0.6632697409621159
0.5553264757275395
0.5014272768729423
0.5004757595319771
0.25987219460789057
0.9685416501955972
0.6193479474777965
0.5401272504791732
0.5143491985794932
0.5047836500104113
0.5526661094933125
0.5175843685666724
0.6403633710316772
0.5473541236989212
0.5139123826599729
0.5046379929814264
0.7497436568600098
0.5867269361839044
0.9776345595868325
0.7054872790560398
0.5184789329664196
0.5061608915029213
0.5323882843833752
0.5108028185198658
0.82763050460176
0.6179652706723264
0.5312569639358418
0.510425030674704
0.5186243431579195
0.5062093912533897
0.9997088911487579
0.7549093084967339
0.5576992851555017
0.5192712656921272
0.799933956338470

0.6934382653952564
0.5660136879947694
0.6195101102032671
0.540182736184498
0.5217312031096657
0.5072457632255097
0.8464182076484854
0.6261896561234732
0.501342356394362
0.5004474526092471
0.5234139511945518
0.5078071883460582
0.5020698500544486
0.5006899517698302
0.5073200809687137
0.5024401044757051
0.7775182681968564
0.5974402582445756
0.5696684277199117
0.5232901871699327
0.50582747431984
0.5019425305332428
0.5390574028353896
0.5130309354604962
0.5852418024664494
0.5285378892186897
0.5015144127710828
0.5005048049430993
0.5018839332775564
0.5006279790799769
0.5020594911657028
0.5006864987807464
0.5248076182538554
0.5082722251030163
0.5782054530555727
0.5261640081439853
0.519401078526277
0.5064684696307886
0.8045008872731427
0.6082691067854893
0.500679521370634
0.5002265071855236
0.9380569586719272
0.6742235395198437
0.5126092430007656
0.5042034771184498
0.5001722547010032
0.5000574182346773
0.5791467986193424
0.5264813076607695
0.5032037682523651
0.5010679292464916
0.5888009297551173

In [83]:
data

array([[0.47912421, 0.77664246, 0.67185503, ..., 0.        , 1.        ,
        0.        ],
       [0.17111171, 0.6280657 , 0.49966338, ..., 1.        , 1.        ,
        0.        ],
       [0.46535226, 0.26265062, 0.12480625, ..., 1.        , 1.        ,
        1.        ],
       ...,
       [0.27360197, 0.73726894, 0.31800897, ..., 0.        , 1.        ,
        0.        ],
       [0.32330085, 0.87825974, 0.05439212, ..., 1.        , 1.        ,
        1.        ],
       [0.38104782, 0.54149087, 0.02135069, ..., 1.        , 1.        ,
        1.        ]])