In [5]:
from collections import defaultdict

from sklearn.model_selection import KFold

from torch.utils.data import Dataset, DataLoader, RandomSampler,SequentialSampler

from transformers import get_linear_schedule_with_warmup

from torch.nn import CrossEntropyLoss

import torch
import torch.nn as nn
from torch.optim import AdamW

from PIL import Image
import torchvision.transforms as transforms

import numpy as np
from scipy.special import softmax
from scipy.special import logit
from sklearn.linear_model import LogisticRegression 
from sklearn.metrics import accuracy_score

from tqdm import tqdm
import math

In [6]:
import random
import numpy as np
import torch

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


### wandbの設定

In [28]:
import wandb
wandb.login()

True

In [29]:
CUDA = (torch.cuda.device_count() > 0)
MASK_IDX = 103

In [30]:
def platt_scale(outcome,probs):
    logits = logit(probs)
    logits = logits.reshape(-1,1)
    log_reg = LogisticRegression(penalty='none', warm_start = True, solver = 'lbfgs' )
    log_reg.fit(logits, outcome)
    return log_reg.predict_proba(logits)

def gelu(x):
    return 0.5 * x * (1.0 + torch.erf(x/math.sqrt(2.0)))

In [31]:
def make_confound_vector(ids, vocab_size, use_counts = False):
    vec = torch.zeros(ids.shape[0],vocab_size)
    ones = torch.ones_like(ids,dtype = torch.float)
    
    if CUDA:
        vec = vec.cuda()
        ones = ones.cuda()
        ids = ids.cuda()
    vec[:,1] = 0.0
    if not use_counts:
        vec = (vec != 0).float()
    return vec.float()

In [32]:
import timm
import torch
from torch import nn

class ImageCausalModel(nn.Module):
    """The model itself."""
    def __init__(self, num_labels = 2,pretrained_model_names = "timm/eva02_tiny_patch14_224.mim_in22k"):
        super().__init__()

        self.num_labels = num_labels

        self.base_model = timm.create_model(pretrained_model_names,pretrained = True,num_classes = 0)
        self.base_model.fc = nn.Identity()

        # 因果推論用の追加レイヤー ここいらない説
        #self.classifier = nn.Linear(self.base_model.num_features, num_labels)
        self.Q_cls = nn.ModuleDict()

        # self.base_model.num_features は、事前学習済みの画像モデルからの特徴量サイズです。
        input_size = self.base_model.num_features + self.num_labels
        print(self.base_model.num_features)
        print(input_size)
        for T in range(2):
            # ModuleDict keys have to be strings..
            self.Q_cls['%d' % T] = nn.Sequential(
                nn.Linear(input_size, 200),
                nn.ReLU(),
                nn.Linear(200, self.num_labels))
        

        self.g_cls = nn.Linear(self.base_model.num_features + self.num_labels, 
            self.num_labels)

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self,images, confounds, treatment=None, outcome = None):
        features = self.base_model(images)
        # print("features")
        # print("C", confounds.shape)
        # print(confounds)
        # print(confounds.unsqueeze(1).shape)
        C = make_confound_vector(confounds.unsqueeze(1), self.num_labels)
        inputs = torch.cat((features, C), dim =  1)
        g = self.g_cls(inputs)

        if outcome is not None:
            g_loss = CrossEntropyLoss()(g.view(-1, self.num_labels),treatment.view(-1))
        else:
            g_loss = 0.0

        Q_logits_T0 = self.Q_cls['0'](inputs)
        Q_logits_T1 = self.Q_cls['1'](inputs)

        #[todo] ここ元論文と実装が異なってたなぜ?
        #Q_prob_T0 = torch.sigmoid(Q_logits_T0)
        #Q_prob_T1 = torch.sigmoid(Q_logits_T1)
        if outcome is not None:
            T0_indices = (treatment == 0).nonzero().squeeze()
            Y_T1_labels = outcome.clone().scatter(0,T0_indices, -100)

            T1_indices = (treatment == 1).nonzero().squeeze()
            Y_T0_labels = outcome.clone().scatter(0,T1_indices, -100)
            Q_loss_T1 = CrossEntropyLoss()(Q_logits_T1.view(-1,self.num_labels), Y_T1_labels)
            Q_loss_T0 = CrossEntropyLoss()(Q_logits_T0.view(-1, self.num_labels), Y_T0_labels)
            Q_loss = Q_loss_T1 + Q_loss_T0
        else:
            Q_loss = 0.0
            
        sm = torch.nn.Softmax(dim = 1)
        Q_prob_T0 = sm(Q_logits_T0)[:,1]
        Q_prob_T1 = sm(Q_logits_T1)[:,1]
        g_prob = sm(g)[:,1]
    
        
        return g_prob, Q_prob_T0, Q_prob_T1, g_loss, Q_loss
        

In [33]:

class CausalImageModelWrapper:
    def __init__(self, g_weight=1.0, Q_weight=0.1, mlm_weight=1.0, batch_size=32):
        self.model = ImageCausalModel(num_labels=2, pretrained_model_names="timm/eva02_tiny_patch14_224.mim_in22k")
        if CUDA:
            self.model = self.model.cuda()

        self.loss_weights = {
            'g': g_weight,
            'Q': Q_weight
        }
        self.batch_size = batch_size
        self.losses = []
        
    
    def train(self,images, confounds, treatments, outcomes , learning_rate = 2e-5, epochs  = 3):

        wandb.init(project = "image_causal_project", config = {
            "learning_rate":learning_rate,
            "epochs": epochs,
            "batch_size": self.batch_size,
            "g_weight": self.loss_weights['g'],
            "Q_weight": self.loss_weights['Q']
        })

        dataloader = self.build_dataloader(images, confounds, treatments, outcomes, batch_size = self.batch_size,sampler="seqiemtial")
        self.model.train()
        optimizer = AdamW(self.model.parameters(), lr = learning_rate, eps = 1e-8)
        total_steps = len(dataloader) * epochs
        warmup_steps = total_steps * 0.1
        print(learning_rate)
        print(warmup_steps)
        scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps = warmup_steps,num_training_steps = total_steps)
        print(total_steps)
        for epoch in range(epochs):
            epoch_losses = []
            g_losses = []
            Q_losses = []
            self.model.train()
            for batch in dataloader:
                if CUDA:
                    batch = tuple(x.cuda() for x in batch)
                images, confounds, treatments, outcomes = batch

                self.model.zero_grad()
                g, Q0, Q1, g_loss, Q_loss = self.model(images, confounds, treatments, outcomes)
                loss = self.loss_weights['g'] * g_loss + self.loss_weights['Q'] * (Q_loss)
                loss.backward()
                optimizer.step()
                scheduler.step()  
                g_losses.append(g_loss.detach().cpu().item())
                Q_losses.append((Q_loss).detach().cpu().item())
                epoch_losses.append(loss.detach().cpu().item())
                wandb.log({"g_loss": g_loss.detach().cpu().item(),
                           "Q_loss": (Q_loss).detach().cpu().item(),
                           "epoch_loss": loss.detach().cpu().item()
                })
                self.losses.append(loss.detach().cpu().item())
     
            avg_loss = sum(epoch_losses) / len(epoch_losses)
            avg_g_loss = sum(g_losses) / len(g_losses)
            avg_Q_loss = sum(Q_losses) / len(Q_losses)
            wandb.log({
                "epoch": epoch, 
                "loss":avg_loss, 
                "g_losses":avg_g_loss, 
                "Q_losses" : avg_Q_loss
                })
            print(len(self.losses))
        wandb.watch(self.model)
        wandb.finish()            
        return self.model
    
    def inference(self, images, confounds, outcome = None):
        self.model.eval()
        dataloader = self.build_dataloader(images, confounds,outcomes = outcome,
                                           sampler = 'sequential',batch_size= self.batch_size)
        Q0s = []
        Q1s = []
        Ys = []
        for i, batch in tqdm(enumerate(dataloader),total = len(dataloader)):
            if CUDA: 
                batch = (x.cuda() for x in batch)
            images, confounds, _ ,outcomes = batch

            g, Q0,Q1,_,_= self.model(images, confounds, outcome = None)
    
            Q0s += Q0.detach().cpu().numpy().tolist()
            Q1s += Q1.detach().cpu().numpy().tolist()
            Ys += outcomes.detach().cpu().numpy().tolist()

            ## [todo] inferenceメソッドの形式?
        probs = np.array(list(zip(Q0s, Q1s)))
        preds = np.argmax(probs, axis = 1)  
        return probs, preds, Ys
    
    def ATE(self,C,image, Y = None, platt_scaling = False):
        ## [todo] ATEの計算方法
        Q_probs,_,Ys = self.inference(image,C,outcome = Y)
        if platt_scaling and Y is not None:
            Q0 = platt_scale(Ys, Q_probs[:,0])[:,0]
            Q1 = platt_scale(Ys, Q_probs[:,1])[:,1]
        else:
            Q0 = Q_probs[:,0]
            Q1 = Q_probs[:,1]
        
        print("Q0:", Q0, "Q1:", Q1)
        return np.mean(Q0 - Q1)

    def build_dataloader(self,image_paths, confounds, treatments = None, outcomes = None,batch_size = 32,sampler = "random"):
        dataset = CausalImageDataset(image_paths, confounds, treatments, outcomes)
        sampler = RandomSampler(dataset) if sampler == "random" else SequentialSampler(dataset)
        dataloader = DataLoader(dataset, batch_size = batch_size,sampler = sampler,num_workers=0)
        return dataloader
    


In [34]:
class CausalImageDataset(Dataset):
    def __init__(self,image_paths, confounds, treatments = None, outcomes = None,transform = None):
        self.image_paths = image_paths
        self.confounds = confounds
        self.treatments = treatments
        self.outcomes = outcomes

        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406] , std = [0.229, 0.224, 0.225])
                ]
            )
        else:
            self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self,idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        
        confounds = self.confounds[idx]
        treatment = self.treatments[idx] if self.treatments is not None else -1
        outcome = self.outcomes[idx] if self.outcomes is not None else -1
        return image , confounds, treatment, outcome

    

In [7]:
import sys 
sys.path.append("../../")
import pandas as pd
df = pd.read_csv("/root/graduation_thetis/causal-bert-pytorch/input/Appliances_preprocess_1116.csv")
df.head()

Unnamed: 0.1,Unnamed: 0,name,main_category,sub_category,image,link,ratings,no_of_ratings,discount_price,actual_price,img_path,actual_price_yen,sharpness,sharpness_ave,light_or_dark,outcome,y0,y1
0,72,AmazonBasics High Speed 55 Watt Oscillating Pe...,appliances,All Appliances,https://m.media-amazon.com/images/I/71QfUcEOg8...,https://www.amazon.in/AmazonBasics-400mm-Pedes...,4.1,6113,"₹2,099",3300.0,/root/graduation_thetis/causal-bert-pytorch/in...,363000.0,1997.908563,0,1,1,0.56755,0.533931
1,73,Farberware Mini Blender Fruit Mixer Machine Po...,appliances,All Appliances,https://m.media-amazon.com/images/I/716mmFt0PG...,https://www.amazon.in/Farberware-Portable-Elec...,2.9,6071,₹499,1199.0,/root/graduation_thetis/causal-bert-pytorch/in...,131890.0,4380.767889,0,1,1,0.56755,0.533931
2,74,PHILIPS Handheld Garment Steamer STH3000/20 - ...,appliances,All Appliances,https://m.media-amazon.com/images/I/71W2XPQdBq...,https://www.amazon.in/PHILIPS-Handheld-Garment...,4.0,1553,"₹3,995",4095.0,/root/graduation_thetis/causal-bert-pytorch/in...,450450.0,1789.675186,0,1,1,0.56755,0.533931
3,75,"Cookwell Bullet Mixer Grinder (5 Jars, 3 Blade...",appliances,All Appliances,https://m.media-amazon.com/images/I/81yobRRV8n...,https://www.amazon.in/Cookwell-Bullet-Mixer-Gr...,4.1,9592,"₹2,479",6000.0,/root/graduation_thetis/causal-bert-pytorch/in...,660000.0,5434.679447,1,1,0,0.109457,0.259581
4,76,"Bajaj ATX 4 750-Watt Pop-up Toaster, 2-Slice A...",appliances,All Appliances,https://m.media-amazon.com/images/I/51D5T7TGVb...,https://www.amazon.in/Bajaj-ATX-750-Watt-Pop-u...,4.3,9520,"₹1,499",2250.0,/root/graduation_thetis/causal-bert-pytorch/in...,247500.0,1697.735543,0,1,1,0.56755,0.533931


In [36]:
df["light_or_dark"] = df["light_or_dark"].apply(lambda x : 1 if x == "light" else 0)

In [37]:
set_seed(42)
ci = CausalImageModelWrapper(batch_size = 32, g_weight=0.1, Q_weight=0.1)
ci.train(df["img_path"],df["light_or_dark"], df["price_ave"], df["outcome"],epochs = 10)


192
194


2e-05
160.0
1600
160
320
480
640
800
960
1120
1280
1440
1600


0,1
Q_loss,▅▆▅▆██▄▃▄▄▃▂▄▂▂▂▃▂▂▂▂▁▁▁▁▁▄▂▂▁▂▁▂▂▁▂▂▂▂▂
Q_losses,█▄▂▂▁▁▁▁▁▁
epoch,▁▂▃▃▄▅▆▆▇█
epoch_loss,▅█▆▄▇▇▇▆▅▄▃▂▃▂▂▂▂▂▁▂▂▂▁▂▃▂▁▂▂▂▁▁▁▂▁▂▂▂▂▁
g_loss,█▆▆▄▄▃▂▂▁▃▂▁▂▁▂▂▂▂▂▂▂▂▁▂▂▂▁▁▂▂▂▃▂▁▂▁▂▂▂▂
g_losses,█▅▂▁▁▁▁▁▁▁
loss,█▄▂▂▁▁▁▁▁▁

0,1
Q_loss,1.32982
Q_losses,1.22824
epoch,9.0
epoch_loss,0.18803
g_loss,0.55049
g_losses,0.61236
loss,0.18406


ImageCausalModel(
  (base_model): Eva(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(14, 14), stride=(14, 14))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (rope): RotaryEmbeddingCat()
    (blocks): ModuleList(
      (0-11): 12 x EvaBlock(
        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (attn): EvaAttention(
          (qkv): Linear(in_features=192, out_features=576, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (proj): Linear(in_features=192, out_features=192, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path1): Identity()
        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): GluMlp(
          (fc1): Linear(in_features=192, out_features=1024, bias=True)
          (act): SiLU()
          (drop1): Dropout(p=0.0, inplace=False)
          (norm): Identity()
         

In [None]:
print(ci.model)

In [None]:
probs,preds,Ys = ci.inference(df["img_path"], df["light_or_dark"],df["outcome"])

In [None]:
probs.shape

In [None]:
probs

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(Ys, preds)
precision = precision_score(Ys, preds)
recall = recall_score(Ys, preds)
f1 = f1_score(Ys, preds)

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

In [39]:
print(ci.ATE(df["light_or_dark"], df["img_path"],Y = None,platt_scaling = True))

100%|██████████| 160/160 [00:27<00:00,  5.86it/s]

Q0: [0.76335073 0.76769769 0.72861868 ... 0.78121001 0.72307813 0.8110612 ] Q1: [0.6971913  0.38844791 0.60665202 ... 0.46023688 0.54316306 0.57524842]
0.14347405117161818





In [8]:
from collections import defaultdict
import numpy as np
def ATE_unadjusted(T, Y):
    x = defaultdict(list)
    for t, y in zip(T, Y):
        x[t].append(y)
    T0 = np.mean(x[0])
    T1 = np.mean(x[1])
    return T0 - T1

def ATE_adjusted(C, T, Y):
    x = defaultdict(list)
    for c, t, y in zip(C, T, Y):
        x[c, t].append(y)

    C0_ATE = np.mean(x[0,0]) - np.mean(x[0,1])
    C1_ATE = np.mean(x[1,0]) - np.mean(x[1,1])
    print(C0_ATE, C1_ATE)
    print(x)
    return np.mean([C0_ATE, C1_ATE])


print("ATE_unadjusted: ", ATE_unadjusted(df["light_or_dark"], df["outcome"]))
print("ATE_adjusted: ", ATE_adjusted(df["sharpness_ave"], df["light_or_dark"],df["outcome"]))

ATE_unadjusted:  -0.07329184114618686
0.06866410085521002 -0.1281966930336857
defaultdict(<class 'list'>, {(0, 1): [1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1,