In [1]:
import os
import time

import numpy as np
import torch
import pandas as pd
import random
from torch.distributions.binomial import Binomial
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 

from transformers import DataCollatorForLanguageModeling, BertForMaskedLM
from transformers import Trainer, TrainingArguments

from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from data import CausalBertDataset, MLMDataset
from causal_bert import CausalBert

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '4'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare data

In [3]:
vocab = create_vocab(merged=True, uni_diag=True)
tokenizer = WordLevelBertTokenizer(vocab)

In [4]:
dataset = CausalBertDataset(tokenizer=tokenizer, data_type='merged', is_unidiag=True,
                            group=[9], max_length=512, min_length=10,
                            truncate_method='first', device=device)

train_loader = DataLoader(dataset, batch_size=64, drop_last=True, shuffle=True)

# Create Causal-Bert

In [13]:
trained_bert = '/home/liutianc/emr/bert/results/behrt/MLM/merged/unidiag/checkpoint-2166633/'
model = BertForMaskedLM.from_pretrained(trained_bert)
token_embed = model.get_input_embeddings()

causalBert = CausalBert(token_embed).to(device)
optimizer = torch.optim.Adam(causalBert.parameters(), lr=5e-4)

q_loss = nn.BCELoss()
prop_score_loss = nn.BCELoss()

epoch = 10
running_loss = 0.
for e in range(1, epoch + 1):
    for idx, (tokens, treatment, response) in enumerate(train_loader):
        optimizer.zero_grad()
        prop_score, q1, q0 = causalBert(tokens)
        s_loss = prop_score_loss(prop_score, treatment)
        q1_loss = q_loss(q1[response==1], response[response==1])
        q0_loss = q_loss(q0[response==0], response[response==0])
        s_loss.backward(retain_graph=True)
        q1_loss.backward(retain_graph=True)
        q0_loss.backward()
        optimizer.step()
        
        running_loss += s_loss.item() + q1_loss.item() + q0_loss.item()

    print(f'epoch: {e} / {epoch}, running loss: {round(running_loss / (idx + 1), 5)}.')
    running_loss = 0.
    

epoch: 1 / 10, running loss: 0.48283.
epoch: 2 / 10, running loss: 0.41434.
epoch: 3 / 10, running loss: 0.40337.
epoch: 4 / 10, running loss: 0.39869.
epoch: 5 / 10, running loss: 0.39505.
epoch: 6 / 10, running loss: 0.39134.
epoch: 7 / 10, running loss: 0.38911.
epoch: 8 / 10, running loss: 0.38652.
epoch: 9 / 10, running loss: 0.38459.
epoch: 10 / 10, running loss: 0.38243.
