In [2]:
import pandas as pd
import torch
import random
import os
from torch.distributions.binomial import Binomial
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn

In [15]:
def treat_portion(x):
    x = x.replace(' [SEP] ', '').strip()
    token_list = x.split(' ')
    treat_cnt = 0
    if len(token_list) > 0:
        for token in token_list:
            if 'diag' not in token:
                treat_cnt += 1
    
    score = (treat_cnt + 1e-8)/ (len(token_list) + 1e-8)
    score = max(score, 0.001)
    score = min(score, 0.999)

    return score

DATA_PATH = '/nfs/turbo/lsa-regier/emr-data/'
groups = [str(i) for i in range(10)]
beta_1 = torch.tensor([1.0, 10.0, 100.0])


for group in groups:
    print(group)
    data_path = os.path.join(DATA_PATH, f'group_{group}_merged.csv')
    data_temp = pd.read_csv(data_path, sep=',')
    
    ### calculate propensity score pi(z)
    data_temp = data_temp.assign(base_propensity_score = data_temp['document'].apply(lambda x: treat_portion(x)))
    
    ### simulate treatment based on propensity score
    all_treatments = Binomial(1, torch.tensor(data_temp["base_propensity_score"].values))
    data_temp = data_temp.assign(treatment = all_treatments.sample())
    
    ### simulate response from treatment and propensity score
    treatment_propensity = 0.25 * data_temp['treatment'] + beta_1[0] * (data_temp['base_propensity_score'] - 0.2)
    sigmoid_fn = torch.nn.Sigmoid()
    input = torch.tensor(treatment_propensity.values)
    reponse_prob = sigmoid_fn(input)
    all_response = torch.bernoulli(reponse_prob)
    data_temp = data_temp.assign(response = all_response)
    data = pd.concat([data, data_temp]) if group != 0 else data_temp

0
1
2
3
4
5
6
7
8
9


In [16]:
data["response"].sum()

1066811.0

In [17]:
data["treatment"].sum()

247961.0

In [18]:
data


Unnamed: 0,patid,document,base_propensity_score,treatment,response
0,560499201514209,OLMESARTAN_MEDOXOMIL [SEP] icd:9_diag:4011 icd...,0.140187,0.0,1.0
1,560499201790039,icd:9_diag:7831 [SEP] icd:9_diag:5589 icd:9_di...,0.001000,0.0,0.0
2,560499202032459,icd:9_diag:79093 [SEP] BLOOD_SUGAR_DIAGNOSTIC ...,0.036364,0.0,1.0
3,560499202053999,icd:9_diag:3804 icd:9_diag:38870 icd:9_diag:46...,0.021739,0.0,0.0
4,560499202061489,icd:9_diag:8820 icd:9_diag:8820 icd:9_diag:E84...,0.136364,0.0,0.0
...,...,...,...,...,...
199422,560499899999079,icd:9_diag:V700 icd:9_diag:V7644 [SEP] icd:9_d...,0.001000,0.0,1.0
199423,560499899999569,icd:9_diag:53081 icd:9_diag:78650 [SEP] icd:9_...,0.080000,0.0,0.0
199424,560499899999659,icd:9_diag:36617 [SEP] icd:9_diag:36617 icd:9_...,0.034483,0.0,0.0
199425,560499899999879,ATENOLOL FLUOXETINE_HCL [SEP] NAPROXEN_SODIUM ...,0.224215,0.0,1.0
