In [1]:
import dagsim.base as ds
import numpy as np
import csv
from pathlib import Path
import scipy
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

Variables from first example in dd4d tutorial
https://nbviewer.org/github/opensafely/mv-dummy-data/blob/main/prototype-report.html#example-1

```
age = node(
    variable_formula = ~(rnorm(n=1, mean=60, sd=15))
),
sex = node(
variable_formula = ~(rcat(n=1, levels = c("F", "M"), p = c(0.51, 0.49))),
),
diabetes = node(
variable_formula = ~(rbernoulli(n=1, p = plogis(-1 + age*0.002 + I(sex=='F')*-0.2))),
),
hosp_admission_count = node(
variable_formula = ~(rpois(n=1, lambda = exp(-2.5 + age*0.03 + I(sex=='F')*-0.2 +diabetes*1)))
),
time_to_death = node(
variable_formula = ~(round(rexp(n=1, rate = exp(-5 + age*0.01 + I(age^2)*0.0001 + diabetes*1.5 + hosp_admission_count*1)/365))),
)
```

In [2]:
rng = np.random.default_rng()

In [3]:
age = ds.Node(name="age", function=rng.normal, kwargs={"loc": 60, "scale": 15, "size": 1})

In [4]:
sex = ds.Node(name="sex", function=rng.choice, kwargs={"a": ["M", "F"], "p": [0.51, 0.49]})

In [5]:
def sex_as_int(sex):
    return 1 if sex == "F" else 0

In [6]:
def diabetes_fn(rng, age, sex):
    return rng.binomial(n=1, p=scipy.special.expit(-1 + age*0.002 + sex_as_int(sex) * -0.2))

In [7]:
diabetes = ds.Node(name="diabetes", function=diabetes_fn, kwargs={"rng": rng, "age": age, "sex": sex})

In [8]:
def hosp_admission_count_fn(rng, age, sex, diabetes):
    return rng.poisson(size=1, lam=np.exp(-2.5 + age*0.03 + sex_as_int(sex)*-0.2 +diabetes*1))

In [9]:
hosp_admission_count = ds.Node(name="hosp_admission_count", function=hosp_admission_count_fn, kwargs={"rng": rng, "age": age, "sex": sex, "diabetes": diabetes})

In [10]:
def time_to_death_fn(rng, age, diabetes, hosp_admission_count):
    return rng.exponential(
        size=1, 
        scale=np.exp(
            -5 + age*0.01 + ((age**2)*0.0001) + diabetes * 1.5 + hosp_admission_count * 1
        )/365
    )

In [11]:
time_to_death = ds.Node(name="time_to_death", function=time_to_death_fn, kwargs={"rng": rng, "age": age, "diabetes": diabetes, "hosp_admission_count": hosp_admission_count})

In [12]:
graph = ds.Graph(name="demo_graph", list_nodes=[age, sex, diabetes, hosp_admission_count, time_to_death])

In [16]:
graph.draw()

In [14]:
data = graph.simulate(num_samples=20, csv_name="demo_data")

2024-09-20 12:02:54.994077: Simulation started.
2024-09-20 12:02:54.994758: Simulating node "age".
2024-09-20 12:02:54.995604: Simulating node "sex".
2024-09-20 12:02:54.997245: Simulating node "diabetes".
2024-09-20 12:02:55.011791: Simulating node "hosp_admission_count".
2024-09-20 12:02:55.013576: Simulating node "time_to_death".
2024-09-20 12:02:55.020093: Simulation finished in 0.0260 seconds.


In [15]:
for row in Path("demo_data.csv").read_text().split("\n"):
    print(row)

age,sex,diabetes,hosp_admission_count,time_to_death
[33.86051237],M,[1],[2],[0.00191583]
[60.42863862],M,[0],[0],[3.27029842e-05]
[53.51355427],M,[1],[2],[0.00022324]
[58.70989768],F,[1],[1],[0.00031219]
[56.76104256],F,[1],[2],[0.00105921]
[86.36782201],M,[0],[2],[0.00033341]
[51.74176839],M,[0],[0],[4.922631e-06]
[53.51032054],F,[1],[4],[0.00139998]
[62.1428733],M,[0],[1],[3.29997181e-06]
[84.24442133],M,[0],[0],[2.9037925e-05]
[54.73395752],F,[1],[1],[0.00063925]
[46.00845351],M,[0],[0],[3.53652923e-05]
[54.64338671],F,[0],[0],[2.70418527e-05]
[49.11102081],F,[0],[0],[1.64008394e-06]
[62.61083576],M,[1],[1],[0.00059983]
[64.9909159],M,[0],[0],[2.94413712e-06]
[54.37972677],F,[1],[0],[0.00028093]
[64.59080979],M,[0],[0],[2.58387313e-05]
[21.24513734],F,[0],[0],[1.11831421e-05]
[36.83594173],M,[0],[0],[3.6615347e-05]

