In [1]:
###### data loader####
import os
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as tfms
from PIL import Image
import random
from tqdm import trange
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from torch import nn, optim
import torch
import numpy as np
from tqdm import tqdm
import open_clip
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
from sklearn.model_selection import train_test_split
import os.path as osp

torch.set_num_threads(5)   # Sets the number of threads used for intra-operations
torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations

import open_clip

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
logabs = lambda x: torch.log(torch.abs(x))
batch_size = 256


def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    


model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
model = model.to(device)
model = model.eval()
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')


seed_everything(1024)



def get_transform():
    transform = tfms.Compose([
        tfms.Resize((224,224)),
        tfms.ToTensor()
    ])
    return transform

class ConfounderDataset(Dataset):
    def __init__(self, root_dir,
                 target_name, confounder_names,
                 model_type=None, augment_data=None):
        
        raise NotImplementedError

    def __len__(self):
        if self.split == 'train':
            return len(self.training_sample)
        if self.split == 'val':
            return len(self.valid_sample)
        if self.split == 'test':
            return len(self.test_sample)

    def __getitem__(self, idx):
        if self.split == 'train': 
            y = self.training_sample_y_array[idx]
            y = torch.tensor(y)
            a = self.training_sample_confounder_array[idx]
            a = torch.tensor(a)
            img_filename = os.path.join(
                self.data_dir,
                self.training_sample[idx]) 
            img = Image.open(img_filename).convert('RGB')

            x = preprocess(img)
            img_for_res = self.transform(img)
            
            
        if self.split == 'val': 
            y = self.valid_sample_y_array[idx]
            y = torch.tensor(y)
            a = self.valid_sample_confounder_array[idx]
            a = torch.tensor(a)
            img_filename = os.path.join(
                self.data_dir,
                self.valid_sample[idx])       
            img = Image.open(img_filename).convert('RGB')

            x = preprocess(img)
            img_for_res = self.transform(img)
            
        if self.split == 'test': 
            y = self.test_sample_y_array[idx]
            a = self.test_sample_confounder_array[idx]
            y = torch.tensor(y)
            a = torch.tensor(a)
            img_filename = os.path.join(
                self.data_dir,
                self.test_sample[idx])       
            img = Image.open(img_filename).convert('RGB')

            x = preprocess(img)
            img_for_res = self.transform(img)
        return x,y,a, img_for_res



    
class ISICDataset(ConfounderDataset):
    def __init__(self, 
                 root_dir,
                 seed,
                 split,
                 target_name = ['label'], 
                 confounder_names=['patches'],
                 model_type=None,
                 augment_data=False,
                 mix_up=False,
                 group_id=None,
                 id_val=True):
        self.split = split
        self.augment_data = augment_data
        self.group_id = group_id
        self.mix_up = mix_up
        self.model_type = model_type
        self.target_name = target_name
        self.confounder_names = confounder_names
        self.split_dir = osp.join(root_dir, 'trap-sets')
        self.data_dir = osp.join(root_dir, 'ISIC2018_Task1-2_Training_Input')
        
        metadata = {}
        metadata['train'] = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_train{seed}.csv'))
        if id_val:
            test_val_data = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_test{seed}.csv'))
            idx_val, idx_test = train_test_split(np.arange(len(test_val_data)), 
                                                test_size=0.8, random_state=0)
            metadata['test'] = test_val_data.iloc[idx_test]
            metadata['val'] = test_val_data.iloc[idx_val]
        else:
            metadata['test'] = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_test{seed}.csv'))
            metadata['val'] = pd.read_csv(osp.join(self.split_dir, f'isic_annotated_val{seed}.csv'))
            metadata_new = metadata['train'].merge(metadata['val'], how='left', indicator=True)
            metadata_new = metadata_new[metadata_new['_merge'] == 'left_only']
            metadata['train'] = metadata_new.drop(columns=['_merge'])
        
        
        self.precomputed = False
        self.pretransformed = False
        self.n_classes = 2
        self.n_confounders = 1
        confounder = confounder_names[0]
        
        self.training_sample = metadata['train']['image'].values
        self.training_sample_y_array = metadata['train'][target_name].values
        self.training_sample_confounder_array = metadata['train'][confounder].values
        
        self.valid_sample = metadata['val']['image'].values
        self.valid_sample_y_array = metadata['val'][target_name].values
        self.valid_sample_confounder_array = metadata['val'][confounder].values
        
        self.test_sample = metadata['test']['image'].values
        self.test_sample_y_array = metadata['test'][target_name].values
        self.test_sample_confounder_array = metadata['test'][confounder].values
        self.transform = get_transform()
        

    
data_dir = r"../isic"
seed = 1

training_isic_dataset  = ISICDataset(data_dir, seed, 'train')
valid_isic_dataset  = ISICDataset(data_dir, seed, 'val')
test_isic_dataset  = ISICDataset(data_dir, seed, 'test')


training_data_loader  = torch.utils.data.DataLoader(dataset = training_isic_dataset,
                                                batch_size= batch_size,
                                                shuffle=True,
                                                num_workers=0)

valid_data_loader  = torch.utils.data.DataLoader(dataset = valid_isic_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=0)

test_data_loader  = torch.utils.data.DataLoader(dataset = test_isic_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=0)
print('Done')


  warn(


Done


In [2]:
spurious_text = ["There exists a color patch",  "There exists no color patch"] 
texts = tokenizer(spurious_text).to(device)
null_image = torch.rand((1,3,224,224)).to(device)
model = model.to(device)
_, spurious_embedding, _ = model(null_image, texts)

no_patch = spurious_embedding[1].unsqueeze(0).to(device)
patch = spurious_embedding[0].unsqueeze(0).to(device)

In [3]:
def inference_a_test(vlm, spu_v0, spu_v1):
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    
    for step, (test_input, test_target, sensitive, _) in enumerate(tqdm(test_data_loader, desc="Testing")):
        with torch.no_grad():
            test_target = test_target.to(device)
            sensitive = sensitive.to(device)
            test_target = test_target.squeeze()
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            infered_a = inference_a(vlm, no_patch, patch,z )
            
            mask_00 = ((test_target == 0) & (sensitive == 0))
            mask_01 = ((test_target == 0) & (sensitive == 1))
            mask_10 = ((test_target == 1) & (sensitive == 0))
            mask_11 = ((test_target == 1) & (sensitive == 1))

            correct_00 += (infered_a[mask_00] == sensitive[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (infered_a[mask_01] == sensitive[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (infered_a[mask_10] == sensitive[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (infered_a[mask_11] == sensitive[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10
    acc_11 = correct_11 / (total_11+1e-9)

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    print(f'Accuracy for y=1, s=1: {acc_11}')   

            
def inference_a(vlm, spu_v0, spu_v1, z):
    text_embeddings = torch.cat((spu_v0, spu_v1), dim=0)
    norm_img_embeddings = z 
    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
    cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
    logits_per_image = cosine_similarity 
    probs = logits_per_image.softmax(dim=1)
    _, predic = torch.max(probs.data, 1)
    return predic

            
def supervised_inference_a(img):
    resnet18 = models.resnet18(pretrained=False)
    num_classes = 2 
    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
    res_model = resnet18
    res_model.load_state_dict(torch.load('res_net.pth'))
    res_model = res_model.to(device)
    res_model.eval()
    img = img.to(device)
    test_pred_ = res_model(img)
    _, predic = torch.max(test_pred_.data, 1)
    return predic            
            
    
def compute_scale(vlm, spu_v0, spu_v1):
    vlm = vlm.to(device)
    scale_0 = []
    scale_1 = []
    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)
    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)

    
    for step, (test_input, _, sensitive, img) in enumerate(tqdm(training_data_loader, desc="Computing Scale")):
        with torch.no_grad():
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            if a ==True:
                sensitive = sensitive
            else:
                if partial_a == False:
                    sensitive = inference_a(vlm, no_patch, patch,z )
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)
            
            
            mask_0 = sensitive == 0
            mask_0 = mask_0.to(device)
            h = z[mask_0]
            inner_no_patch = torch.mm(h/ h.norm(dim=1, keepdim=True), spu0.t())
            scale_0.extend(inner_no_patch.detach().cpu().numpy())
                
            mask_1 = sensitive == 1
            mask_1 = mask_1.to(device)
            g = z[mask_1]
            inner_patch = torch.mm(g/ g.norm(dim=1, keepdim=True), spu1.t())
            scale_1.extend(inner_patch.detach().cpu().numpy())
    scale_0 = np.array(scale_0)
    scale_1 = np.array(scale_1)
    print(np.mean(scale_0))
    print(np.mean(scale_1))
    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1))



def test_epoch(vlm,   dataloader):
    scale_0, scale_1 = compute_scale(model, no_patch, patch)

    texts_label = ["This is a benign lesion",  "This is a malignant lesion"] 
    text_label_tokened = tokenizer(texts_label).to(device)
    
    vlm = vlm.to(device)
    vlm.eval()   
    test_pred = []
    test_gt = []
    sense_gt = []
    female_predic = []
    female_gt = []
    male_predic = []
    male_gt = []
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    cos = nn.CosineSimilarity(dim = 0)
    feature_a0 = []
    feature_a1 = []

    for step, (test_input, test_target, sensitive_real,img) in enumerate(tqdm(dataloader, desc="Zero Shot Testing")):
        test_target = test_target.squeeze()
        with torch.no_grad():
            gt = test_target.detach().cpu().numpy()
            sen = sensitive_real.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)
            test_input = test_input.to(device)

            text_label_tokened
            z = vlm.encode_image(test_input)
            z = z/ z.norm(dim=1, keepdim=True)
            
            if a == True:
                sensitive = sensitive_real
            if a == False:
                if partial_a == False:
                    sensitive = inference_a(vlm, no_patch, patch,z )
                    sensitive = torch.tensor(sensitive)
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)
            
            mask_0 = sensitive == 0
            mask_0 = mask_0.to(device)
            z[mask_0] -= scale_0 * no_patch/ no_patch.norm(dim=1, keepdim=True)
                
            mask_1 = sensitive == 1
            mask_1 = mask_1.to(device)
            z[mask_1] -= scale_1 * patch/ patch.norm(dim=1, keepdim=True)
            
        
            
            
            feature_a0.extend(z[mask_0].detach().cpu().numpy())
            feature_a1.extend(z[mask_1].detach().cpu().numpy())
            
            text_embeddings = vlm.encode_text(text_label_tokened)
            img_embeddings = z
            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)
            norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
                    
            logits_per_image = cosine_similarity 
            probs = logits_per_image.softmax(dim=1)
            _, predic = torch.max(probs.data, 1)
            predic = predic.detach().cpu()
            test_pred.extend(predic.numpy())
            label = test_target.squeeze().detach().cpu()
            mask_00 = ((label == 0) & (sensitive_real == 0))
            mask_01 = ((label == 0) & (sensitive_real == 1))
            mask_10 = ((label == 1) & (sensitive_real == 0))
            mask_11 = ((label == 1) & (sensitive_real == 1))


            correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    acc = accuracy_score(test_gt, test_pred)
    print('acc', accuracy_score(test_gt, test_pred))

a = True
partial_a = False
    
model = model.to(device)
test_epoch(model, test_data_loader)

Computing Scale: 100%|███████████████████████████████████████████████████████████████████████████| 8/8 [06:24<00:00, 48.06s/it]


0.30454728
0.29942143


Zero Shot Testing: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [02:07<00:00, 42.62s/it]

Accuracy for y=0, s=0: 0.6910755148741419
Accuracy for y=0, s=1: 0.8545454545454545
Accuracy for y=1, s=0: 0.6587301587301587
acc 0.6990291262135923



