In [1]:
import torch
import torch.nn
import clip
import os
import torch.utils.data
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from scipy import interpolate
from torchvision import datasets
import torchvision

# Test Module

In [2]:
def get_output(image_inputs, text_inputs, t=1.0):
    image_features = model.encode_image(image_inputs).float()
    image_features /= image_features.norm(dim=-1, keepdim=True)

    text_features = model.encode_text(text_inputs).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)
    output = (image_features @ text_features.T) / t
    return output


def get_result(ID_score, OOD_score, ood_name="None"):
    id_score = ID_score
    ood_score = OOD_score

    y_true = [1] * len(id_score) + [0] * len(ood_score)
    y_scores = id_score + ood_score
    roc_auc = roc_auc_score(y_true, y_scores)
    fpr, tpr, thresh = roc_curve(y_true, y_scores, pos_label=1)
    fpr95 = float(interpolate.interp1d(tpr, fpr)(0.95))
    print(f"OOD is {ood_name}, The AUC score is:[{roc_auc}], The FPR95 score is:[{fpr95}]")


def get_ood_score(ood_data_loader, text_inputs, len_of_ID, name="none", t=1.0):
    _pre_max = []

    for i, (images, class_ids) in enumerate(ood_data_loader):
        image_inputs = images.to(device)

        with torch.no_grad():
            logits = get_output(image_inputs, text_inputs, t=t)
            probs = logits.softmax(dim=-1).cpu().numpy()

        for j in range(probs.shape[0]):
            prob = probs[j][:len_of_ID + 1]
            _pre_max.append(prob[:len_of_ID].max())

        print(f"{name} has finished {i / len(ood_data_loader) * 100:.0f}%")
    return _pre_max

In [3]:
def test_in_ood_datasets(ID_dataloader, classes_list, ID_name="none", Agent="a photo of things."):
        input_text = classes_list.copy()
        L = len(classes_list)
        for i in range(L):
            input_text.append(Agent)

        text_inputs = clip.tokenize(input_text).to(device)
        print(f"text input shape {text_inputs.shape}")

        ID_pre_max = []

        for i, (images, class_ids) in enumerate(ID_dataloader):
            image_inputs = images.to(device)

            with torch.no_grad():
                logits = get_output(image_inputs, text_inputs)
                probs = logits.softmax(dim=-1).cpu().numpy()

            for j in range(probs.shape[0]):
                prob = probs[j][:L + 1]
                ID_pre_max.append(prob[:L].max())

            print(f"ID test dataset has finished {i / len(ID_dataloader) * 100:.2f}%")

        int_pre_max = get_ood_score(ood_data_loader=iNt_loader, name="iNt", text_inputs=text_inputs, len_of_ID=L)
        sun_pre_max = get_ood_score(ood_data_loader=sun_loader, name="SUN", text_inputs=text_inputs, len_of_ID=L)
        place_pre_max = get_ood_score(ood_data_loader=place_loader, name="Places", text_inputs=text_inputs, len_of_ID=L)
        textures_pre_max = get_ood_score(ood_data_loader=textures_loader, name="Textures", text_inputs=text_inputs,
                                         len_of_ID=L)

        get_result(ID_score=ID_pre_max, OOD_score=int_pre_max, ood_name="iNt")
        get_result(ID_score=ID_pre_max, OOD_score=sun_pre_max, ood_name="SUN")
        get_result(ID_score=ID_pre_max, OOD_score=place_pre_max, ood_name="Places")
        get_result(ID_score=ID_pre_max, OOD_score=textures_pre_max, ood_name="Textures")

# Load Model

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/16", device=device)
print(device)

cuda:0


# Data Loader

In [6]:
textures = torchvision.datasets.ImageFolder(root=os.path.expanduser("dtd"), transform=transform)
textures_loader = torch.utils.data.DataLoader(textures, batch_size=256, shuffle=False, pin_memory=True, num_workers=10)

iNt = torchvision.datasets.ImageFolder(root=os.path.expanduser("iNaturalist"),
                                       transform=transform)
iNt_loader = torch.utils.data.DataLoader(iNt, batch_size=256, shuffle=False, pin_memory=True, num_workers=10)

place365 = torchvision.datasets.ImageFolder(root=os.path.expanduser("Places"),
                                            transform=transform)
place_loader = torch.utils.data.DataLoader(place365, batch_size=256, shuffle=False, pin_memory=True, num_workers=10)

sun = torchvision.datasets.ImageFolder(root=os.path.expanduser("SUN"), transform=transform)
sun_loader = torch.utils.data.DataLoader(sun, batch_size=256, shuffle=False, pin_memory=True, num_workers=10)

In [7]:
class IDDataset(torch.utils.data.Dataset):
    def __init__(self, root, is_train, preprocess):
        self.preprocess = preprocess
        self.root = root
        if is_train:
            self.dataset = datasets.ImageNet(root=root, transform=self.preprocess, split="train")
        else:
            self.dataset = datasets.ImageNet(root=root, transform=self.preprocess, split="val")
        self.dataset_class = [tup[0] for tup in self.dataset.classes]

    def get_classes(self):
        return self.dataset_class

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return image, label

imagenet1k_dataset = IDDataset(root=os.path.expanduser("imagenet_1k"), is_train=False, preprocess=transform)
imagenet1k_dataloader = torch.utils.data.DataLoader(imagenet1k_dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=10)

imagenet1k_classes = [s.replace("_", " ") for s in imagenet1k_dataset.get_classes()]
print(imagenet1k_classes)

['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin', 'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl', 'European fire salamander', 'common newt', 'eft', 'spotted salamander', 'axolotl', 'bullfrog', 'tree frog', 'tailed frog', 'loggerhead', 'leatherback turtle', 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'common iguana', 'American chameleon', 'whiptail', 'agama', 'frilled lizard', 'alligator lizard', 'Gila monster', 'green lizard', 'African chameleon', 'Komodo dragon', 'African crocodile', 'American alligator', 'triceratops', 'thunder snake', 'ringneck snake', 'hognose snake', 'green snake', 'king snake', 'garter snake', 'water snake', 'vine snake', 'night snake', 'boa constrictor', 'rock python', 'Indian cobra', 'green mamba', 'sea snake', 'horned viper', 'diamondback', 

# Test

In [8]:
Agent = "a photo of a thing that we can see in nature."
test_in_ood_datasets(ID_dataloader=imagenet1k_dataloader, classes_list=imagenet1k_classes, ID_name="ImageNet-1k", Agent=Agent)

text input shape torch.Size([2000, 77])
ID test dataset has finished 0.00%
ID test dataset has finished 0.26%
ID test dataset has finished 0.51%
ID test dataset has finished 0.77%
ID test dataset has finished 1.02%
ID test dataset has finished 1.28%
ID test dataset has finished 1.53%
ID test dataset has finished 1.79%
ID test dataset has finished 2.05%
ID test dataset has finished 2.30%
ID test dataset has finished 2.56%
ID test dataset has finished 2.81%
ID test dataset has finished 3.07%
ID test dataset has finished 3.32%
ID test dataset has finished 3.58%
ID test dataset has finished 3.84%
ID test dataset has finished 4.09%
ID test dataset has finished 4.35%
ID test dataset has finished 4.60%
ID test dataset has finished 4.86%
ID test dataset has finished 5.12%
ID test dataset has finished 5.37%
ID test dataset has finished 5.63%
ID test dataset has finished 5.88%
ID test dataset has finished 6.14%
ID test dataset has finished 6.39%
ID test dataset has finished 6.65%
ID test dataset