In [1]:
import os
import time

import numpy as np
import torch
import pandas as pd
import random
from torch.distributions.binomial import Binomial
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 

from transformers import DataCollatorForLanguageModeling, BertForMaskedLM
from transformers import Trainer, TrainingArguments
from transformers import get_linear_schedule_with_warmup, AdamW

from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from data import CausalBertDataset, MLMDataset
from causal_bert import CausalBOW

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def true_casual_effect(data_loader, effect='ate', estimation='q'):
    assert effect == 'ate' and estimation == 'q', f'unallowed effect/estimation: {effect}/{estimation}'
    
    dataset = data_loader.dataset
    
    Q1 = dataset.treatment * dataset.response + (1 - dataset.treatment) * dataset.pseudo_response
    Q1 = Q1.cpu().data.numpy().squeeze()

    Q0 = dataset.treatment * dataset.pseudo_response + (1 - dataset.treatment) * dataset.response
    Q0 = Q0.cpu().data.numpy().squeeze()

    treatment = dataset.treatment.cpu().data.numpy().squeeze()
    prop_scores = dataset.prop_score.cpu().data.numpy().squeeze()
    
    if estimation == 'q':
        if effect == 'att':
            phi = (treatment * (Q1 - Q0))
            return phi.sum() / treatment.sum()
        elif effect == 'ate':
            return (Q1 - Q0).mean()
        
    elif estimation == 'plugin':
        phi = (prop_scores * (Q1 - Q0)).mean()
        if effect == 'att':
            return phi / treatment.mean()
        elif effect == 'ate': 
            return phi
        
def est_casual_effect(data_loader, model, effect='ate', estimation='q', eval_loss=True, **kwargs):
    # We use `real_treatment` here to emphasize the estimations use real instead of estimated treatment.
    real_response, real_treatment = [], []
    prop_scores, Q1, Q0 = [], [], []

    if eval_loss:
        g_loss = kwargs.get('g_loss')
        q_loss = kwargs.get('q_loss')
        g_loss_test, q1_loss_test, q0_loss_test  = [], [], []
    model.eval()
    for idx, (tokens, treatment, response) in enumerate(data_loader):
        real_response.append(response.cpu().data.numpy().squeeze())
        real_treatment.append(treatment.cpu().data.numpy().squeeze())

        prop_score, q1, q0 = model(tokens)
        
        prop_scores.append(prop_score.cpu().data.numpy().squeeze())
        Q1.append(q1.cpu().data.numpy().squeeze())
        Q0.append(q0.cpu().data.numpy().squeeze())
        
        if eval_loss:
            g_loss_val  = g_loss(prop_score, treatment)
            q1_loss_val = q_loss(q1[treatment==1], response[treatment==1])
            q0_loss_val = q_loss(q0[treatment==0], response[treatment==0])
            
            g_loss_test.append(g_loss_val.item())
            q1_loss_test.append(q1_loss_val.item())
            q0_loss_test.append(q0_loss_val.item())
            
    real_response = np.concatenate(real_response, axis=0)
    real_treatment = np.concatenate(real_treatment, axis=0)

    Q1 = np.concatenate(Q1, axis=0)
    Q0 = np.concatenate(Q0, axis=0)
    prop_scores = np.concatenate(prop_scores, axis=0)
    
    
    g_loss = np.array(g_loss_test).mean() if eval_loss else None
    q1_loss = np.array(q1_loss_test).mean() if eval_loss else None
    q0_loss = np.array(q0_loss_test).mean() if eval_loss else None

    model.train()
    
    if estimation == 'q':
        if effect == 'att':
            phi = (real_treatment * (Q1 - Q0))
            return phi.sum() / real_treatment.sum(), g_loss, q1_loss, q0_loss
        elif effect == 'ate':
            return (Q1 - Q0).mean(), g_loss, q1_loss, q0_loss

    elif estimation == 'plugin':
        phi = (prop_scores * (Q1 - Q0)).mean()
        if effect == 'att':
            return phi / real_treatment.mean(), g_loss, q1_loss, q0_loss
        elif effect == 'ate':
            return phi, g_loss, q1_loss, q0_loss

# Prepare data

In [4]:
vocab = create_vocab(merged=True, uni_diag=True)

tokenizer = WordLevelBertTokenizer(vocab)

alpha = 0.25
beta = 5.
c = 0.2
i = 0.

In [5]:
start = time.time()
trainset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                             alpha=alpha, beta=beta, c=c, i=i, 
                             group=list(range(1)), max_length=512, min_length=10,
                             truncate_method='first', device=device, seed=1)

print(f'Load training set in {(time.time() - start):.2f} sec')

Load training set in 68.53 sec


In [6]:
start = time.time()
testset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                            alpha=alpha, beta=beta, c=c, i=i, 
                            group=[9], max_length=512, min_length=10,
                            truncate_method='first', device=device)

print(f'Load training set in {(time.time() - start):.2f} sec')

Load training set in 60.81 sec


In [7]:
train_loader = DataLoader(trainset, batch_size=256, drop_last=True, shuffle=True)
test_loader = DataLoader(testset, batch_size=2048, drop_last=True, shuffle=True)

In [8]:
real_att_q = true_casual_effect(test_loader)

print(f'Real: [effect: ate], [estimation: q], [value: {real_att_q:.5f}]')
print(f'Unadjusted: [value: {(testset.response[testset.treatment == 1].mean() - testset.response[testset.treatment == 0].mean()).item():.4f}]')

Real: [effect: ate], [estimation: q], [value: 0.06024]
Unadjusted: [value: 0.1403]


# Create and train a Causal-Bert

In [11]:
trained_bert = '/nfs/turbo/lsa-regier/bert-results/results/behrt/MLM/merged/unidiag/checkpoint-6018425/'
# trained_bert = '/home/liutianc/emr/bert/results/behrt/MLM/merged/unidiag/checkpoint-6018425/'

model = BertForMaskedLM.from_pretrained(trained_bert)
token_embed = model.get_input_embeddings()

epoch = 50

# learnable_docu_embed: True, False
model = CausalBOW(token_embed, learnable_docu_embed=True, hidden_size=128).to(device)

# lr: small: 1e-5, large: 5e-4
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
epoch_iter = len(train_loader)
total_steps = epoch * epoch_iter

optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

q_loss = nn.BCELoss()
prop_score_loss = nn.BCELoss()

# Please specify the effect and estimation we want to use here.
effect = 'ate'
estimation = 'q'

effect = effect.lower()
estimation = estimation.lower()
assert effect in ['att', 'ate'], f'Wrong effect: {effect}...'
assert estimation in ['q', 'plugin'], f'Wrong estimation: {estimation}...'

In [12]:
epoch = 50
rs_loss, rq1_loss, rq0_loss = [0.] * 3

run_loss = 0.
for e in range(1, epoch + 1):
    model.train()
    start = time.time()
    for idx, (tokens, treatment, response) in enumerate(train_loader):
        optimizer.zero_grad()
        prop_score, q1, q0 = model(tokens)
        
        g_loss  = prop_score_loss(prop_score, treatment)
        q1_loss = q_loss(q1[treatment==1], response[treatment==1])
        q0_loss = q_loss(q0[treatment==0], response[treatment==0])
        
        loss = q1_loss + q0_loss + g_loss
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    

        optimizer.step()
        scheduler.step()

        run_loss += loss.item()
        
    run_idx = idx

    # Evaluation.
    train_est_att_q, _, _, _ = est_casual_effect(train_loader, model, effect, estimation, eval_loss=False)
    test_est_att_q, g_loss_test, q1_loss_test, q0_loss_test = est_casual_effect(test_loader, model, effect, estimation, g_loss=prop_score_loss, q_loss=q_loss)
    test_loss = q1_loss_test + q0_loss_test
    test_loss += g_loss_test

    print(f'''epoch: {e} / {epoch}, time cost: {(time.time() - start):.2f} sec, 
          Loss: [Train: {(run_loss / (run_idx + 1))  :.5f}], [Test: {test_loss:.5f}],
          Effect: [{effect}-{estimation}], [train: {train_est_att_q:.5f}], [test: {test_est_att_q:.5f}]''')
    print('*'* 80)
    start = time.time()
    run_loss = 0.

print('Finish training...')

# With only 1 group(s) to train.

epoch: 1 / 50, time cost: 100.57 sec, 
          Loss: [Train: 1.92246], [Test: 1.89213],
          Effect: [ate-q], [train: 0.14620], [test: 0.14617]
********************************************************************************
epoch: 2 / 50, time cost: 46.68 sec, 
          Loss: [Train: 1.88668], [Test: 1.89105],
          Effect: [ate-q], [train: 0.14348], [test: 0.14344]
********************************************************************************
epoch: 3 / 50, time cost: 44.59 sec, 
          Loss: [Train: 1.88495], [Test: 1.89009],
          Effect: [ate-q], [train: 0.13846], [test: 0.13842]
********************************************************************************
epoch: 4 / 50, time cost: 48.23 sec, 
          Loss: [Train: 1.88337], [Test: 1.88850],
          Effect: [ate-q], [train: 0.13706], [test: 0.13702]
********************************************************************************
epoch: 5 / 50, time cost: 46.96 sec, 
          Loss: [Train: 1.88083], [Te

In [None]:
# epoch = 50
# rs_loss, rq1_loss, rq0_loss = [0.] * 3

# run_loss = 0.
# for e in range(1, epoch + 1):
#     model.train()
#     start = time.time()
#     for idx, (tokens, treatment, response) in enumerate(train_loader):
#         optimizer.zero_grad()
#         prop_score, q1, q0 = model(tokens)
        
#         g_loss  = prop_score_loss(prop_score, treatment)
#         q1_loss = q_loss(q1[treatment==1], response[treatment==1])
#         q0_loss = q_loss(q0[treatment==0], response[treatment==0])
#         loss = q1_loss + q0_loss + g_loss
#         loss.backward()
        
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    
#         optimizer.step()
        
#         run_loss += loss.item()
        
#     run_idx = idx

#     # Evaluation.
#     train_est_att_q, _, _, _ = est_casual_effect(train_loader, model, effect, estimation, eval_loss=False)
#     test_est_att_q, g_loss_test, q1_loss_test, q0_loss_test = est_casual_effect(test_loader, model, effect, estimation, g_loss=prop_score_loss, q_loss=q_loss)
#     test_loss = q1_loss_test + q0_loss_test
#     test_loss += g_loss_test

#     print(f'''epoch: {e} / {epoch}, time cost: {(time.time() - start):.2f} sec, 
#           Loss: [Train: {(run_loss / (run_idx + 1))  :.5f}], [Test: {test_loss:.5f}],
#           Effect: [{effect}-{estimation}], [train: {train_est_att_q:.5f}], [test: {test_est_att_q:.5f}]''')
#     print('*'* 80)
#     start = time.time()
#     run_loss = 0.

# print('Finish training...')

# # With only 1 group(s) to train.

In [None]:
# epoch = 50
# rs_loss, rq1_loss, rq0_loss = [0.] * 3

# for e in range(1, epoch + 1):
#     model.train()
#     start = time.time()
#     for idx, (tokens, treatment, response) in enumerate(train_loader):
#         optimizer.zero_grad()
#         prop_score, q1, q0 = model(tokens)
        
#         s_loss  = prop_score_loss(prop_score, treatment)
#         q1_loss = q_loss(q1[treatment==1], response[treatment==1])
#         q0_loss = q_loss(q0[treatment==0], response[treatment==0])
        
#         s_loss.backward(retain_graph=True)
#         q1_loss.backward(retain_graph=True)
#         q0_loss.backward()
#         optimizer.step()
        
#         rs_loss  += s_loss.item()
#         rq1_loss += q1_loss.item()
#         rq0_loss += q0_loss.item()
        
#     run_idx = idx

#     # Evaluation.
#     train_est_att_q, _, _, _ = est_casual_effect(train_loader, model, effect, estimation, eval_loss=False)
#     test_est_att_q, g_loss_test, q1_loss_test, q0_loss_test = est_casual_effect(test_loader, model, effect, estimation, g_loss=prop_score_loss, q_loss=q_loss)

#     print(f'''epoch: {e} / {epoch}, time cost: {(time.time() - start):.2f} sec, 
#           Train: [g_loss: {(rs_loss / (run_idx + 1))  :.5f}], [q1_loss: {(rq1_loss / (run_idx + 1)):.5f}], [q0_loss: {(rq0_loss/ (run_idx + 1)) :.5f}]
#           Test: [g_loss: {(g_loss_test)  :.5f}], [q1_loss: {(q1_loss_test):.5f}], [q0_loss: {(q0_loss_test) :.5f}]
#           Effect: [effect: {effect}], [estimation: {estimation}], [train: {train_est_att_q:.5f}], [test: {test_est_att_q:.5f}]''')
#     print('*'* 80)
#     start = time.time()
#     rs_loss, rq1_loss, rq0_loss = [0.] * 3

# print('Finish training...')

# # With only 3 groups to train.

In [None]:
effect = 'att'
real_att_q = true_casual_effect(test_loader)
est_att_q, _, _, _ = est_casual_effect(test_loader, model, effect, estimation, eval_loss=False)

print(f'Real: [effect: {effect}], [estimation: {estimation}], [value: {real_att_q:.5f}]')
print(f'unadjusted: {(testset.response[testset.treatment == 1].mean() - testset.response[testset.treatment == 0].mean()).item():.4f}')
print(f'Estimated: [effect: {effect}], [estimation: {estimation}], [value: {est_att_q:.5f}]')

In [None]:
effect = 'ate'
real_att_q = true_casual_effect(test_loader)
est_att_q, _, _, _ = est_casual_effect(test_loader, model, effect, estimation, eval_loss=False)

print(f'Real: [effect: {effect}], [estimation: {estimation}], [value: {real_att_q:.5f}]')
print(f'unadjusted: {(testset.response[testset.treatment == 1].mean() - testset.response[testset.treatment == 0].mean()).item():.4f}')
print(f'Estimated: [effect: {effect}], [estimation: {estimation}], [value: {est_att_q:.5f}]')