In [None]:
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from collections import defaultdict

from sklearn.model_selection import KFold

from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler,SequentialSampler

from transformers import DistilBertTokenizer
from transformers import DistilBertModel, DistilBertPreTrainedModel
from transformers import get_linear_schedule_with_warmup

from torch.nn import CrossEntropyLoss

import torch
import torch.nn as nn
from torch.optim import AdamW

import numpy as np
from scipy.special import softmax
from scipy.special import logit
from sklearn.linear_model import LogisticRegression 

from tqdm import tqdm
import math

In [None]:
CUDA = (torch.cuda.device_count() > 0)
MASK_IDX = 103

In [None]:
def platt_scale(outcome,probs):
    logits = logit(probs)
    logits = logits.reshape(-1,1)
    log_reg = LogisticRegression(penalty='none', warm_start = True, solver = 'lbfgs' )
    log_reg.fit(logits, outcome)
    return log_reg.predict_proba(logits)

def gelu(x):
    return 0.5 * x * (1.0 + torch.erf(x/math.sqrt(2.0)))

In [None]:
def make_confound_vector(ids, vocab_size, use_counts = False):
    vec = torch.zeros(ids.shape[0],vocab_size)
    ones = torch.ones_like(ids,dtype = torch.float)
    if CUDA:
        vec = vec.cuda()
        ones = ones.cuda()
        ids = ids.cuda()
    vec.scatter_add_(1, ids,ones)
    vec[:,1] = 0.0
    if not use_counts:
        vec = (vec != 0).float()
    return vec

In [None]:
import timm
import torch
from torch import nn

class ImageCausalModel(nn.Module):
    """The model itself."""
    def __init__(self, num_labels = 2,pretrained_model_names = "resnet50"):
        super().__init__()

        self.num_labels = num_labels

        self.base_model = timm.create_model(pretrained_model_names,pretrained = True)
        self.base_model.fc = nn.Identity()

        # 因果推論用の追加レイヤー
        self.classifier = nn.Linear(self.base_model.num_features, num_labels)
        self.Q_cls = nn.ModuleDict()

        # self.base_model.num_features は、事前学習済みの画像モデルからの特徴量サイズです。
        input_size = self.base_model.num_features + self.num_labels

        for T in range(2):
            # ModuleDict keys have to be strings..
            self.Q_cls['%d' % T] = nn.Sequential(
                nn.Linear(input_size, 200),
                nn.ReLU(),
                nn.Linear(200, self.num_labels))
        

        self.g_cls = nn.Linear(input_size + self.num_labels, 
            self.num_labels)

        self.init_weights()
    
    def forward(self,images, confounders, treatment, outcome = None):
        features = self.base_model(images)

        C = meke_confound_vector(confounds, self.num_labels)

        inputs = torch.cat((features, C), dim =  1)
        
        g_logits = self.g_cls(inputs)
        g = torch.sigmoid(g_logits)

        if outcome is not None:
            g_loss = CrossEntropyLoss(g_logits, treatment)

            Q_logits_T0 = self.Q_cls['0'](inputs)
            Q_logits_T1 = self.Q_cls['1'](inputs)

            T0_indices = (treatment==0).nonzero(as_tuple = True)
            T1_indices = (treatment==1).nonzero(as_tuple = True)

            Q_loss_T0 = CrossEntropyLoss()(Q_logits_T0, outcome[T0_indices])
            Q_loss_T1= CrossEntropyLoss()(Q_logits_T1, outcome[T1_indices])

            Q_loss = Q_loss_T0 + Q_loss_T1

            total_loss = g_loss + Q_loss
            return g, Q_logits_T0, Q_logits_T1, total_loss
        else:
            return torch.sigmoid(Q_logits_T0), torch.sigmoid(Q_logits_T1),None
            

In [None]:
class CausalImageModelWrapper:
    def __init__(self, g_weight=1.0, Q_weight=0.1, mlm_weight=1.0, batch_size=32):
        self.model = ImageCausalModel(num_labels=2, pretrained_model_names="resnet50")
        if CUDA:
            self.model = self.model.cuda()

        self.loss_weights = {
            'g': g_weight,
            'Q': Q_weight
        }
        self.batch_size = batch_size