### Dataset: EuroSAT

In [1]:
import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)

Torch version: 2.0.1+cu117


In [2]:
import clip

clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [3]:
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


In [4]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7fdce04d6840>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [5]:
import os
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
from torchvision import transforms, models, datasets
from tqdm import tqdm

In [6]:
# Define transform and add image resolution
transform = transforms.Compose([preprocess])

# Load EuroSAT dataset
data_dir = '/data/scratch/public/eurosat/team_bassconnections_eurosat/Train_Test_Splits'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), transform) for x in ['train', 'test']}
dataloaders = {
    'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=400, shuffle=True, num_workers=20),
    'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=400, shuffle=False, num_workers=20)
}

# Get classes
classes = image_datasets['test'].classes

In [7]:
# text descriptions
templates = [
    '{}',
    '{}.',
    'a photo of a {}.',
    'an image of a {}.',
    'a color image of a {}.',
    'a remote sensing image of many {}.',
    'a remote sensing image of a {}.',
    'a remote sensing image of the {}.',
    'a remote sensing image of the hard to see {}.',
    'a remote sensing image of a hard to see {}.',
    'a low resolution remote sensing image of the {}.',
    'a low resolution remote sensing image of a {}.',
    'a bad remote sensing image of the {}.',
    'a bad remote sensing image of a {}.',
    'a cropped remote sensing image of the {}.',
    'a cropped remote sensing image of a {}.',
    'a bright remote sensing image of the {}.',
    'a bright remote sensing image of a {}.',
    'a dark remote sensing image of the {}.',
    'a dark remote sensing image of a {}.',
    'a close-up remote sensing image of the {}.',
    'a close-up remote sensing image of a {}.',
    'a black and white remote sensing image of the {}.',
    'a black and white remote sensing image of a {}.',
    'a jpeg corrupted remote sensing image of the {}.',
    'a jpeg corrupted remote sensing image of a {}.',
    'a blurry remote sensing image of the {}.',
    'a blurry remote sensing image of a {}.',
    'a good remote sensing image of the {}.',
    'a good remote sensing image of a {}.',
    'a remote sensing image of the large {}.',
    'a remote sensing image of a large {}.',
    'a remote sensing image of the nice {}.',
    'a remote sensing image of a nice {}.',
    'a remote sensing image of the small {}.',
    'a remote sensing image of a small {}.',
    'a remote sensing image of the weird {}.',
    'a remote sensing image of a weird {}.',
    'a remote sensing image of the cool {}.',
    'a remote sensing image of a cool {}.',
    'an aerial image of many {}.',
    'an aerial image of a {}.',
    'an aerial image of the {}.',
    'an aerial image of the hard to see {}.',
    'an aerial image of a hard to see {}.',
    'a low resolution aerial image of the {}.',
    'a low resolution aerial image of a {}.',
    'a bad aerial image of the {}.',
    'a bad aerial image of a {}.',
    'a cropped aerial image of the {}.',
    'a cropped aerial image of a {}.',
    'a bright aerial image of the {}.',
    'a bright aerial image of a {}.',
    'a dark aerial image of the {}.',
    'a dark aerial image of a {}.',
    'a close-up aerial image of the {}.',
    'a close-up aerial image of a {}.',
    'a black and white aerial image of the {}.',
    'a black and white aerial image of a {}.',
    'a jpeg corrupted aerial image of the {}.',
    'a jpeg corrupted aerial image of a {}.',
    'a blurry aerial image of the {}.',
    'a blurry aerial image of a {}.',
    'a good aerial image of the {}.',
    'a good aerial image of a {}.',
    'an aerial image of the large {}.',
    'an aerial image of a large {}.',
    'an aerial image of the nice {}.',
    'an aerial image of a nice {}.',
    'an aerial image of the small {}.',
    'an aerial image of a small {}.',
    'an aerial image of the weird {}.',
    'an aerial image of a weird {}.',
    'an aerial image of the cool {}.',
    'an aerial image of a cool {}.',
    'a satellite image of many {}.',
    'a satellite image of a {}.',
    'a satellite image of the {}.',
    'a satellite image of the hard to see {}.',
    'a satellite image of a hard to see {}.',
    'a low resolution satellite image of the {}.',
    'a low resolution satellite image of a {}.',
    'a bad satellite image of the {}.',
    'a bad satellite image of a {}.',
    'a cropped satellite image of the {}.',
    'a cropped satellite image of a {}.',
    'a bright satellite image of the {}.',
    'a bright satellite image of a {}.',
    'a dark satellite image of the {}.',
    'a dark satellite image of a {}.',
    'a close-up satellite image of the {}.',
    'a close-up satellite image of a {}.',
    'a black and white satellite image of the {}.',
    'a black and white satellite image of a {}.',
    'a jpeg corrupted satellite image of the {}.',
    'a jpeg corrupted satellite image of a {}.',
    'a blurry satellite image of the {}.',
    'a blurry satellite image of a {}.',
    'a good satellite image of the {}.',
    'a good satellite image of a {}.',
    'a satellite image of the large {}.',
    'a satellite image of a large {}.',
    'a satellite image of the nice {}.',
    'a satellite image of a nice {}.',
    'a satellite image of the small {}.',
    'a satellite image of a small {}.',
    'a satellite image of the weird {}.',
    'a satellite image of a weird {}.',
    'a satellite image of the cool {}.',
    'a satellite image of a cool {}.',
]

In [8]:
accs = []
recalls = []

for template in templates:
    text_descriptions = [template.format(label) for label in classes]
    model.eval()

    # Lists to store embeddings, images, and labels
    embeddings = []
    images = []
    image_labels = []

    # Iterate through the test dataloader
    for inputs, labels in tqdm(dataloaders['test']):
        inputs = inputs.to('cuda')

        # Forward pass to get embeddings
        with torch.no_grad():
            features = model.encode_image(inputs).detach().cpu()
            features /= features.norm(dim=-1, keepdim=True)

        # Store image embeddings and labels
        images.extend(inputs.cpu().numpy())
        embeddings.extend(features.numpy())
        image_labels.extend(labels.cpu().numpy())

    # Convert lists to numpy arrays
    embeddings_array = np.array(embeddings)
    image_labels_array = np.array(image_labels)

    # Tokenize text descriptions
    text_tokens = clip.tokenize(text_descriptions).cuda()

    # Convert image embeddings to PyTorch tensor
    embeddings_tensor = torch.tensor(embeddings_array)

    # Move the tensor to GPU if CUDA is available
    if torch.cuda.is_available():
        embeddings_tensor = embeddings_tensor.to('cuda')

    with torch.no_grad():
        # Encode text features
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

        # Ensure data types are consistent
        if text_features.dtype == torch.float32:
            embeddings_tensor = embeddings_tensor.to(torch.float32)

    # Calculate text probabilities
    text_probs = (100.0 * embeddings_tensor @ text_features.T).softmax(dim=-1)
    top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

    # Convert image labels to PyTorch tensor
    true_labels_tensor = torch.tensor(image_labels_array, dtype=torch.long)

    # Ensure that both tensors are on the CPU
    predicted = top_labels[:, 0].cpu()
    true_labels = true_labels_tensor.cpu()
    correct_predictions = torch.sum(predicted == true_labels)

    # Calculate accuracy
    acc = correct_predictions.item() / len(image_labels)
    accs.append(acc)
    print(f"\nText #{len(accs)} => Accuracy: {acc * 100:.2f}%")
    print()

    # Calculate mean recall
    true_positives = torch.sum((predicted == true_labels) & (predicted == 1)).item()
    actual_positives = torch.sum(true_labels == 1).item()
    recalls.append(true_positives / actual_positives)

# Print average accuracy
avg_acc = sum(accs) / len(accs)
highest_acc = max(accs)
lowest_acc = min(accs)
print(f"\nAverage Accuracy over {len(accs)} different texts: {avg_acc * 100:.2f}%")
print(f"Highest Accuracy: {highest_acc * 100:.2f}%")
print(f"Lowest Accuracy: {lowest_acc * 100:.2f}%")

# Print mean recall
mean_recall = sum(recalls) / len(recalls)
print(f"\nMean Recall: {mean_recall * 100:.2f}%")

100%|██████████| 14/14 [00:05<00:00,  2.55it/s]



Text #1 => Accuracy: 37.30%



100%|██████████| 14/14 [00:04<00:00,  2.94it/s]



Text #2 => Accuracy: 41.87%



100%|██████████| 14/14 [00:05<00:00,  2.73it/s]



Text #3 => Accuracy: 30.87%



100%|██████████| 14/14 [00:05<00:00,  2.79it/s]



Text #4 => Accuracy: 31.46%



100%|██████████| 14/14 [00:04<00:00,  2.94it/s]



Text #5 => Accuracy: 32.19%



100%|██████████| 14/14 [00:05<00:00,  2.74it/s]



Text #6 => Accuracy: 46.91%



100%|██████████| 14/14 [00:05<00:00,  2.70it/s]



Text #7 => Accuracy: 42.76%



100%|██████████| 14/14 [00:05<00:00,  2.68it/s]



Text #8 => Accuracy: 40.78%



100%|██████████| 14/14 [00:04<00:00,  3.00it/s]



Text #9 => Accuracy: 48.09%



100%|██████████| 14/14 [00:05<00:00,  2.62it/s]



Text #10 => Accuracy: 47.39%



100%|██████████| 14/14 [00:05<00:00,  2.61it/s]



Text #11 => Accuracy: 42.35%



100%|██████████| 14/14 [00:04<00:00,  2.94it/s]



Text #12 => Accuracy: 43.96%



100%|██████████| 14/14 [00:05<00:00,  2.76it/s]



Text #13 => Accuracy: 40.02%



100%|██████████| 14/14 [00:05<00:00,  2.65it/s]



Text #14 => Accuracy: 41.39%



100%|██████████| 14/14 [00:05<00:00,  2.79it/s]



Text #15 => Accuracy: 38.19%



100%|██████████| 14/14 [00:04<00:00,  2.95it/s]



Text #16 => Accuracy: 38.04%



100%|██████████| 14/14 [00:05<00:00,  2.74it/s]



Text #17 => Accuracy: 40.20%



100%|██████████| 14/14 [00:05<00:00,  2.77it/s]



Text #18 => Accuracy: 42.41%



100%|██████████| 14/14 [00:04<00:00,  2.93it/s]



Text #19 => Accuracy: 38.48%



100%|██████████| 14/14 [00:05<00:00,  2.53it/s]



Text #20 => Accuracy: 43.33%



100%|██████████| 14/14 [00:05<00:00,  2.75it/s]



Text #21 => Accuracy: 40.44%



100%|██████████| 14/14 [00:04<00:00,  2.87it/s]



Text #22 => Accuracy: 41.70%



100%|██████████| 14/14 [00:05<00:00,  2.61it/s]



Text #23 => Accuracy: 38.61%



100%|██████████| 14/14 [00:05<00:00,  2.74it/s]



Text #24 => Accuracy: 43.13%



100%|██████████| 14/14 [00:05<00:00,  2.72it/s]



Text #25 => Accuracy: 41.39%



100%|██████████| 14/14 [00:04<00:00,  2.91it/s]



Text #26 => Accuracy: 41.43%



100%|██████████| 14/14 [00:05<00:00,  2.68it/s]



Text #27 => Accuracy: 39.46%



100%|██████████| 14/14 [00:05<00:00,  2.65it/s]



Text #28 => Accuracy: 40.43%



100%|██████████| 14/14 [00:04<00:00,  2.88it/s]



Text #29 => Accuracy: 42.07%



100%|██████████| 14/14 [00:05<00:00,  2.51it/s]



Text #30 => Accuracy: 40.46%



100%|██████████| 14/14 [00:05<00:00,  2.69it/s]



Text #31 => Accuracy: 40.39%



100%|██████████| 14/14 [00:05<00:00,  2.75it/s]



Text #32 => Accuracy: 41.54%



100%|██████████| 14/14 [00:04<00:00,  3.04it/s]



Text #33 => Accuracy: 38.09%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]



Text #34 => Accuracy: 38.39%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]



Text #35 => Accuracy: 44.65%



100%|██████████| 14/14 [00:04<00:00,  3.01it/s]



Text #36 => Accuracy: 46.80%



100%|██████████| 14/14 [00:05<00:00,  2.50it/s]



Text #37 => Accuracy: 39.06%



100%|██████████| 14/14 [00:05<00:00,  2.72it/s]



Text #38 => Accuracy: 36.89%



100%|██████████| 14/14 [00:04<00:00,  2.86it/s]



Text #39 => Accuracy: 42.28%



100%|██████████| 14/14 [00:05<00:00,  2.65it/s]



Text #40 => Accuracy: 39.87%



100%|██████████| 14/14 [00:05<00:00,  2.61it/s]



Text #41 => Accuracy: 42.56%



100%|██████████| 14/14 [00:05<00:00,  2.66it/s]



Text #42 => Accuracy: 40.67%



100%|██████████| 14/14 [00:04<00:00,  2.96it/s]



Text #43 => Accuracy: 39.26%



100%|██████████| 14/14 [00:05<00:00,  2.69it/s]



Text #44 => Accuracy: 38.37%



100%|██████████| 14/14 [00:05<00:00,  2.68it/s]



Text #45 => Accuracy: 41.80%



100%|██████████| 14/14 [00:04<00:00,  2.86it/s]



Text #46 => Accuracy: 43.83%



100%|██████████| 14/14 [00:05<00:00,  2.46it/s]



Text #47 => Accuracy: 45.00%



100%|██████████| 14/14 [00:05<00:00,  2.73it/s]



Text #48 => Accuracy: 36.37%



100%|██████████| 14/14 [00:04<00:00,  2.85it/s]



Text #49 => Accuracy: 39.57%



100%|██████████| 14/14 [00:05<00:00,  2.78it/s]



Text #50 => Accuracy: 41.76%



100%|██████████| 14/14 [00:05<00:00,  2.61it/s]



Text #51 => Accuracy: 43.93%



100%|██████████| 14/14 [00:05<00:00,  2.63it/s]



Text #52 => Accuracy: 35.31%



100%|██████████| 14/14 [00:04<00:00,  2.96it/s]



Text #53 => Accuracy: 34.48%



100%|██████████| 14/14 [00:05<00:00,  2.56it/s]



Text #54 => Accuracy: 32.65%



100%|██████████| 14/14 [00:05<00:00,  2.75it/s]



Text #55 => Accuracy: 33.72%



100%|██████████| 14/14 [00:04<00:00,  2.86it/s]



Text #56 => Accuracy: 45.00%



100%|██████████| 14/14 [00:05<00:00,  2.65it/s]



Text #57 => Accuracy: 43.13%



100%|██████████| 14/14 [00:05<00:00,  2.71it/s]



Text #58 => Accuracy: 27.67%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]



Text #59 => Accuracy: 36.54%



100%|██████████| 14/14 [00:04<00:00,  2.85it/s]



Text #60 => Accuracy: 42.15%



100%|██████████| 14/14 [00:05<00:00,  2.37it/s]



Text #61 => Accuracy: 43.13%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]



Text #62 => Accuracy: 35.70%



100%|██████████| 14/14 [00:05<00:00,  2.74it/s]



Text #63 => Accuracy: 35.67%



100%|██████████| 14/14 [00:04<00:00,  2.85it/s]



Text #64 => Accuracy: 36.07%



100%|██████████| 14/14 [00:05<00:00,  2.59it/s]



Text #65 => Accuracy: 44.13%



100%|██████████| 14/14 [00:05<00:00,  2.60it/s]



Text #66 => Accuracy: 35.70%



100%|██████████| 14/14 [00:04<00:00,  2.82it/s]



Text #67 => Accuracy: 36.54%



100%|██████████| 14/14 [00:05<00:00,  2.55it/s]



Text #68 => Accuracy: 32.78%



100%|██████████| 14/14 [00:05<00:00,  2.69it/s]



Text #69 => Accuracy: 35.96%



100%|██████████| 14/14 [00:05<00:00,  2.79it/s]



Text #70 => Accuracy: 34.43%



100%|██████████| 14/14 [00:05<00:00,  2.76it/s]



Text #71 => Accuracy: 34.81%



100%|██████████| 14/14 [00:05<00:00,  2.71it/s]



Text #72 => Accuracy: 33.50%



100%|██████████| 14/14 [00:05<00:00,  2.76it/s]



Text #73 => Accuracy: 32.17%



100%|██████████| 14/14 [00:05<00:00,  2.74it/s]



Text #74 => Accuracy: 34.26%



100%|██████████| 14/14 [00:05<00:00,  2.51it/s]



Text #75 => Accuracy: 34.94%



100%|██████████| 14/14 [00:05<00:00,  2.76it/s]



Text #76 => Accuracy: 42.63%



100%|██████████| 14/14 [00:04<00:00,  2.85it/s]



Text #77 => Accuracy: 32.89%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]



Text #78 => Accuracy: 33.89%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]



Text #79 => Accuracy: 43.70%



100%|██████████| 14/14 [00:05<00:00,  2.58it/s]



Text #80 => Accuracy: 42.30%



100%|██████████| 14/14 [00:04<00:00,  2.95it/s]



Text #81 => Accuracy: 39.13%



100%|██████████| 14/14 [00:05<00:00,  2.43it/s]



Text #82 => Accuracy: 38.52%



100%|██████████| 14/14 [00:05<00:00,  2.68it/s]



Text #83 => Accuracy: 39.81%



100%|██████████| 14/14 [00:05<00:00,  2.78it/s]



Text #84 => Accuracy: 46.17%



100%|██████████| 14/14 [00:05<00:00,  2.80it/s]



Text #85 => Accuracy: 37.35%



100%|██████████| 14/14 [00:05<00:00,  2.70it/s]



Text #86 => Accuracy: 36.20%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]



Text #87 => Accuracy: 33.76%



100%|██████████| 14/14 [00:04<00:00,  2.82it/s]



Text #88 => Accuracy: 31.76%



100%|██████████| 14/14 [00:05<00:00,  2.58it/s]



Text #89 => Accuracy: 36.22%



100%|██████████| 14/14 [00:05<00:00,  2.67it/s]



Text #90 => Accuracy: 37.13%



100%|██████████| 14/14 [00:05<00:00,  2.75it/s]



Text #91 => Accuracy: 41.98%



100%|██████████| 14/14 [00:05<00:00,  2.78it/s]



Text #92 => Accuracy: 38.37%



100%|██████████| 14/14 [00:05<00:00,  2.68it/s]



Text #93 => Accuracy: 31.15%



100%|██████████| 14/14 [00:05<00:00,  2.69it/s]



Text #94 => Accuracy: 30.87%



100%|██████████| 14/14 [00:04<00:00,  2.92it/s]



Text #95 => Accuracy: 37.76%



100%|██████████| 14/14 [00:05<00:00,  2.57it/s]



Text #96 => Accuracy: 38.67%



100%|██████████| 14/14 [00:05<00:00,  2.72it/s]



Text #97 => Accuracy: 39.89%



100%|██████████| 14/14 [00:05<00:00,  2.79it/s]



Text #98 => Accuracy: 37.59%



100%|██████████| 14/14 [00:05<00:00,  2.74it/s]



Text #99 => Accuracy: 37.61%



100%|██████████| 14/14 [00:05<00:00,  2.62it/s]



Text #100 => Accuracy: 42.37%



100%|██████████| 14/14 [00:05<00:00,  2.66it/s]



Text #101 => Accuracy: 34.98%



100%|██████████| 14/14 [00:04<00:00,  2.90it/s]



Text #102 => Accuracy: 33.28%



100%|██████████| 14/14 [00:05<00:00,  2.63it/s]



Text #103 => Accuracy: 35.74%



100%|██████████| 14/14 [00:05<00:00,  2.71it/s]



Text #104 => Accuracy: 38.35%



100%|██████████| 14/14 [00:04<00:00,  2.82it/s]



Text #105 => Accuracy: 34.24%



100%|██████████| 14/14 [00:05<00:00,  2.62it/s]



Text #106 => Accuracy: 35.26%



100%|██████████| 14/14 [00:05<00:00,  2.65it/s]



Text #107 => Accuracy: 34.74%



100%|██████████| 14/14 [00:05<00:00,  2.77it/s]



Text #108 => Accuracy: 33.07%



100%|██████████| 14/14 [00:04<00:00,  2.92it/s]



Text #109 => Accuracy: 36.30%



100%|██████████| 14/14 [00:05<00:00,  2.64it/s]


Text #110 => Accuracy: 35.93%


Average Accuracy over 110 different texts: 38.62%
Highest Accuracy: 48.09%
Lowest Accuracy: 27.67%

Mean Recall: 37.91%



