In [1]:
import os
import time

import numpy as np
import torch
import pandas as pd
import random
from torch.distributions.binomial import Binomial
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 
from torch.nn.utils import clip_grad_norm_

from transformers import DataCollatorForLanguageModeling, BertModel
from transformers import Trainer, TrainingArguments
from transformers import get_linear_schedule_with_warmup, AdamW

from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from data import CausalBertDataset, MLMDataset
from causal_bert import CausalBert

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '6'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def true_casual_effect(data_loader, effect='ate', estimation='Q'):
    dataset = data_loader.dataset
    
    Q1 = dataset.treatment * dataset.response + (1 - dataset.treatment) * dataset.pseudo_response
    Q1 = Q1.cpu().data.numpy().squeeze()

    Q0 = dataset.treatment * dataset.pseudo_response + (1 - dataset.treatment) * dataset.response
    Q0 = Q0.cpu().data.numpy().squeeze()

    treatment = dataset.treatment.cpu().data.numpy().squeeze()
    prop_scores = dataset.prop_score.cpu().data.numpy().squeeze()
    
    if estimation == 'q':
        if effect == 'att':
            phi = (treatment * (Q1 - Q0))
            return phi.sum() / treatment.sum()
        elif effect == 'ate':
            return (Q1 - Q0).mean()
        
    elif estimation == 'plugin':
        phi = (prop_scores * (Q1 - Q0)).mean()
        if effect == 'att':
            return phi / treatment.mean()
        elif effect == 'ate': 
            return phi

def est_casual_effect(data_loader, causal_bert, effect='ate', estimation='Q'):
    # We use `real_treatment` here to emphasize the estimations use real instead of estimated treatment.
    real_response, real_treatment = [], []
    prop_scores, Q1, Q0 = [], [], []

    causal_bert.eval()
    with torch.no_grad():
        for idx, (tokens, treatment, response) in enumerate(data_loader):
            real_response.append(response.cpu().data.numpy().squeeze())
            real_treatment.append(treatment.cpu().data.numpy().squeeze())

            prop_score, q1, q0 = causal_bert(tokens)

            prop_scores.append(prop_score.cpu().data.numpy().squeeze())
            Q1.append(q1.cpu().data.numpy().squeeze())
            Q0.append(q0.cpu().data.numpy().squeeze())

        real_response = np.concatenate(real_response, axis=0)
        real_treatment = np.concatenate(real_treatment, axis=0)

        Q1 = np.concatenate(Q1, axis=0)
        Q0 = np.concatenate(Q0, axis=0)
        prop_scores = np.concatenate(prop_scores, axis=0)
    
    causal_bert.train()
    
    if estimation == 'q':
        if effect == 'att':
            phi = (real_treatment * (Q1 - Q0))
            return phi.sum() / real_treatment.sum()            
        elif effect == 'ate':
            return (Q1 - Q0).mean()
        
    elif estimation == 'plugin':
        phi = (prop_scores * (Q1 - Q0)).mean()
        if effect == 'att':
            return phi / real_treatment.mean()
        elif effect == 'ate':
            return phi

# Prepare data

In [10]:
vocab = create_vocab(merged=True, uni_diag=True)

tokenizer = WordLevelBertTokenizer(vocab)

alpha = 0.25
beta = 10.
c = 0.2
i = 0

In [5]:
start = time.time()
trainset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                             alpha=alpha, beta=beta, c=c, i=i, 
                             group=list(range(1)), max_length=512, min_length=10,
                             truncate_method='first', device=device)

print(f'Load training set in {(time.time() - start):.2f} sec')

Load training set in 210.39 sec


In [11]:
start = time.time()
testset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                            alpha=alpha, beta=beta, c=c, i=i, 
                            group=[9], max_length=512, min_length=10,
                            truncate_method='first', device=device)

print(f'Load training set in {(time.time() - start):.2f} sec')

Load training set in 71.29 sec


In [7]:
train_loader = DataLoader(trainset, batch_size=20, drop_last=True, shuffle=True)
test_loader = DataLoader(testset, batch_size=20, drop_last=True, shuffle=True)

In [13]:
test_loader = DataLoader(testset, batch_size=16, drop_last=True, shuffle=True)

effect = 'ate'
estimation = 'Q'

effect = effect.lower()
estimation = estimation.lower()

real_effect = true_casual_effect(test_loader, effect, estimation)
print(real_effect)
# true_casual_effect(test_loader, effect, estimation)

print(f'unadjusted: {(testset.response[testset.treatment == 1].mean() - testset.response[testset.treatment == 0].mean()).item():.4f}')

0.06045784
unadjusted: 0.1481


# Create and train a Causal-Bert

In [8]:
# trained_bert = '/nfs/turbo/lsa-regier/bert-results/results/behrt/MLM/merged/unidiag/checkpoint-4574003/'
trained_bert = '/home/liutianc/emr/bert/results/behrt/MLM/merged/unidiag/checkpoint-6018425/'
bert = BertModel.from_pretrained(trained_bert).to(device)
causal_bert = CausalBert(bert, learnable_docu_embed=False, hidden_size=64).to(device)

epoch = 4
epoch_iter = len(train_loader)
total_steps = epoch * epoch_iter
optimizer = AdamW(causal_bert.parameters(),
                  lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

# optimizer = torch.optim.Adam(causal_bert.parameters(), lr=5e-5)

q_loss = nn.BCELoss()
prop_score_loss = nn.BCELoss()

# Please specify the effect and estimation we want to use here.
effect = 'ate'
estimation = 'Q'

effect = effect.lower()
estimation = estimation.lower()
assert effect in ['att', 'ate'], f'Wrong effect: {effect}...'
assert estimation in ['q', 'plugin'], f'Wrong estimation: {estimation}...'

In [9]:
real_att_q = true_casual_effect(test_loader, effect, estimation)
est_att_q = est_casual_effect(test_loader, causal_bert, effect, estimation)

print(f'Real: [effect: {effect}], [estimation: {estimation}], [value: {real_att_q:.5f}]')
print(f'Estimated: [effect: {effect}], [estimation: {estimation}], [value: {est_att_q:.5f}]')

Real: [effect: ate], [estimation: q], [value: 0.06272]
Estimated: [effect: ate], [estimation: q], [value: 0.65282]


In [10]:
for e in range(1, epoch + 1):
    causal_bert.train()
    start = time.time()
    
    rg_loss, rq1_loss, rq0_loss = [], [], []
    for idx, (tokens, treatment, response) in enumerate(train_loader):
        optimizer.zero_grad()
        prop_score, q1, q0 = causal_bert(tokens)
        
        g_loss  = prop_score_loss(prop_score, treatment)
        rg_loss.append(g_loss.item())

        g_loss.backward(retain_graph=True)
        
        if len(response[treatment == 1]) > 0:
            q1_loss = q_loss(q1[treatment==1], response[treatment==1])
            rq1_loss.append(q1_loss.item())
            
            q1_loss.backward(retain_graph=True)

        if len(response[treatment == 0]) > 0:
            q0_loss = q_loss(q0[treatment==0], response[treatment==0])
            rq0_loss.append(q0_loss.item())
            
            q0_loss.backward()
        
        torch.nn.utils.clip_grad_norm_(causal_bert.parameters(), 1.0)    
            
        optimizer.step()
        scheduler.step()
        
        if not idx % 2000 and idx > 0: 
            rg_loss = np.array(rg_loss)
            rq1_loss = np.array(rq1_loss)
            rq0_loss = np.array(rq0_loss)

            print(f'''epoch: {e}/{epoch}, iteration: {idx + 1}/{epoch_iter}, time: {(time.time() - start):.2f} sec, 
                  g_loss: {(rg_loss.mean()) :.5f}, q1_loss: {(rq1_loss.mean()):.5f}/{rq1_loss.shape[0]}, q0_loss: {(rq0_loss.mean()) :.5f}/{rq0_loss.shape[0]}, 
                  ''')
            rg_loss, rq1_loss, rq0_loss = [], [], []

    run_idx = idx
    
    # Evaluation.
    test_est_att_q = est_casual_effect(test_loader, causal_bert, effect, estimation)
    
    print(f'''epoch: {e} / {epoch}, 
          time cost: {(time.time() - start):.2f} sec, 
          Testing set effect: [effect: {effect}], [estimation: {estimation}], [value: {test_est_att_q:.5f}]''')
    start = time.time()

print('Finish training...')

# With only 3 groups to train.

epoch: 1/4, iteration: 2001/36073, time: 507.66 sec, 
                  g_loss: 28.56193, q1_loss: 3.53932/1978, q0_loss: 2.54718/2001, 
                  
epoch: 1/4, iteration: 4001/36073, time: 1025.19 sec, 
                  g_loss: 23.72188, q1_loss: 1.19654/1971, q0_loss: 0.90608/2000, 
                  
epoch: 1/4, iteration: 6001/36073, time: 1534.32 sec, 
                  g_loss: 23.75000, q1_loss: 0.96711/1971, q0_loss: 0.84969/2000, 
                  
epoch: 1/4, iteration: 8001/36073, time: 2043.89 sec, 
                  g_loss: 23.83750, q1_loss: 0.91957/1980, q0_loss: 0.79217/2000, 
                  
epoch: 1/4, iteration: 10001/36073, time: 2550.37 sec, 
                  g_loss: 23.71250, q1_loss: 0.86243/1971, q0_loss: 0.79140/2000, 
                  
epoch: 1/4, iteration: 12001/36073, time: 3056.05 sec, 
                  g_loss: 23.29688, q1_loss: 0.78312/1965, q0_loss: 0.78666/2000, 
                  
epoch: 1/4, iteration: 14001/36073, time: 3561.93 sec, 
 