# Q2: Contrastive Language-Image Pretraining [6 points]

## 1. [1 point] Setup models.

In [1]:
import torch
import clip 
import torchvision.models as models

In [None]:
device1 = "cuda:0" if torch.cuda.is_available() else "cpu"
device2 = "cuda:1" if torch.cuda.is_available() else "cpu"
# ImageNet pretraining
resnet50_imagenet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).to(device1)

In [None]:
# OpenAI's CLIP
clip_model, preprocess = clip.load("RN50", device=device2)
clip_image_encoder = clip_model.visual.to(device2)

The ImageNet-pretrained ResNet-50 has weights trained for image classification on the ImageNet dataset. The CLIP model's visual encoder, however, has been further trained on a large dataset of image-text pairs. This additional training allows the CLIP encoder to learn features that are not only good for classification but also for understanding the semantic relationship between images and text.

## 2. [1 point] Setup data. Understand the ImageNet challenge dataset

The ImageNet challenge dataset uses the WordNet hierarchy for organizing its 1000 labels (also called synsets) associated with ILSVRC.

**(i) Label Hierarchy (WordNet):**

* WordNet is a lexical database that groups words into synonym sets (synsets) and connects them through defined relationships like "is-a," "part-of," etc.
* ImageNet leverages this hierarchy by assigning each synset a unique label and using the WordNet structure to organize related concepts.

**(ii) What is a Synset?**

* A synset is a group of words or phrases that share the same meaning. In ImageNet, it represents a category of objects within the 1000 labels.
* For example, the synset "n01532825" might correspond to the label "golden retriever."

**(iii) Problems with Grouping by Synsets for Visual Recognition?**

* Yes, grouping objects based on synsets can lead to challenges for visual recognition tasks. Here's why:
    * **Intra-class Variability:** Synsets can encompass objects with significant visual differences. A synset for "dog" might include images of golden retrievers, poodles, chihuahuas, etc. These breeds have distinct appearances that a model needs to learn.
    * **Background Clutter:** Images within a synset might vary considerably due to background clutter. A "chair" synset could have images of chairs in different rooms, with different objects around them. The model needs to focus on the chair itself despite these variations.
    * **Pose and Viewpoint:** Objects within a synset can appear in various poses and viewpoints. A "car" synset could have images of cars from the front, side, or back. The model needs to be robust to these pose variations.


**(iv) 3 Types of Visual Differences in Images with the Same Synset:**

1. **Object Appearance:** As mentioned earlier, objects within a synset can have significant visual differences in terms of breed, shape, size, color, or material.
2. **Background Complexity:** Images can vary in background complexity, with objects appearing in cluttered environments, outdoors, or with other objects around them.
3. **Pose and Viewpoint:** The pose and viewpoint of the object can differ significantly within a synset. Objects can be tilted, rotated, or partially occluded, requiring the model to recognize them from various perspectives.


In [None]:
# 3. Setup zero-shot CLIP
import time
from tqdm import tqdm
from PIL import Image

def get_image_features(image_path):
    with open(image_path, "rb") as f:
        image = Image.open(f).convert("RGB")
    image = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
    return image_features

def compute_probability_scores(image_features):
    text_features = []
    with torch.no_grad():
        for label in tqdm(IMAGENET2012_CLASSES.values()):
            text = clip.tokenize([label]).to(device)
            curr_features = clip_model.encode_text(text)
            text_features.append(curr_features)
    text_features = torch.cat(text_features, dim=0)
    logits_per_image = torch.cosine_similarity(image_features, text_features, dim=-1)
    probs = torch.softmax(logits_per_image, dim=-1)
    return probs

# Test CLIP on a few example images
example_images = os.listdir('imagenet-sample-images')
example_images = [example_image for example_image in example_images if example_image.endswith('.JPEG')]

example_images = example_images[:5]

classes_list = list(IMAGENET2012_CLASSES.values())
for image_path in example_images:
    image_features = get_image_features(os.path.join('imagenet-sample-images', image_path))
    probs = compute_probability_scores(image_features)
    top5_indices = torch.topk(probs, k=5, dim=-1).indices.tolist()
    top5_categories = [classes_list[idx] for idx in top5_indices]
    print(f"Image: {image_path}")
    print(f"Top 5 predicted categories: {top5_categories}")