In [1]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

Collecting ftfy
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.3
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-tqqz4j8r
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-tqqz4j8r
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369497 sha256=6e605a104f465ce9f8a405817c4aa85f1451067cd026416a2c9e8481748e771f
  Stored in directory: /tmp/pip-ephem-wheel-cache-icvfwlj6/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b

In [2]:
import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)

Torch version: 2.1.0+cu121


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Loading the model

In [4]:
import clip

clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [5]:
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 165MiB/s]


Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


## Image preprocessing

In [6]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7b7fa06abc70>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

## Example of tokenizing text

In [7]:
clip.tokenize("We are a bass connections team!")

tensor([[49406,   649,   631,   320,  5992, 37161,  1027,   256, 49407,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], dtype=torch.int32)

## Setting up input images and texts

In [8]:
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
from torchvision import transforms, models, datasets
from tqdm import tqdm

In [9]:
# Define transform and add image resolution
transform = transforms.Compose([preprocess])

# Load EuroSAT dataset
data_dir = '/content/drive/MyDrive/Train_Test_Splits'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), transform) for x in ['train', 'test']}
dataloaders = {
    'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=400, shuffle=True, num_workers=20),
    'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=400, shuffle=False, num_workers=20)
}

# Get classes
classes = image_datasets['test'].classes



In [14]:
# text descriptions
templates = [
    '{}',
    '{}.',
    'a photo of a {}.',
    'an image of a {}.',
    'a color image of a {}.',
    'a remote sensing image of many {}.',
    'a remote sensing image of a {}.',
    'a remote sensing image of the {}.',
    'a remote sensing image of the hard to see {}.',
    'a remote sensing image of a hard to see {}.',
    'a low resolution remote sensing image of the {}.',
    'a low resolution remote sensing image of a {}.',
    'a bad remote sensing image of the {}.',
    'a bad remote sensing image of a {}.',
    'a cropped remote sensing image of the {}.',
    'a cropped remote sensing image of a {}.',
    'a bright remote sensing image of the {}.',
    'a bright remote sensing image of a {}.',
    'a dark remote sensing image of the {}.',
    'a dark remote sensing image of a {}.',
    'a close-up remote sensing image of the {}.',
    'a close-up remote sensing image of a {}.',
    'a black and white remote sensing image of the {}.',
    'a black and white remote sensing image of a {}.',
    'a jpeg corrupted remote sensing image of the {}.',
    'a jpeg corrupted remote sensing image of a {}.',
    'a blurry remote sensing image of the {}.',
    'a blurry remote sensing image of a {}.',
    'a good remote sensing image of the {}.',
    'a good remote sensing image of a {}.',
    'a remote sensing image of the large {}.',
    'a remote sensing image of a large {}.',
    'a remote sensing image of the nice {}.',
    'a remote sensing image of a nice {}.',
    'a remote sensing image of the small {}.',
    'a remote sensing image of a small {}.',
    'a remote sensing image of the weird {}.',
    'a remote sensing image of a weird {}.',
    'a remote sensing image of the cool {}.',
    'a remote sensing image of a cool {}.',
    'an aerial image of many {}.',
    'an aerial image of a {}.',
    'an aerial image of the {}.',
    'an aerial image of the hard to see {}.',
    'an aerial image of a hard to see {}.',
    'a low resolution aerial image of the {}.',
    'a low resolution aerial image of a {}.',
    'a bad aerial image of the {}.',
    'a bad aerial image of a {}.',
    'a cropped aerial image of the {}.',
    'a cropped aerial image of a {}.',
    'a bright aerial image of the {}.',
    'a bright aerial image of a {}.',
    'a dark aerial image of the {}.',
    'a dark aerial image of a {}.',
    'a close-up aerial image of the {}.',
    'a close-up aerial image of a {}.',
    'a black and white aerial image of the {}.',
    'a black and white aerial image of a {}.',
    'a jpeg corrupted aerial image of the {}.',
    'a jpeg corrupted aerial image of a {}.',
    'a blurry aerial image of the {}.',
    'a blurry aerial image of a {}.',
    'a good aerial image of the {}.',
    'a good aerial image of a {}.',
    'an aerial image of the large {}.',
    'an aerial image of a large {}.',
    'an aerial image of the nice {}.',
    'an aerial image of a nice {}.',
    'an aerial image of the small {}.',
    'an aerial image of a small {}.',
    'an aerial image of the weird {}.',
    'an aerial image of a weird {}.',
    'an aerial image of the cool {}.',
    'an aerial image of a cool {}.',
    'a satellite image of many {}.',
    'a satellite image of a {}.',
    'a satellite image of the {}.',
    'a satellite image of the hard to see {}.',
    'a satellite image of a hard to see {}.',
    'a low resolution satellite image of the {}.',
    'a low resolution satellite image of a {}.',
    'a bad satellite image of the {}.',
    'a bad satellite image of a {}.',
    'a cropped satellite image of the {}.',
    'a cropped satellite image of a {}.',
    'a bright satellite image of the {}.',
    'a bright satellite image of a {}.',
    'a dark satellite image of the {}.',
    'a dark satellite image of a {}.',
    'a close-up satellite image of the {}.',
    'a close-up satellite image of a {}.',
    'a black and white satellite image of the {}.',
    'a black and white satellite image of a {}.',
    'a jpeg corrupted satellite image of the {}.',
    'a jpeg corrupted satellite image of a {}.',
    'a blurry satellite image of the {}.',
    'a blurry satellite image of a {}.',
    'a good satellite image of the {}.',
    'a good satellite image of a {}.',
    'a satellite image of the large {}.',
    'a satellite image of a large {}.',
    'a satellite image of the nice {}.',
    'a satellite image of a nice {}.',
    'a satellite image of the small {}.',
    'a satellite image of a small {}.',
    'a satellite image of the weird {}.',
    'a satellite image of a weird {}.',
    'a satellite image of the cool {}.',
    'a satellite image of a cool {}.',
]

In [15]:
accs = []

for template in templates:
    text_descriptions = [template.format(label) for label in classes]
    model.eval()

    # Lists to store embeddings, images, and labels
    embeddings = []
    images = []
    image_labels = []

    # Iterate through the test dataloader
    for inputs, labels in tqdm(dataloaders['test']):
        inputs = inputs.to('cuda')

        # Forward pass to get embeddings
        with torch.no_grad():
            features = model.encode_image(inputs).detach().cpu()
            features /= features.norm(dim=-1, keepdim=True)

        # Store image embeddings and labels
        images.extend(inputs.cpu().numpy())
        embeddings.extend(features.numpy())
        image_labels.extend(labels.cpu().numpy())

    # Convert lists to numpy arrays
    embeddings_array = np.array(embeddings)
    image_labels_array = np.array(image_labels)

    # Tokenize text descriptions
    text_tokens = clip.tokenize(text_descriptions).cuda()

    # Convert image embeddings to PyTorch tensor
    embeddings_tensor = torch.tensor(embeddings_array)

    # Move the tensor to GPU if CUDA is available
    if torch.cuda.is_available():
        embeddings_tensor = embeddings_tensor.to('cuda')

    with torch.no_grad():
        # Encode text features
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

        # Ensure data types are consistent
        if text_features.dtype == torch.float32:
            embeddings_tensor = embeddings_tensor.to(torch.float32)

    # Calculate text probabilities
    text_probs = (100.0 * embeddings_tensor @ text_features.T).softmax(dim=-1)
    top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

    # Convert image labels to PyTorch tensor
    true_labels_tensor = torch.tensor(image_labels_array, dtype=torch.long)

    # Ensure that both tensors are on the CPU
    predicted = top_labels[:, 0].cpu()
    true_labels = true_labels_tensor.cpu()
    correct_predictions = torch.sum(predicted == true_labels)

    # Calculate accuracy
    acc = correct_predictions.item() / len(image_labels)
    accs.append(acc)
    print(f"\nText #{len(accs)} => Accuracy: {acc * 100:.2f}%")
    print()

# Print the average accuracy
avg_acc = sum(accs) / len(accs)
print(f"\nAverage Accuracy over {len(accs)} different texts: {avg_acc * 100:.2f}%")

100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #1 => Accuracy: 37.19%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #2 => Accuracy: 41.87%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #3 => Accuracy: 30.91%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #4 => Accuracy: 31.50%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #5 => Accuracy: 32.15%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #6 => Accuracy: 46.91%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #7 => Accuracy: 42.87%



100%|██████████| 14/14 [00:17<00:00,  1.25s/it]



Text #8 => Accuracy: 40.80%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #9 => Accuracy: 48.13%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #10 => Accuracy: 47.57%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #11 => Accuracy: 42.50%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #12 => Accuracy: 44.11%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #13 => Accuracy: 40.09%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #14 => Accuracy: 41.37%



100%|██████████| 14/14 [00:18<00:00,  1.30s/it]



Text #15 => Accuracy: 38.17%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #16 => Accuracy: 38.04%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #17 => Accuracy: 40.17%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #18 => Accuracy: 42.52%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #19 => Accuracy: 38.46%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #20 => Accuracy: 43.28%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #21 => Accuracy: 40.50%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #22 => Accuracy: 41.70%



100%|██████████| 14/14 [00:17<00:00,  1.25s/it]



Text #23 => Accuracy: 38.63%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #24 => Accuracy: 43.11%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #25 => Accuracy: 41.48%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #26 => Accuracy: 41.39%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #27 => Accuracy: 39.50%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #28 => Accuracy: 40.61%



100%|██████████| 14/14 [00:17<00:00,  1.29s/it]



Text #29 => Accuracy: 41.98%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #30 => Accuracy: 40.54%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #31 => Accuracy: 40.43%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #32 => Accuracy: 41.56%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #33 => Accuracy: 38.19%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #34 => Accuracy: 38.52%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #35 => Accuracy: 44.67%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #36 => Accuracy: 46.80%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #37 => Accuracy: 39.13%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #38 => Accuracy: 36.85%



100%|██████████| 14/14 [00:17<00:00,  1.29s/it]



Text #39 => Accuracy: 42.22%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #40 => Accuracy: 40.00%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #41 => Accuracy: 42.52%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #42 => Accuracy: 40.63%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #43 => Accuracy: 39.19%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #44 => Accuracy: 38.26%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #45 => Accuracy: 41.94%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #46 => Accuracy: 43.87%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #47 => Accuracy: 44.96%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #48 => Accuracy: 36.19%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #49 => Accuracy: 39.56%



100%|██████████| 14/14 [00:18<00:00,  1.30s/it]



Text #50 => Accuracy: 41.80%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #51 => Accuracy: 43.81%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #52 => Accuracy: 35.46%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #53 => Accuracy: 34.46%



100%|██████████| 14/14 [00:18<00:00,  1.30s/it]



Text #54 => Accuracy: 32.56%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #55 => Accuracy: 34.00%



100%|██████████| 14/14 [00:18<00:00,  1.31s/it]



Text #56 => Accuracy: 45.28%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #57 => Accuracy: 42.98%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #58 => Accuracy: 27.57%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #59 => Accuracy: 36.78%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #60 => Accuracy: 42.20%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #61 => Accuracy: 43.13%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #62 => Accuracy: 35.78%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #63 => Accuracy: 35.57%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #64 => Accuracy: 35.98%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #65 => Accuracy: 44.06%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #66 => Accuracy: 35.61%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #67 => Accuracy: 36.57%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #68 => Accuracy: 32.80%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #69 => Accuracy: 36.00%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #70 => Accuracy: 34.37%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #71 => Accuracy: 34.74%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #72 => Accuracy: 33.43%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #73 => Accuracy: 32.20%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #74 => Accuracy: 34.31%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #75 => Accuracy: 35.04%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #76 => Accuracy: 42.57%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #77 => Accuracy: 33.02%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #78 => Accuracy: 33.85%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #79 => Accuracy: 43.72%



100%|██████████| 14/14 [00:17<00:00,  1.26s/it]



Text #80 => Accuracy: 42.31%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #81 => Accuracy: 38.98%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #82 => Accuracy: 38.48%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #83 => Accuracy: 39.83%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #84 => Accuracy: 46.13%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #85 => Accuracy: 37.39%



100%|██████████| 14/14 [00:18<00:00,  1.34s/it]



Text #86 => Accuracy: 36.19%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #87 => Accuracy: 33.74%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #88 => Accuracy: 31.69%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #89 => Accuracy: 36.19%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #90 => Accuracy: 37.13%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #91 => Accuracy: 42.07%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #92 => Accuracy: 38.35%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #93 => Accuracy: 31.09%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #94 => Accuracy: 30.89%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #95 => Accuracy: 37.83%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #96 => Accuracy: 38.56%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #97 => Accuracy: 39.94%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #98 => Accuracy: 37.67%



100%|██████████| 14/14 [00:18<00:00,  1.30s/it]



Text #99 => Accuracy: 37.69%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #100 => Accuracy: 42.35%



100%|██████████| 14/14 [00:18<00:00,  1.30s/it]



Text #101 => Accuracy: 35.00%



100%|██████████| 14/14 [00:18<00:00,  1.30s/it]



Text #102 => Accuracy: 33.28%



100%|██████████| 14/14 [00:18<00:00,  1.30s/it]



Text #103 => Accuracy: 35.67%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #104 => Accuracy: 38.35%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #105 => Accuracy: 34.30%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #106 => Accuracy: 35.35%



100%|██████████| 14/14 [00:17<00:00,  1.28s/it]



Text #107 => Accuracy: 34.69%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]



Text #108 => Accuracy: 33.09%



100%|██████████| 14/14 [00:18<00:00,  1.29s/it]



Text #109 => Accuracy: 36.35%



100%|██████████| 14/14 [00:17<00:00,  1.27s/it]


Text #110 => Accuracy: 36.02%


Average Accuracy over 110 different texts: 38.63%





In [12]:
# plt.figure(figsize=(8, 2 * len(images)))  # Adjust the figure size
# for i, image in enumerate(images):
#     plt.subplot(len(images), 1, i + 1)  # Arrange subplots in a single column
#     image_to_plot = np.transpose(image, (1, 2, 0))
#     plt.imshow(image_to_plot)
#     plt.axis("off")

# plt.subplots_adjust(hspace=0.5)  # Adjust horizontal spacing
# plt.show()