In [1]:
from collections import defaultdict

from sklearn.model_selection import KFold

from torch.utils.data import Dataset, DataLoader, RandomSampler,SequentialSampler

from transformers import get_linear_schedule_with_warmup

from torch.nn import CrossEntropyLoss

import torch
import torch.nn as nn
from torch.optim import AdamW

from PIL import Image
import torchvision.transforms as transforms

import numpy as np
from scipy.special import softmax
from scipy.special import logit
from sklearn.linear_model import LogisticRegression 
from sklearn.metrics import accuracy_score

from tqdm import tqdm
import math

  from .autonotebook import tqdm as notebook_tqdm


### wandbの設定

In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhayatarou-ay[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
CUDA = (torch.cuda.device_count() > 0)
MASK_IDX = 103

In [4]:
def platt_scale(outcome,probs):
    logits = logit(probs)
    logits = logits.reshape(-1,1)
    log_reg = LogisticRegression(penalty='none', warm_start = True, solver = 'lbfgs' )
    log_reg.fit(logits, outcome)
    return log_reg.predict_proba(logits)

def gelu(x):
    return 0.5 * x * (1.0 + torch.erf(x/math.sqrt(2.0)))

In [5]:
def make_confound_vector(ids, vocab_size, use_counts = False):
    vec = torch.zeros(ids.shape[0],vocab_size)
    ones = torch.ones_like(ids,dtype = torch.float)
    
    if CUDA:
        vec = vec.cuda()
        ones = ones.cuda()
        ids = ids.cuda()
    vec[:,1] = 0.0
    if not use_counts:
        vec = (vec != 0).float()
    return vec.float()

In [6]:
import timm
import torch
from torch import nn

class ImageCausalModel(nn.Module):
    """The model itself."""
    def __init__(self, num_labels = 2,pretrained_model_names = "resnet50"):
        super().__init__()

        self.num_labels = num_labels

        self.base_model = timm.create_model(pretrained_model_names,pretrained = True)
        self.base_model.fc = nn.Identity()

        # 因果推論用の追加レイヤー
        self.classifier = nn.Linear(self.base_model.num_features, num_labels)
        self.Q_cls = nn.ModuleDict()

        # self.base_model.num_features は、事前学習済みの画像モデルからの特徴量サイズです。
        input_size = self.base_model.num_features + self.num_labels

        for T in range(2):
            # ModuleDict keys have to be strings..
            self.Q_cls['%d' % T] = nn.Sequential(
                nn.Linear(input_size, 200),
                nn.ReLU(),
                nn.Linear(200, self.num_labels))
        

        self.g_cls = nn.Linear(self.base_model.num_features + self.num_labels, 
            self.num_labels)

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self,images, confounds, treatment=None, outcome = None):
        features = self.base_model(images)
        # print("features")
        # print("C", confounds.shape)
        # print(confounds)
        # print(confounds.unsqueeze(1).shape)
        C = make_confound_vector(confounds.unsqueeze(1), self.num_labels)
        # print("C",C.shape) 
        inputs = torch.cat((features, C), dim =  1)
        g_logits = self.g_cls(inputs)
        g_prob = torch.sigmoid(g_logits)

        Q_logits_T0 = self.Q_cls['0'](inputs)
        Q_logits_T1 = self.Q_cls['1'](inputs)

        Q_prob_T0 = torch.sigmoid(Q_logits_T0)
        Q_prob_T1 = torch.sigmoid(Q_logits_T1)
        if outcome is not None:
            return g_prob, Q_prob_T0, Q_prob_T1, g_logits, Q_logits_T0, Q_logits_T1
        else:
            return g_prob, Q_prob_T0, Q_prob_T1,
        

In [7]:

class CausalImageModelWrapper:
    def __init__(self, g_weight=1.0, Q_weight=0.1, mlm_weight=1.0, batch_size=32):
        self.model = ImageCausalModel(num_labels=2, pretrained_model_names="resnet50")
        if CUDA:
            self.model = self.model.cuda()

        self.loss_weights = {
            'g': g_weight,
            'Q': Q_weight
        }
        self.batch_size = batch_size
        self.losses = []

    def train(self,images, confounds, treatments, outcomes , learning_rate = 2e-5, epochs  = 3):

        wandb.init(project = "image_causal_project", config = {
            "learning_rate":learning_rate,
            "epochs": epochs,
            "batch_size": self.batch_size,
            "g_weight": self.loss_weights['g'],
            "Q_weight": self.loss_weights['Q']
        })

        dataloader = self.build_dataloader(images, confounds, treatments, outcomes, batch_size = self.batch_size)
        self.model.train()
        optimizer = AdamW(self.model.parameters(), lr = learning_rate, eps = 1e-8)
        total_steps = len(dataloader) * epochs
        warmup_steps = total_steps * 0.1
        scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps = warmup_steps,num_training_steps = total_steps)

        for epoch in range(epochs):
            epoch_losses = []
            g_losses = []
            Q_losses = []
            g_accuracies = []
            Q_accuracies_T0 = []
            Q_accuracies_T1 = []
            for batch in dataloader:
                if CUDA:
                    batch = tuple(x.cuda() for x in batch)
                images, confounds, treatments, outcomes = batch

                self.model.zero_grad()
                g_prob, Q_prob_T0, Q_prob_T1, g_logits, Q_logits_T0, Q_logits_T1 = self.model(images, confounds, treatments, outcomes)
                g_loss = CrossEntropyLoss()(g_logits, treatments)
                Q_loss_T0 = CrossEntropyLoss()(Q_logits_T0, outcomes)
                Q_loss_T1 = CrossEntropyLoss()(Q_logits_T1, outcomes)

                # Accuracy計算
                g_preds = torch.argmax(g_logits, dim=1)
                Q_preds_T0 = torch.argmax(Q_logits_T0, dim=1)
                Q_preds_T1 = torch.argmax(Q_logits_T1, dim=1)

                g_accuracy = accuracy_score(treatments.cpu().numpy(), g_preds.cpu().numpy())
                Q_accuracy_T0 = accuracy_score(outcomes.cpu().numpy(), Q_preds_T0.cpu().numpy())
                Q_accuracy_T1 = accuracy_score(outcomes.cpu().numpy(), Q_preds_T1.cpu().numpy())
                g_accuracies.append(g_accuracy)
                Q_accuracies_T0.append(Q_accuracy_T0)
                Q_accuracies_T1.append(Q_accuracy_T1)

                loss = self.loss_weights['g'] * g_loss + self.loss_weights['Q'] * (Q_loss_T0 + Q_loss_T1)
                loss.backward()
                optimizer.step()
                scheduler.step()
                g_losses.append(g_loss.detach().cpu().item())
                Q_losses.append((Q_loss_T0 + Q_loss_T1).detach().cpu().item())
                epoch_losses.append(loss.detach().cpu().item())
            
            avg_loss = sum(epoch_losses) / len(epoch_losses)
            avg_g_loss = sum(g_losses) / len(g_losses)
            avg_Q_loss = sum(Q_losses) / len(Q_losses)
            avg_g_accuracy = sum(g_accuracies) / len(g_accuracies)
            avg_Q_accuracy_T0 = sum(Q_accuracies_T0) / len(Q_accuracies_T0)
            avg_Q_accuracy_T1 = sum(Q_accuracies_T1) / len(Q_accuracies_T1)
            wandb.log({
                "epoch": epoch, 
                "loss":avg_loss, 
                "g_losses":avg_g_loss, 
                "Q_losses" : avg_Q_loss,
                "g_accuracy": avg_g_accuracy,
                "Q_accuracy_T0": avg_Q_accuracy_T0,
                "Q_accuracy_T1": avg_Q_accuracy_T1
                })
        wandb.watch(self.model)
        wandb.finish()            
        return self.model
    
    def inference(self, images, confounds, outcome = None):
        self.model.eval()
        dataloader = self.build_dataloader(images, confounds,outcomes = outcome,
                                           sampler = 'sequential')
        Q0s = []
        Q1s = []
        Ys = []
        for i, batch in tqdm(enumerate(dataloader),total = len(dataloader)):
            if CUDA: 
                batch = (x.cuda() for x in batch)
            images, confounds, _ ,outcomes = batch
            g_prob, Q0, Q1 ,_,_,_= self.model(images, confounds, outcome = outcomes)

            
            Q0s += Q0.detach().cpu().numpy().tolist()
            Q1s += Q1.detach().cpu().numpy().tolist()
            Ys += outcomes.detach().cpu().numpy().tolist()

            ## [todo] inferenceメソッドの形式?
        probs = np.array(list(zip(Q0s, Q1s)))
        preds = np.argmax(probs, axis = 1)  

        return probs, preds, Ys
    
    def ATE(self,C,image, Y = None, platt_scaling = False):
        ## [todo] ATEの計算方法
        Q_probs,_,Ys = self.inference(image,C,outcome = Y)
        if platt_scaling and Y is not None:
            Q0 = platt_scale(Ys, Q_probs[:,0])[:,0]
            Q1 = platt_scale(Ys, Q_probs[:,1])[:,1]
        else:
            Q0 = Q_probs[:,0]
            Q1 = Q_probs[:,1]
        
        print("Q0:", Q0, "Q1:", Q1)
        return np.mean(Q0 - Q1)

    def build_dataloader(self,image_paths, confounds, treatments = None, outcomes = None,batch_size = 32,sampler = "random"):
        dataset = CausalImageDataset(image_paths, confounds, treatments, outcomes)
        sampler = RandomSampler(dataset) if sampler == "random" else SequentialSampler(dataset)
        dataloader = DataLoader(dataset, batch_size = batch_size,sampler = sampler)
        return dataloader
    


In [8]:
class CausalImageDataset(Dataset):
    def __init__(self,image_paths, confounds, treatments = None, outcomes = None,transform = None):
        self.image_paths = image_paths
        self.confounds = confounds
        self.treatments = treatments
        self.outcomes = outcomes

        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406] , std = [0.229, 0.224, 0.225])
                ]
            )
        else:
            self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self,idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        
        confounds = self.confounds[idx]
        treatment = self.treatments[idx] if self.treatments is not None else -1
        outcome = self.outcomes[idx] if self.outcomes is not None else -1
        return image , confounds, treatment, outcome

    

In [9]:
import sys 
sys.path.append("../../")
import pandas as pd
df = pd.read_csv("../input/outputs_v4.csv")
df.head()

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,name,main_category,sub_category,image,link,ratings,no_of_ratings,...,embedding_path,embedding,price_ave,output,output_2v,brightness,light_or_dark,outcome,y0,y1
0,0,0,72,AmazonBasics High Speed 55 Watt Oscillating Pe...,appliances,All Appliances,https://m.media-amazon.com/images/I/71QfUcEOg8...,https://www.amazon.in/AmazonBasics-400mm-Pedes...,4.1,6113,...,/root/graduation_thetis/causal-bert-pytorch/in...,29.492722,1,1863.892722,1,191.928625,light,1,0.768525,0.645656
1,1,1,73,Farberware Mini Blender Fruit Mixer Machine Po...,appliances,All Appliances,https://m.media-amazon.com/images/I/716mmFt0PG...,https://www.amazon.in/Farberware-Portable-Elec...,2.9,6071,...,/root/graduation_thetis/causal-bert-pytorch/in...,68.038506,0,1889.338506,1,199.540718,light,0,0.768525,0.645656
2,2,2,74,PHILIPS Handheld Garment Steamer STH3000/20 - ...,appliances,All Appliances,https://m.media-amazon.com/images/I/71W2XPQdBq...,https://www.amazon.in/PHILIPS-Handheld-Garment...,4.0,1553,...,/root/graduation_thetis/causal-bert-pytorch/in...,43.871647,1,510.271647,1,199.71166,light,1,0.768525,0.645656
3,3,3,75,"Cookwell Bullet Mixer Grinder (5 Jars, 3 Blade...",appliances,All Appliances,https://m.media-amazon.com/images/I/81yobRRV8n...,https://www.amazon.in/Cookwell-Bullet-Mixer-Gr...,4.1,9592,...,/root/graduation_thetis/causal-bert-pytorch/in...,45.319656,1,2923.419656,1,206.265611,light,1,0.768525,0.645656
4,4,4,76,"Bajaj ATX 4 750-Watt Pop-up Toaster, 2-Slice A...",appliances,All Appliances,https://m.media-amazon.com/images/I/51D5T7TGVb...,https://www.amazon.in/Bajaj-ATX-750-Watt-Pop-u...,4.3,9520,...,/root/graduation_thetis/causal-bert-pytorch/in...,51.195602,0,2907.195602,1,214.465127,light,1,0.768525,0.645656


In [10]:
df["light_or_dark"] = df["light_or_dark"].apply(lambda x : 1 if x == "light" else 0)

In [11]:
ci = CausalImageModelWrapper(batch_size = 32, g_weight=0.1, Q_weight=0.1)
ci.train(df["img_path"],df["light_or_dark"], df["price_ave"], df["outcome"],epochs = 32)


0,1
Q_accuracy_T0,▁▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇██████████
Q_accuracy_T1,▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▅▅▆▆▇▇▇▇▇▇▇▇█████
Q_losses,█▇▇▆▆▆▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███
g_accuracy,▁▃▅▅▅▅▆▆▇▇▇▇▇▇▇█████████████████
g_losses,█▇▆▆▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
Q_accuracy_T0,0.86789
Q_accuracy_T1,0.86809
Q_losses,0.61387
epoch,31.0
g_accuracy,0.84997
g_losses,0.36198
loss,0.09758


ImageCausalModel(
  (base_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act2): ReLU(inplace=True)
        (aa): Identity()
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps

In [12]:
preds, probs , Ys = ci.inference(df["img_path"], df["light_or_dark"],df["outcome"])

100%|██████████| 160/160 [00:30<00:00,  5.30it/s]


array([0, 0])

In [32]:
preds

array([[[0.13765283, 0.8347199 ],
        [0.10919077, 0.81690699]],

       [[0.45538336, 0.40710497],
        [0.41515949, 0.34420407]],

       [[0.24416827, 0.74157315],
        [0.20150211, 0.68436348]],

       ...,

       [[0.68406391, 0.38542143],
        [0.58817196, 0.31646666]],

       [[0.19861032, 0.76638579],
        [0.24582973, 0.81811768]],

       [[0.07397657, 0.92001796],
        [0.07027964, 0.90107787]]])

In [49]:
Q0 = preds[:,0]
Q0

array([[0.13765283, 0.8347199 ],
       [0.45538336, 0.40710497],
       [0.24416827, 0.74157315],
       ...,
       [0.68406391, 0.38542143],
       [0.19861032, 0.76638579],
       [0.07397657, 0.92001796]])

In [48]:
Q1 = preds[:,1]
Q1

array([[0.10919077, 0.81690699],
       [0.41515949, 0.34420407],
       [0.20150211, 0.68436348],
       ...,
       [0.58817196, 0.31646666],
       [0.24582973, 0.81811768],
       [0.07027964, 0.90107787]])

In [52]:
np.mean((Q0-Q1)[:,1])

0.028261613450628092

In [54]:
np.mean((Q0-Q1)[:,0])

0.02358581546500979

In [37]:
Q1 = preds[:,1]
Q11 = Q1[:,1]
Q11

array([0.81690699, 0.34420407, 0.68436348, ..., 0.31646666, 0.81811768,
       0.90107787])

In [27]:
Q1-Q0

array([[-0.02846206, -0.01781291],
       [-0.04022387, -0.0629009 ],
       [-0.04266615, -0.05720967],
       ...,
       [-0.09589195, -0.06895477],
       [ 0.04721941,  0.05173188],
       [-0.00369693, -0.01894009]])

In [19]:
print(preds.shape, probs.shape, len(Ys))

(5111, 2, 2) (5111, 2) 5111


In [19]:
print(ci.ATE(df["light_or_dark"], df["img_path"],df["outcome"], platt_scaling = False))

100%|██████████| 160/160 [00:43<00:00,  3.66it/s]


ValueError: Classification metrics can't handle a mix of binary and multilabel-indicator targets

In [15]:
from collections import defaultdict
import numpy as np
def ATE_unadjusted(T, Y):
    x = defaultdict(list)
    for t, y in zip(T, Y):
        x[t].append(y)
    T0 = np.mean(x[0])
    T1 = np.mean(x[1])
    return T0 - T1

def ATE_adjusted(C, T, Y):
    x = defaultdict(list)
    for c, t, y in zip(C, T, Y):
        x[c, t].append(y)

    C0_ATE = np.mean(x[0,0]) - np.mean(x[0,1])
    C1_ATE = np.mean(x[1,0]) - np.mean(x[1,1])
    print(C0_ATE, C1_ATE)
    print(x)
    return np.mean([C0_ATE, C1_ATE])


print("ATE_unadjusted: ", ATE_unadjusted(df["price_ave"], df["outcome"]))
print("ATE_adjusted: ", ATE_adjusted(df["light_or_dark"], df["price_ave"],df["outcome"]))

ATE_unadjusted:  0.10267981830743678
0.07905631183025796 0.11232609131180016
defaultdict(<class 'list'>, {(1, 1): [1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 

In [15]:
ci.inference(df["img_path"], df["light_or_dark"])

100%|██████████| 160/160 [00:56<00:00,  2.82it/s]


(array([[[0.21575224, 0.64928162],
         [0.35140994, 0.78701228]],
 
        [[0.54956341, 0.54114825],
         [0.55252564, 0.52543193]],
 
        [[0.31131324, 0.57314974],
         [0.45010456, 0.61553663]],
 
        ...,
 
        [[0.76871824, 0.30006468],
         [0.75267065, 0.24299857]],
 
        [[0.25023708, 0.70237416],
         [0.3343994 , 0.80563241]],
 
        [[0.05974026, 0.8897146 ],
         [0.15516284, 0.95918369]]]),
 array([[1, 1],
        [1, 0],
        [1, 1],
        ...,
        [0, 0],
        [1, 1],
        [1, 1]]),
 [-1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1,
  -1

In [16]:
preds, probs,ys = ci.inference(df["img_path"], df["light_or_dark"])

100%|██████████| 160/160 [00:57<00:00,  2.76it/s]


In [31]:
preds[:,0].mean()
preds[:,1].mean()

0.5407321434582374

In [23]:
len(ys) == -sum(ys)

True