In [1]:
import pandas as pd
import torch
import random
import os
import re
from torch.distributions.binomial import Binomial
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn

In [19]:
def treat_portion(x, pattern):
    score = 0.2
    if pattern in x:
        score = 0.8
    return score

In [21]:
DATA_PATH = '/nfs/turbo/lsa-regier/emr-data/'
groups = [str(i) for i in range(10)]
beta_1 = torch.tensor([1.0, 10.0, 100.0])
pattern = 'diag:J45' # Asthma


for group in groups:
    data_path = os.path.join(DATA_PATH, f'group_{group}_merged.csv')
    data_temp = pd.read_csv(data_path, sep=',')
    
    ### calculate propensity score pi(z)
    data_temp = data_temp.assign(base_propensity_score = data_temp['document'].apply(lambda x: treat_portion(x, pattern)))
    
    ### simulate treatment based on propensity score
    all_treatments = Binomial(1, torch.tensor(data_temp["base_propensity_score"].values))
    data_temp = data_temp.assign(treatment = all_treatments.sample().int())

    sigmoid_fn = torch.nn.Sigmoid()
    ### simulate response from treatment and propensity score
    for i in range(2):
        treatment_propensity = 0.25 * i + beta_1[1] * (data_temp['base_propensity_score'] - 0.2)
        input = torch.tensor(treatment_propensity.values)
        reponse_prob = sigmoid_fn(input)
        all_response = torch.bernoulli(reponse_prob)
        data_temp[f'response_{str(i)}_'] = all_response
        
    data_temp['response'] = (data_temp['treatment'] == 0)* data_temp['response_0_'] \
                            + (data_temp['treatment'] == 1)* data_temp['response_1_']
    data = pd.concat([data, data_temp]) if group != '0' else data_temp
    break
    
    
    

In [22]:
data = data.rename(columns={'response_0_': 'potential_response_0', 'response_1_': 'potential_response_1','response_2_': 'response'})

In [23]:
data["response"].sum()

108361.0

In [24]:
data["treatment"].sum()

47604

In [25]:
SATE = (data["potential_response_1"]-data["potential_response_0"]).mean()
SATE

0.05901331963636545

In [26]:
data["potential_response_1"].mean()

0.5877395514907919

In [27]:
data["potential_response_0"].mean()

0.5287262318544265

In [28]:
data

Unnamed: 0,patid,document,base_propensity_score,treatment,potential_response_0,potential_response_1,response
0,560499201141940,SIMVASTATIN METOPROLOL_SUCCINATE LISINOPRIL ES...,0.2,0,1.0,1.0,1.0
1,560499201299620,icd:9_diag:7822 [SEP] icd:9_diag:V700 icd:9_di...,0.2,0,1.0,1.0,1.0
2,560499202033650,icd:10_diag:M7051 icd:10_diag:M7052,0.2,1,1.0,1.0,1.0
3,560499202033740,icd:9_diag:V7612 icd:9_diag:V7612 [SEP] icd:9_...,0.2,0,0.0,1.0,0.0
4,560499202037510,icd:9_diag:5246 icd:9_diag:83901 [SEP] icd:9_d...,0.2,1,0.0,0.0,0.0
...,...,...,...,...,...,...,...
200526,560499899998870,icd:9_diag:8472 CYCLOBENZAPRINE_HCL [SEP] icd:...,0.2,1,1.0,1.0,1.0
200527,560499899999040,icd:9_diag:41401 icd:9_diag:71536 icd:9_diag:7...,0.2,0,1.0,0.0,1.0
200528,560499899999220,icd:9_diag:V202 [SEP] icd:9_diag:38870 NEOMY_S...,0.2,0,0.0,1.0,0.0
200529,560499899999970,AMOX_TR/POTASSIUM_CLAVULANATE GUAIFENESIN/CODE...,0.2,0,0.0,1.0,0.0


In [29]:
((data["treatment"])*data["response"]).sum()

30908.0

In [30]:
((1-data["treatment"])*data["response"]).sum()

77453.0

In [31]:
data["base_propensity_score"].max()

0.8

In [32]:
treat_portion("diag:J4500", pattern)

0.8

In [33]:
data_temp['document'].apply(lambda x: "diag:J45" in x)

0         False
1         False
2         False
3         False
4         False
          ...  
200526    False
200527    False
200528    False
200529    False
200530    False
Name: document, Length: 200531, dtype: bool