In [1]:
import os
import time

import numpy as np
import torch
import pandas as pd
import random
from torch.distributions.binomial import Binomial
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 

from transformers import DataCollatorForLanguageModeling, BertForMaskedLM
from transformers import Trainer, TrainingArguments

from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from data import CausalBertDataset, MLMDataset
from causal_bert import CausalBert

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def true_casual_effect(data_loader, effect='ate', estimation='Q'):
    dataset = data_loader.dataset
    
    Q1 = dataset.treatment * dataset.response + (1 - dataset.treatment) * dataset.pseudo_response
    Q1 = Q1.cpu().data.numpy().squeeze()

    Q0 = dataset.treatment * dataset.pseudo_response + (1 - dataset.treatment) * dataset.response
    Q0 = Q0.cpu().data.numpy().squeeze()

    treatment = dataset.treatment.cpu().data.numpy().squeeze()
    prop_scores = dataset.prop_score.cpu().data.numpy().squeeze()
    
    if estimation == 'q':
        if effect == 'att':
            phi = (treatment * (Q1 - Q0))
            return phi.sum() / treatment.sum()
        elif effect == 'ate':
            return (Q1 - Q0).mean()
        
    elif estimation == 'plugin':
        phi = (prop_scores * (Q1 - Q0)).mean()
        if effect == 'att':
            return phi / treatment.mean()
        elif effect == 'ate': 
            return phi

def est_casual_effect(data_loader, causal_bert, effect='ate', estimation='Q'):
    # We use `real_treatment` here to emphasize the estimations use real instead of estimated treatment.
    real_response, real_treatment = [], []
    prop_scores, Q1, Q0 = [], [], []

    causal_bert.eval()
    for idx, (tokens, treatment, response) in enumerate(data_loader):
        real_response.append(response.cpu().data.numpy().squeeze())
        real_treatment.append(treatment.cpu().data.numpy().squeeze())

        prop_score, q1, q0 = causal_bert(tokens)

        prop_scores.append(prop_score.cpu().data.numpy().squeeze())
        Q1.append(q1.cpu().data.numpy().squeeze())
        Q0.append(q0.cpu().data.numpy().squeeze())

    real_response = np.concatenate(real_response, axis=0)
    real_treatment = np.concatenate(real_treatment, axis=0)

    Q1 = np.concatenate(Q1, axis=0)
    Q0 = np.concatenate(Q0, axis=0)
    prop_scores = np.concatenate(prop_scores, axis=0)
    
    causal_bert.train()
    
    if estimation == 'q':
        if effect == 'att':
            phi = (real_treatment * (Q1 - Q0))
            return phi.sum() / real_treatment.sum()            
        elif effect == 'ate':
            return (Q1 - Q0).mean()
        
    elif estimation == 'plugin':
        phi = (prop_scores * (Q1 - Q0)).mean()
        if effect == 'att':
            return phi / real_treatment.mean()
        elif effect == 'ate':
            return phi

# Prepare data

In [4]:
vocab = create_vocab(merged=True, uni_diag=True)

tokenizer = WordLevelBertTokenizer(vocab)

In [5]:
start = time.time()
trainset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                             alpha=0.75, beta=1., c=0.5, 
                             group=list(range(9)), max_length=512, min_length=10,
                             truncate_method='first', device=device)

print(f'Load training set in {(time.time() - start):.2f} sec')

In [6]:
start = time.time()
testset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                            alpha=0.75, beta=1., c=0.5,
                            group=[9], max_length=512, min_length=10,
                            truncate_method='first', device=device)

print(f'Load training set in {(time.time() - start):.2f} sec')

In [7]:
train_loader = DataLoader(trainset, batch_size=2048, drop_last=True, shuffle=True)
test_loader = DataLoader(testset, batch_size=1024, drop_last=True, shuffle=True)

# Create and train a Causal-Bert

In [8]:
trained_bert = '/nfs/turbo/lsa-regier/bert-results/results/behrt/MLM/merged/unidiag/checkpoint-4574003/'
model = BertForMaskedLM.from_pretrained(trained_bert)
token_embed = model.get_input_embeddings()

causal_bert = CausalBert(token_embed, learnable_docu_embed=False).to(device)
optimizer = torch.optim.Adam(causal_bert.parameters(), lr=5e-4)

q_loss = nn.BCELoss()
prop_score_loss = nn.BCELoss()

# Please specify the effect and estimation we want to use here.
effect = 'ate'
estimation = 'Q'

effect = effect.lower()
estimation = estimation.lower()
assert effect in ['att', 'ate'], f'Wrong effect: {effect}...'
assert estimation in ['q', 'plugin'], f'Wrong estimation: {estimation}...'

In [9]:
real_att_q = true_casual_effect(test_loader, effect, estimation)
est_att_q = est_casual_effect(test_loader, causal_bert, effect, estimation)

print(f'Real: [effect: {effect}], [estimation: {estimation}], [value: {real_att_q:.5f}]')
print(f'Estimated: [effect: {effect}], [estimation: {estimation}], [value: {est_att_q:.5f}]')

Real: [effect: ate], [estimation: q], [value: 0.18407]
Estimated: [effect: ate], [estimation: q], [value: -0.91199]


In [None]:
epoch = 10
rs_loss, rq1_loss, rq0_loss = [0.] * 3

for e in range(1, epoch + 1):
    causal_bert.train()
    start = time.time()
    for idx, (tokens, treatment, response) in enumerate(train_loader):
        optimizer.zero_grad()
        prop_score, q1, q0 = causal_bert(tokens)
        
        s_loss  = prop_score_loss(prop_score, treatment)
        q1_loss = q_loss(q1[treatment==1], response[treatment==1])
        q0_loss = q_loss(q0[treatment==0], response[treatment==0])
        
        s_loss.backward(retain_graph=True)
        q1_loss.backward(retain_graph=True)
        q0_loss.backward()
        optimizer.step()
        
        rs_loss  += s_loss.item()
        rq1_loss += q1_loss.item()
        rq0_loss += q0_loss.item()
        
    run_idx = idx

    # Evaluation.
    train_est_att_q = est_casual_effect(train_loader, causal_bert, effect, estimation)
    test_est_att_q = est_casual_effect(test_loader, causal_bert, effect, estimation)
    
    print(f'''epoch: {e} / {epoch}, 
          time cost: {(time.time() - start):.2f} sec, 
          g_loss: {(rs_loss / (run_idx + 1))  :.5f}, 
          q1_loss: {(rq1_loss / (run_idx + 1)):.5f}, 
          q0_loss: {(rq0_loss/ (run_idx + 1)) :.5f}, 
          Training set effect: [effect: {effect}], [estimation: {estimation}], [value: {train_est_att_q:.5f}],
          Testing set effect: [effect: {effect}], [estimation: {estimation}], [value: {test_est_att_q:.5f}]''')
    start = time.time()
    rs_loss, rq1_loss, rq0_loss = [0.] * 3

print('Finish training...')

# With only 3 groups to train.

epoch: 1 / 10, 
          time cost: 306.29 sec, 
          g_loss: 0.67316, 
          q1_loss: 0.68353, 
          q0_loss: 0.70161, 
          Training set effect: [effect: ate], [estimation: q], [value: 0.16977],
          Testing set effect: [effect: ate], [estimation: q], [value: 0.16999]
