In [1]:
import os
import os.path as osp
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader, Dataset
from wilds.datasets.fmow_dataset import FMoWDataset
from PIL import Image
from tqdm import trange
from sklearn.metrics import confusion_matrix
import open_clip

torch.set_num_threads(5)   # Sets the number of threads used for intra-operations
torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model,_, preprocess =  open_clip.create_model_and_transforms("ViT-L-14", pretrained='laion2b_s32b_b82k') 
model = model.to(device)
tokenizer = open_clip.get_tokenizer('ViT-L-14')

batch_size = 256

root_dir = r"../../../Dataset/data"

def get_transform():
    transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
        ])
    return transform
    
train_transform = get_transform()    
dataset = FMoWDataset(root_dir=root_dir, download=False)
train_data = dataset.get_subset('train')
val_data = dataset.get_subset('val')
test_data = dataset.get_subset('test')


metadata_str = dataset.metadata_map['region']
meadata_idx = dataset.metadata_fields.index('region')
metadata = dataset._metadata_array = dataset._metadata_array[:, meadata_idx]
train_data.group_str = val_data.group_str = test_data.group_str = lambda id: metadata_str[id]

n_groups = len(metadata_str)
setattr(train_data, 'n_groups', n_groups)
setattr(val_data, 'n_groups', n_groups)
setattr(test_data, 'n_groups', n_groups)

split_dict = dataset.split_dict
split_array = dataset.split_array
y_array = dataset.y_array

train_data.get_group_array = metadata[split_array == split_dict['train']]
val_data.get_group_array = metadata[split_array == split_dict['val']]
test_data.get_group_array = metadata[split_array == split_dict['test']]

class ConfounderDataset(Dataset):
    def __init__(self):
        self.dataset = dataset
        raise NotImplementedError

    def __len__(self):
        if self.dataset == "train":
            return len(train_data)
        if self.dataset == "val":
            return len(val_data)
        if self.dataset == "test":
            return len(test_data)

    def __getitem__(self, idx):
        if self.dataset == "train":
            img, y, _ = train_data[idx]
            a = self.train_a[idx]
            x = preprocess(img)
            return x, y, a
            
        if self.dataset == "val":
            img, y, _ = val_data[idx]
            a = self.val_a[idx]
            x = preprocess(img)
            return x, y, a

        if self.dataset == "test":
            img, y, _ = test_data[idx]
            a = self.test_a[idx]
            x = preprocess(img)
            return x, y, a
            


class FMOW(ConfounderDataset):
    def __init__(self, dataset):
        self.dataset = dataset
        train_data.get_group_array = metadata[split_array == split_dict['train']]
        val_data.get_group_array = metadata[split_array == split_dict['val']]
        test_data.get_group_array = metadata[split_array == split_dict['test']]
        self.train_a = train_data.get_group_array
        self.train_y = train_data.y_array
        self.val_a = val_data.get_group_array
        self.val_y = val_data.y_array
        self.test_a = test_data.get_group_array
        self.test_y = test_data.y_array
        



training_dataset = FMOW('train')
test_dataset = FMOW('test')

training_data_loader  = torch.utils.data.DataLoader(dataset = training_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=18,
                                                drop_last=True)

test_data_loader  = torch.utils.data.DataLoader(dataset = test_dataset,
                                                batch_size= batch_size,
                                                shuffle=False,
                                                num_workers=18,
                                                drop_last=False)
print('Done')

  self.test_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp'], errors='coerce', utc=True, infer_datetime_format=True) >= year_dt)
  self.val_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp'], errors='coerce', utc=True, infer_datetime_format=True) >= year_minus_3_dt) & ~self.test_ood_mask
  ts = pd.to_datetime(self.metadata['timestamp'], errors='coerce', utc=True, infer_datetime_format=True)


Done


In [2]:
from tqdm import tqdm
from torch import nn
texts = ["Over Europe.", "Over Asia", "Over America.", "Over Africa", "Over Oceania"]
text = tokenizer(texts).to(device)
text_features = model.encode_text(text)
Eurobg = text_features[0].unsqueeze(0)
Asiabg = text_features[1].unsqueeze(0)
Americabg = text_features[2].unsqueeze(0)
Africabg = text_features[3].unsqueeze(0)
Oceaniabg = text_features[4].unsqueeze(0)


"""
def inference_a_test(vlm, spu_v0, spu_v1):
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    
    for step, (test_input, test_target, sensitive, _) in enumerate(tqdm(test_data_loader, desc="Testing")):
        with torch.no_grad():
            test_target = test_target.to(device)
            sensitive = sensitive.to(device)
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            infered_a = inference_a(vlm, landbg, waterbg,z )
            
            mask_00 = ((test_target == 0) & (sensitive == 0))
            mask_01 = ((test_target == 0) & (sensitive == 1))
            mask_10 = ((test_target == 1) & (sensitive == 0))
            mask_11 = ((test_target == 1) & (sensitive == 1))


            correct_00 += (infered_a[mask_00] == sensitive[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (infered_a[mask_01] == sensitive[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (infered_a[mask_10] == sensitive[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (infered_a[mask_11] == sensitive[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10
    acc_11 = correct_11 / total_11

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    print(f'Accuracy for y=1, s=1: {acc_11}')   

            
 """


def inference_a(vlm, spu_v0, spu_v1,spu_v2, spu_v3,spu_v4, z):
    text_embeddings = torch.cat((spu_v0, spu_v1,spu_v2, spu_v3,spu_v4), dim=0)
    norm_img_embeddings = z 
    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
    cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
    logits_per_image = cosine_similarity 
    probs = logits_per_image.softmax(dim=1)
    _, predic = torch.max(probs.data, 1)
    return predic

            
def supervised_inference_a(img):
    resnet18 = models.resnet18(pretrained=False)
    num_classes = 2 
    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
    res_model = resnet18
    res_model.load_state_dict(torch.load('res_net.pth'))
    res_model = res_model.to(device)
    res_model.eval()
    img = img.to(device)
    test_pred_ = res_model(img)
    _, predic = torch.max(test_pred_.data, 1)
    return predic    
   
def process_group(mask, spu, z):
    mask = mask.to(device)
    subset = z[mask]
    inner_result = torch.mm(subset / subset.norm(dim=1, keepdim=True), spu.t())
    return inner_result.detach().cpu().numpy()            
    
def compute_scale(vlm, spu_v0, spu_v1, spu_v2, spu_v3, spu_v4):
    vlm = vlm.to(device)
    scale_0 = []
    scale_1 = []
    scale_2 = []
    scale_3 = []
    scale_4 = []

    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)
    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)
    spu2 = spu_v2  / spu_v2.norm(dim=1, keepdim=True)
    spu3 = spu_v3 / spu_v3.norm(dim=1, keepdim=True)
    spu4 = spu_v4  / spu_v4.norm(dim=1, keepdim=True)

    
    for step, (test_input, _, sensitive ) in enumerate(tqdm(training_data_loader, desc="Computing Scale", dynamic_ncols=False, ascii=True)):
        with torch.no_grad():
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            if a ==True:
                sensitive = sensitive
            else:
                if partial_a == False:
                    sensitive = inference_a(vlm, Eurobg, Asiabg, Americabg, Africabg,Oceaniabg,z )
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)
            
            mask_0 = sensitive == 0
            scale_0.extend(process_group(mask_0, spu0, z))
            
            mask_1 = sensitive == 1
            scale_1.extend(process_group(mask_1, spu1, z))

            mask_2 = sensitive == 2
            scale_2.extend(process_group(mask_2, spu2, z))
            
            mask_3 = sensitive == 3
            scale_3.extend(process_group(mask_3, spu3, z))

            mask_4 = sensitive == 4
            scale_4.extend(process_group(mask_4, spu4, z))

    scale_0 = np.array(scale_0)
    scale_1 = np.array(scale_1)
    scale_2 = np.array(scale_2)
    scale_3 = np.array(scale_3)
    scale_4 = np.array(scale_4)
    
    print(np.mean(scale_0))
    print(np.mean(scale_1))
    print(np.mean(scale_2))
    print(np.mean(scale_3))
    print(np.mean(scale_4))

    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1)), torch.tensor(np.mean(scale_2)), torch.tensor(np.mean(scale_3)), torch.tensor(np.mean(scale_4))



def test_epoch(vlm,   dataloader):

    scale_0, scale_1, scale_2, scale_3, scale_4 = compute_scale(model, Eurobg, Asiabg,Americabg, Africabg, Oceaniabg)
    texts = ["airport", "airport_hangar", "airport_terminal", "amusement_park", "aquaculture", "archaeological_site", "barn", "border_checkpoint", 
         "burial_site", "car_dealership", "construction_site", "crop_field", "dam", "debris_or_rubble", "educational_institution", 
         "electric_substation", "factory_or_powerplant", "fire_station", "flooded_road", "fountain", "gas_station", "golf_course", 
         "ground_transportation_station", "helipad", "hospital", "impoverished_settlement", "interchange", "lake_or_pond", "lighthouse", 
         "military_facility", "multi-unit_residential", "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park", 
         "parking_lot_or_garage", "place_of_worship", "police_station", "port", "prison", "race_track", "railway_bridge", "recreational_facility", 
         "road_bridge", "runway", "shipyard", "shopping_mall", "single-unit_residential", "smokestack", "solar_farm", "space_facility", "stadium", 
         "storage_tank", "surface_mine", "swimming_pool", "toll_booth", "tower", "tunnel_opening", "waste_disposal", "water_treatment_facility", 
         "wind_farm", "zoo"
        ]
    expanded_texts =  ["A satellite image of " + text.replace('_', ' ') + "." for text in texts]
    print(expanded_texts)
    text_label_tokened = tokenizer(expanded_texts).to(device)
    text_embeddings = vlm.encode_text(text_label_tokened)
    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)

    vlm = vlm.to(device)
    vlm.eval()   
    test_pred = []
    test_gt = []
    sense_gt = []
    cos = nn.CosineSimilarity(dim = 0)

    correct_0, total_0 = 0, 0
    correct_1, total_1 = 0, 0
    correct_2, total_2 = 0, 0
    correct_3, total_3 = 0, 0
    correct_4, total_4 = 0, 0
    
    total_correct = 0
    total_count = 0
    num_classes = 5 

    correct = [0] * num_classes
    total = [0] * num_classes

    for step, (test_input, test_target, sensitive_real) in enumerate(tqdm(dataloader, desc="Zero Shot Testing", dynamic_ncols=False, ascii=True)):
        with torch.no_grad():
            gt = test_target.detach().cpu().numpy()
            sen = sensitive_real.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            z = z/ z.norm(dim=1, keepdim=True)
            
            if a == True:
                sensitive = sensitive_real
            if a == False:
                if partial_a == False:
                    sensitive = inference_a(vlm, Eurobg, Asiabg, Americabg, Africabg,Oceaniabg,z )
                    sensitive = torch.tensor(sensitive)
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)

            
            mask_0 = sensitive == 0
            mask_0 = mask_0.to(device)
            z[mask_0] -= scale_0 * Eurobg/ Eurobg.norm(dim=1, keepdim=True)
                
            mask_1 = sensitive == 1
            mask_1 = mask_1.to(device)
            z[mask_1] -= scale_1 * Asiabg/ Asiabg.norm(dim=1, keepdim=True)

            mask_2 = sensitive == 2
            mask_2 = mask_2.to(device)
            z[mask_2] -= scale_2 * Americabg/ Americabg.norm(dim=1, keepdim=True)

            mask_3 = sensitive == 3
            mask_3 = mask_3.to(device)
            z[mask_3] -= scale_3 * Africabg/ Africabg.norm(dim=1, keepdim=True)

            mask_4 = sensitive == 4
            mask_4 = mask_4.to(device)
            z[mask_4] -= scale_4 * Oceaniabg/ Oceaniabg.norm(dim=1, keepdim=True)
            
            
 
            img_embeddings = z
            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)
            
            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
            logits_per_image = cosine_similarity             
            probs = logits_per_image.softmax(dim=1)
            _, predic = torch.max(probs.data, 1)
            predic = predic.detach().cpu()
            
            test_pred.extend(predic.numpy())
            label = test_target.squeeze().detach().cpu()
            
            for i in range(num_classes):
                mask = (sensitive_real == i)
                correct_predictions = (predic[mask] == label[mask]).float().sum().item()
                count = mask.float().sum().item()
            
                correct[i] += correct_predictions
                total[i] += count
                total_correct += correct_predictions
                total_count += count
            
    accuracies = [correct[i] / total[i] if total[i] != 0 else 0 for i in range(num_classes)]
    total_accuracy = total_correct / total_count if total_count != 0 else 0
    for i in range(num_classes):
        print(f'Accuracy for s={i}: {accuracies[i]:.4f}')
    print(f'Total accuracy: {total_accuracy:.4f}')

a = True
partial_a = False
    

model = model.to(device)
test_epoch(model, test_data_loader)




Computing Scale: 100%|################################################################| 305/305 [32:16<00:00,  6.35s/it]


0.14512545
0.1435668
0.16128178
0.13966314
0.1632383
['A satellite image of airport.', 'A satellite image of airport hangar.', 'A satellite image of airport terminal.', 'A satellite image of amusement park.', 'A satellite image of aquaculture.', 'A satellite image of archaeological site.', 'A satellite image of barn.', 'A satellite image of border checkpoint.', 'A satellite image of burial site.', 'A satellite image of car dealership.', 'A satellite image of construction site.', 'A satellite image of crop field.', 'A satellite image of dam.', 'A satellite image of debris or rubble.', 'A satellite image of educational institution.', 'A satellite image of electric substation.', 'A satellite image of factory or powerplant.', 'A satellite image of fire station.', 'A satellite image of flooded road.', 'A satellite image of fountain.', 'A satellite image of gas station.', 'A satellite image of golf course.', 'A satellite image of ground transportation station.', 'A satellite image of helipad

Zero Shot Testing: 100%|################################################################| 87/87 [09:44<00:00,  6.72s/it]

Accuracy for s=0: 0.2037
Accuracy for s=1: 0.2750
Accuracy for s=2: 0.2019
Accuracy for s=3: 0.3062
Accuracy for s=4: 0.4242
Total accuracy: 0.2662



