In [1]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
import skimage
import torch
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import Utilities as ut
from clip import clip
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.io import read_image

In [2]:
NUM_WORKERS = 8

In [2]:
from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [4]:
#preprocess is the model image encoder
model = model.cuda().eval()

In [5]:
def test_step_zero_shot_clip(net, data_loader, texts_z, device='cuda'):
    samples = 0.0
    cumulative_accuracy = 0.0

    # Set the network to evaluation mode
    net.eval()

    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets) in tqdm(enumerate(data_loader), total=len(data_loader), position=0, leave=True):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            images_z = model.encode_image(inputs).float()
            # the @ is the dot product
            outputs = (100 * images_z @ texts_z.T).softmax(dim=-1)

            # Fetch prediction and loss value
            samples += inputs.shape[0]
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_accuracy / samples * 100

In [None]:
class DatasetImageNetA(Dataset):
    def __init__(self, dataset_path, transform=None, target_transform=None):
        self.dataset_path = dataset_path
        self.transform = transform
        self.target_transform = target_transform

        # Read the mapping file
        with open(os.path.join(dataset_path, "README.txt"), "r") as f:
            lines = f.readlines()[12:]  # skips first 12 lines

        # Create the mapping dictionary
        self.mapping = {}
        for line in lines:
            split_line = line.split()
            if len(split_line) > 1:
                numeric_id = split_line[0][1:]
                name = " ".join(split_line[1:]).strip()
                self.mapping[int(numeric_id)] = name.lower()

        # Create the labels list
        labels = []
        for cl in self.mapping.keys():
            for file_name in os.listdir(os.path.join(dataset_path, f"n{str(cl).zfill(8)}")):
                labels.append((cl, file_name))
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(
            self.dataset_path, f"n{str(self.labels[idx][0]).zfill(8)}", self.labels[idx][1])
        image = read_image(img_path)
        label = self.labels[idx][0]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
val_data = DatasetImageNetA(os.path.join("dataset", "imagenet-a"))

In [None]:
def encode_data(images_path, texts, model_preprocessor):
    # Preprocess the images to transform from filenames to images to tensors
    images = [model_preprocessor(Image.open(image_path))
              for image_path in images_path]  # will crop and normalize

    # Preprocess the texts to transform from text to tensors
    images = torch.tensor(np.stack(images)).cuda()
    # the this is is to improve the precision of clip
    text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()

    # Encode the inputs
    with torch.no_grad():
        images_z = model.encode_image(images).float()
        texts_z = model.encode_text(text_tokens).float()

    return images_z, texts_z

In [None]:
DATASETS = {
    "mnist": torchvision.datasets.MNIST,
    "cifar10": torchvision.datasets.CIFAR10,
}

def embed_dataset_classnames(dataset_name, templates=["a photo of a {}."],):
    # Create the list of descriptions and tokenize them
    dataset = DATASETS[dataset_name](
        "./data", transform=preprocessor, download=True, train=False)
    classnames = dataset.classes

    texts_z_views = []
    for template in templates:
        descriptions = [template.format(c) for c in classnames]
        text_tokens = clip.tokenize(descriptions).cuda()

        # Get the normalized textual features
        with torch.no_grad():
            texts_z = model.encode_text(text_tokens).float()
            texts_z /= texts_z.norm(dim=-1, keepdim=True)
            texts_z_views.append(texts_z)

    # Evaluate the mean representation
    texts_z = torch.stack(texts_z_views).mean(dim=0)

    # Renormalise
    texts_z /= texts_z.norm(dim=-1, keepdim=True)

    return classnames, texts_z #encoded texts

In [None]:
def get_data(dataset_name, batch_size=64, transform=None, test_batch_size=256):
    dataset = DATASETS[dataset_name]

    if not transform:
        # Convert the PIL images to Tensors
        transform = torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor()])

    # Load data
    full_training_data = dataset(
        './data', train=True, transform=transform, download=True)
    test_data = dataset('./data', train=False,
                        transform=transform, download=True)

    # Create train and validation splits
    num_samples = len(full_training_data)
    training_samples = int(num_samples * 0.5 + 1)
    validation_samples = num_samples - training_samples

    training_data, validation_data = torch.utils.data.random_split(
        full_training_data, [training_samples, validation_samples])

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(
        training_data, batch_size, shuffle=True, num_workers=8)
    val_loader = torch.utils.data.DataLoader(
        validation_data, test_batch_size, shuffle=False, num_workers=8)
    test_loader = torch.utils.data.DataLoader(
        test_data, test_batch_size, shuffle=False, num_workers=8)

    return train_loader, val_loader, test_loader

In [None]:
def test_step_zero_shot_clip(net, data_loader, texts_z, device='cuda'):
    samples = 0.0
    cumulative_accuracy = 0.0

    # Set the network to evaluation mode
    net.eval()

    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets) in tqdm(enumerate(data_loader), total=len(data_loader), position=0, leave=True):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            images_z = model.encode_image(inputs).float()
            # the @ is the dot product
            outputs = (100 * images_z @ texts_z.T).softmax(dim=-1)

            # Fetch prediction and loss value
            samples += inputs.shape[0]
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_accuracy / samples * 100

In [None]:
def test_step_zero_shot_clip(net, data_loader, texts_z, device='cuda'):
    samples = 0.0
    cumulative_accuracy = 0.0

    # Set the network to evaluation mode
    net.eval()

    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets) in tqdm(enumerate(data_loader), total=len(data_loader), position=0, leave=True):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            images_z = model.encode_image(inputs).float()
            # the @ is the dot product
            outputs = (100 * images_z @ texts_z.T).softmax(dim=-1)

            # Fetch prediction and loss value
            samples += inputs.shape[0]
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_accuracy / samples * 100

In [None]:
dataset_name = "cifar10"

_, _, test_loader = get_data(
    dataset_name, transform=preprocess, batch_size=128)
# for each of the prompts
texts, texts_z = embed_dataset_classnames(dataset_name)
test_accuracy = test_step_zero_shot_clip(model, test_loader, texts_z)

print(f"Test accuracy {test_accuracy:.2f}")

In [3]:
print(model)

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [10]:
model.vision_model.post_layernorm

LayerNorm((768,), eps=1e-05, elementwise_affine=True)

In [17]:
processor.tokenizer.encode("a photo of")


[49406, 320, 1125, 539, 49407]