In [1]:
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from collections import defaultdict

from sklearn.model_selection import KFold

from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler,SequentialSampler

from transformers import DistilBertTokenizer
from transformers import DistilBertModel, DistilBertPreTrainedModel
from transformers import get_linear_schedule_with_warmup

from torch.nn import CrossEntropyLoss

import torch
import torch.nn as nn
from torch.optim import AdamW

import numpy as np
from scipy.special import softmax
from scipy.special import logit
from sklearn.linear_model import LogisticRegression 

from tqdm import tqdm
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CUDA = (torch.cuda.device_count() > 0)
MASK_IDX = 103

In [3]:
def platt_scale(outcome,probs):
    logits = logit(probs)
    logits = logits.reshape(-1,1)
    log_reg = LogisticRegression(penalty='none', warm_start = True, solver = 'lbfgs' )
    log_reg.fit(logits, outcome)
    return log_reg.predict_proba(logits)

def gelu(x):
    return 0.5 * x * (1.0 + torch.erf(x/math.sqrt(2.0)))


In [24]:
def make_bow_vector(ids, vocab_size, use_counts = False):
    vec = torch.zeros(ids.shape[0],vocab_size)
    ones = torch.ones_like(ids,dtype = torch.float)
    if CUDA:
        vec = vec.cuda()
        ones = ones.cuda()
        ids = ids.cuda()
    vec.scatter_add_(1, ids,ones)
    vec[:,1] = 0.0
    if not use_counts:
        vec = (vec != 0).float()
    return vec

In [25]:
class CausalBert(DistilBertPreTrainedModel):
    """The model itself."""
    def __init__(self, config):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vocab_size = config.vocab_size

        self.distilbert = DistilBertModel(config)
        # self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vocab_transform = nn.Linear(config.dim, config.dim)
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)

        self.Q_cls = nn.ModuleDict()

        for T in range(2):
            # ModuleDict keys have to be strings..
            self.Q_cls['%d' % T] = nn.Sequential(
                nn.Linear(config.hidden_size + self.num_labels, 200),
                nn.ReLU(),
                nn.Linear(200, self.num_labels))

        self.g_cls = nn.Linear(config.hidden_size + self.num_labels, 
            self.config.num_labels)

        self.init_weights()

    def forward(self, W_ids, W_len, W_mask, C, T, Y=None, use_mlm=True):
        if use_mlm:
            W_len = W_len.unsqueeze(1) - 2 # -2 because of the +1 below
            mask_class = torch.cuda.FloatTensor if CUDA else torch.FloatTensor
            mask = (mask_class(W_len.shape).uniform_() * W_len.float()).long() + 1 # + 1 to avoid CLS
            target_words = torch.gather(W_ids, 1, mask)
            mlm_labels = torch.ones(W_ids.shape).long() * -100
            if CUDA:
                mlm_labels = mlm_labels.cuda()
            mlm_labels.scatter_(1, mask, target_words)
            W_ids.scatter_(1, mask, MASK_IDX)

        outputs = self.distilbert(W_ids, attention_mask=W_mask)
        seq_output = outputs[0]
        pooled_output = seq_output[:, 0]
        # seq_output, pooled_output = outputs[:2]
        # pooled_output = self.dropout(pooled_output)

        if use_mlm:
            prediction_logits = self.vocab_transform(seq_output)  # (bs, seq_length, dim)
            prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)
            prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
            prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)
            mlm_loss = CrossEntropyLoss()(
                prediction_logits.view(-1, self.vocab_size), mlm_labels.view(-1))
        else:
            mlm_loss = 0.0
    
        C_bow = make_bow_vector(C.unsqueeze(1), self.num_labels)
        inputs = torch.cat((pooled_output, C_bow), 1)

        print(inputs.shape)    
        # g logits
        g = self.g_cls(inputs)
        
        if Y is not None:  # TODO train/test mode, this is a lil hacky
            g_loss = CrossEntropyLoss()(g.view(-1, self.num_labels), T.view(-1))
        else:
            g_loss = 0.0

        # conditional expected outcome logits: 
        # run each example through its corresponding T matrix
        # TODO this would be cleaner with sigmoid and BCELoss, but less general 
        #   (and I couldn't get it to work as well)
        Q_logits_T0 = self.Q_cls['0'](inputs)
        Q_logits_T1 = self.Q_cls['1'](inputs)

        if Y is not None:
            T0_indices = (T == 0).nonzero().squeeze()
            Y_T1_labels = Y.clone().scatter(0, T0_indices, -100)

            T1_indices = (T == 1).nonzero().squeeze()
            Y_T0_labels = Y.clone().scatter(0, T1_indices, -100)

            Q_loss_T1 = CrossEntropyLoss()(
                Q_logits_T1.view(-1, self.num_labels), Y_T1_labels)
            Q_loss_T0 = CrossEntropyLoss()(
                Q_logits_T0.view(-1, self.num_labels), Y_T0_labels)

            Q_loss = Q_loss_T0 + Q_loss_T1
        else:
            Q_loss = 0.0

        sm = nn.Softmax(dim=1)
        Q0 = sm(Q_logits_T0)[:, 1]
        Q1 = sm(Q_logits_T1)[:, 1]
        g = sm(g)[:, 1]

        return g, Q0, Q1, g_loss, Q_loss, mlm_loss

In [26]:
class CausalBertLightningModule(pl.LightningModule):
    def __init__(self,g_weight = 1.0, Q_weight = 0.1, mlm_weight = 1.0, learning_rate = 2e-5,total_training_steps= None):
        super().__init__()
        self.save_hyperparameters()
        self.model = CausalBert.from_pretrained(
            "distilbert-base-uncased",
            num_labels = 2,
            output_attentions = False,
            output_hidden_states = False
        )

        self.loss_weights = {
            'g':g_weight,
            'Q':Q_weight,
            'mlm': mlm_weight
        }

        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()
        self.total_training_steps = total_training_steps
    
    def forward(self,W_ids, W_len, W_mask, C,T,Y = None, use_mlm = None):
        return self.model(W_ids, W_len, W_mask,  C,T,Y, use_mlm)
    
    def training_step(self,batch,batch_idx):
        W_ids,W_len,W_mask, C,T,Y = batch
        g,Q0,Q1,g_loss,Q_loss,mlm_loss = self(W_ids, W_len,W_mask, C,T,Y)
        loss = (self.loss_weights['g'] * g_loss + 
                self.loss_weights['Q'] * Q_loss + 
                self.loss_weights['mlm'] * mlm_loss)
        self.log("train_loss",loss)
        return loss
    
    def validation_step(self,batch,batch_idx):
        W_ids, W_len, W_mask, C,T,Y = batch
        g, Q0, Q1, g_loss, Q_loss, _ = self(W_ids, W_len, W_mask, C, T, Y, use_mlm=False)
        loss = self.loss_weights['g'] * g_loss + self.loss_weights['Q'] * Q_loss
        self.log("val_loss",loss)
        return loss


    def predict_step(self,batch,batch_idx):
        W_ids, W_len, W_mask, C, T, Y = batch
        g, Q0, Q1, _, _, _ = self(W_ids, W_len, W_mask, C, T, Y=None, use_mlm=False)
        return Q0, Q1, Y
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr = self.learning_rate, eps = 1e-8)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(0.1 * self.total_training_steps),
            num_training_steps=self.total_training_steps
        )

        return [optimizer], [scheduler]

In [27]:
class CausalBertDataModule(pl.LightningDataModule):
    def __init__(self,texts, confounds, treatments, outcomes, tokenizer, batch_size = 32):
        super().__init__()
        self.texts = texts
        self.confounds = confounds
        self.treatments = treatments
        self.outcomes = outcomes
        self.tokenizer = tokenizer
        self.batch_size = batch_size
    def setup(self,stage = None):
        inputs = self.tokenizer(
            self.texts.tolist(),
            add_special_tokens = True,
            max_length = 128,
            truncation = True,
            padding = "max_length",
            return_tensors = "pt"
        )
        W_len = inputs['attention_mask'].sum(dim=1)
        dataset = TensorDataset(
            inputs["input_ids"],
            W_len,
            inputs["attention_mask"],
            torch.tensor(self.confounds.values),
            torch.tensor(self.treatments.values),
            torch.tensor(self.outcomes.values)
        )
        if stage == 'fit' or stage is None:
            # データの分割（例：80%をトレーニング、20%を検証）
            train_size = int(0.8 * len(dataset))
            val_size = len(dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
        if stage == 'predict' or stage is None:
            self.predict_dataset = dataset  # 推論用データセッ

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)
    
    def predict_dataloader(self):
        return DataLoader(self.predict_dataset, batch_size=self.batch_size)


In [28]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', do_lower_case = True)



In [29]:
import pandas as pd
df = pd.read_csv("../testdata.csv")

In [30]:
data_module = CausalBertDataModule(
    texts = df['text'],
    confounds = df['C'],
    treatments = df['T'],
    outcomes = df['Y'],
    tokenizer = tokenizer,
    batch_size = 32
)

data_module.setup(stage = "fit")
train_dataset_size = len(data_module.train_dataset)
steps_per_epoch = train_dataset_size // data_module.batch_size
total_training_steps = steps_per_epoch * 3  # max_epochs = 3

model = CausalBertLightningModule(
    g_weight = 0.1,
    Q_weight = 0.1,
    mlm_weight = 1.0,
    learning_rate = 2e-5,
    total_training_steps=total_training_steps
)

trainer = Trainer(
    max_epochs = 3,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None,
    enable_progress_bar=True
)
trainer.fit(model,data_module)

Epoch 0:  14%|█▍        | 66/461 [12:04<1:12:17,  0.09it/s, v_num=10]


Some weights of CausalBert were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['Q_cls.0.0.bias', 'Q_cls.0.0.weight', 'Q_cls.0.2.bias', 'Q_cls.0.2.weight', 'Q_cls.1.0.bias', 'Q_cls.1.0.weight', 'Q_cls.1.2.bias', 'Q_cls.1.2.weight', 'g_cls.bias', 'g_cls.weight', 'vocab_projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | CausalBert       | 90.7 M | eval 
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
90.7 M    Trainable params
0         Non-trainable params
90.7 M    Total params
362.949   Total estimated model params size (MB)
1

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]C tensor([0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 0, 1, 1], device='cuda:0')
torch.Size([32])
torch.Size([32, 1])
C_bow torch.Size([32, 2])
C_bow_val tensor([[1., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  8.12it/s]

/root/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


C tensor([0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0,
        1, 0, 1, 0, 1, 1, 0, 0], device='cuda:0')
torch.Size([32])
torch.Size([32, 1])
C_bow torch.Size([32, 2])
C_bow_val tensor([[1., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [1., 0.]], device='cuda:0')
                                                                           

/root/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/461 [00:00<?, ?it/s] C tensor([1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        0, 1, 1, 0, 1, 0, 1, 1], device='cuda:0')
torch.Size([32])
torch.Size([32, 1])
C_bow torch.Size([32, 2])
C_bow_val tensor([[0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
Epoch 0:   0%|          | 1/461 [00:00<01:15,  6.06it/s, v_num=12]C tensor([0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
    

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 461/461 [02:21<00:00,  3.26it/s, v_num=12]


In [31]:
# データモジュールのセットアップ
data_module.setup(stage='predict')

# 推論の実行
predictions = trainer.predict(model, datamodule=data_module)

# 予測結果の取得
Q0s = []
Q1s = []
Ys = []

for batch in predictions:
    Q0_batch, Q1_batch, Y_batch = batch
    Q0s.extend(Q0_batch.detach().cpu().numpy())
    Q1s.extend(Q1_batch.detach().cpu().numpy())
    Ys.extend(Y_batch.detach().cpu().numpy() if Y_batch is not None else [None]*len(Q0_batch))

# ATEの計算
Q0s = np.array(Q0s)
Q1s = np.array(Q1s)
ATE = np.mean(Q1s - Q0s)
print(f"Average Treatment Effect (ATE): {ATE}")

KeyboardInterrupt: 

Unnamed: 0.1,Unnamed: 0,text,T,C,Y
0,0,this is a great cd full of worship favorites!!...,1,1,0
1,1,keith green had his special comedy style of ch...,1,1,1
2,2,keith green was a true gift of god. his music ...,1,0,0
3,3,keith's music is a timeless message. since hi...,1,0,0
4,4,"if you're looking for a meditative, contemplat...",1,0,1
