In [1]:
import os
import time

import numpy as np
import torch
import pandas as pd
import random
from torch.distributions.binomial import Binomial
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 

from transformers import DataCollatorForLanguageModeling, BertForMaskedLM
from transformers import Trainer, TrainingArguments

from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from data import CausalBertDataset, MLMDataset
from causal_bert import CausalBert

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def real_phi_q(dataset):
    
    Q1 = dataset.treatment * trainset.response + (1 - dataset.treatment) * dataset.pesudo_response
    Q1 = Q1.cpu().data.numpy().squeeze()

    Q0 = dataset.treatment * dataset.pesudo_response + (1 - dataset.treatment) * dataset.response
    Q0 = Q0.cpu().data.numpy().squeeze()

    real_treatment = dataset.treatment.cpu().data.numpy().squeeze()
    prop_scores = dataset.prop_score.cpu().data.numpy().squeeze()
    phi_q = (real_treatment * (Q1 - Q0)).sum() / real_treatment.sum()
    
    return phi_q

# Prepare data

In [4]:
vocab = create_vocab(merged=True, uni_diag=True)
tokenizer = WordLevelBertTokenizer(vocab)

In [5]:
trainset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                            group=list(range(3)), max_length=512, min_length=10,
                            truncate_method='first', device=device)

In [6]:
testset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                            group=[9], max_length=512, min_length=10,
                            truncate_method='first', device=device)

In [7]:
train_loader = DataLoader(trainset, batch_size=2048, drop_last=True, shuffle=True)
test_loader = DataLoader(testset, batch_size=1024, drop_last=True, shuffle=True)

# Create and train a Causal-Bert

In [8]:
trained_bert = '/home/liutianc/emr/bert/results/behrt/MLM/merged/unidiag/checkpoint-4574003/'
model = BertForMaskedLM.from_pretrained(trained_bert)
token_embed = model.get_input_embeddings()

causalBert = CausalBert(token_embed).to(device)
optimizer = torch.optim.Adam(causalBert.parameters(), lr=5e-4)

q_loss = nn.BCELoss()
prop_score_loss = nn.BCELoss()

In [9]:
real_response, real_treatment = [], []
prop_scores, Q1, Q0 = [], [], []

causalBert.eval()
for idx, (tokens, treatment, response) in enumerate(test_loader):
    real_response.append(response.cpu().data.numpy().squeeze())
    real_treatment.append(treatment.cpu().data.numpy().squeeze())
    
    prop_score, q1, q0 = causalBert(tokens)
    
    prop_scores.append(prop_score.cpu().data.numpy().squeeze())
    Q1.append(q1.cpu().data.numpy().squeeze())
    Q0.append(q0.cpu().data.numpy().squeeze())

real_response = np.concatenate(real_response, axis=0)
real_treatment = np.concatenate(real_treatment, axis=0)

Q1 = np.concatenate(Q1, axis=0)
Q0 = np.concatenate(Q0, axis=0)
prop_scores = np.concatenate(prop_scores, axis=0)
phi_q = (real_treatment * (Q1 - Q0)).sum() / real_treatment.sum()

print(f'Before train: Phi_Q: {phi_q:.5f}')

print(f'Real Phi_Q: {real_phi_q(trainset):.5f}')

Before train: Phi_Q: -0.42509


AttributeError: 'CausalBertDataset' object has no attribute 'pesudo_response'

In [None]:
epoch = 10
rs_loss, rq1_loss, rq0_loss = [0.] * 3

for e in range(1, epoch + 1):
    causalBert.train()
    start = time.time()
    for idx, (tokens, treatment, response) in enumerate(train_loader):
        optimizer.zero_grad()
        prop_score, q1, q0 = causalBert(tokens)
        
        s_loss = prop_score_loss(prop_score, treatment)
        q1_loss = q_loss(q1[response==1], response[response==1])
        q0_loss = q_loss(q0[response==0], response[response==0])
        
        s_loss.backward(retain_graph=True)
        q1_loss.backward(retain_graph=True)
        q0_loss.backward()
        optimizer.step()
        
        rs_loss += s_loss.item()
        rq1_loss + q1_loss.item()
        rq0_loss += q0_loss.item()
        
    run_idx = idx

    # Evaluation.
    real_response, real_treatment = [], []
    prop_scores, Q1, Q0 = [], [], []

    causalBert.eval()
    for idx, (tokens, treatment, response) in enumerate(train_loader):
        real_response.append(response.cpu().data.numpy().squeeze())
        real_treatment.append(treatment.cpu().data.numpy().squeeze())
        prop_score, q1, q0 = causalBert(tokens)
        prop_scores.append(prop_score.cpu().data.numpy().squeeze())
        Q1.append(q1.cpu().data.numpy().squeeze())
        Q0.append(q0.cpu().data.numpy().squeeze())

    real_response = np.concatenate(real_response, axis=0)
    real_treatment = np.concatenate(real_treatment, axis=0)

    Q1 = np.concatenate(Q1, axis=0)
    Q0 = np.concatenate(Q0, axis=0)
    prop_scores = np.concatenate(prop_scores, axis=0)
    phi_q = (real_treatment * (Q1 - Q0)).sum() / real_treatment.sum()
#     print(f'''epoch: {e} / {epoch}, 
#               time cost: {(time.time() - start):.2f} sec, 
#               g_loss: {(rs_loss / (run_idx + 1)):.5f}, 
#               q1_loss: {(q1_loss / (run_idx + 1)):.5f}, 
#               q0_loss: {(q0_loss / (run_idx + 1)):.5f}, 
#               Estimated Phi_Q: {phi_q:.5f}''')
    
    print(f'''epoch: {e} / {epoch}, 
          time cost: {(time.time() - start):.2f} sec, 
          g_loss: {(rs_loss)  :.5f}, 
          q1_loss: {(q1_loss) :.5f}, 
          q0_loss: {(q0_loss) :.5f}, 
          Estimated Phi_Q: {phi_q:.5f}''')
    start = time.time()
    running_loss = 0.

print('Finish training...')