<a href="https://colab.research.google.com/github/rohitrnath/LLM-Training-Colab-Sync/blob/main/OpenCLIP_Model_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from typing import Tuple

import torch
from PIL.Image import Image


class ClipApp:
    """
    This class consists of light-weight "app code" that is required to perform end to end inference with Clip.

    The app uses 1 model:
        * Clip

    For a given image input, the app will:
        * pre-process the image
        * pre-process the text
        * Run Clip inference
    """

    def __init__(
        self,
        clip_model: torch.nn.Module,
    ):
        # Open AI Clip
        self.text_encoder = clip_model.text_encoder
        self.image_encoder = clip_model.image_encoder
        # Preprocess Compose function from Open AI clip
        self.preprocess = clip_model.preprocess
        self.tokenizer = clip_model.tokenizer_func

    def predict(self, *args, **kwargs):
        # See predict_similarity.
        return self.predict_similarity(*args, **kwargs)

    def predict_similarity(
        self, image: torch.Tensor, text: torch.Tensor
    ) -> torch.Tensor:
        """
        Inputs:
            image: torch.Tensor (Shape: [1, 3, 224, 224])
                Processed image tensor with values normalized to be between 0-1.
            text: torch.Tensor (Shape: [1, 77])
                Processed text tensor to be tokenized.

        Outputs:
            logits_per_image: torch.Tensor (Shape: [num_images, num_text_prompts])

                Given a batch of images and a batch of text tokens, returns a tensor,
                containing the logit scores corresponding to each image per text input.
                The values are cosine similarities between the corresponding image and
                text features, times 100. The logits of text per image can be computed
                by doing a transpose.

        """
        with torch.no_grad():
            image_features = self.image_encoder(image)
            text_features = self.text_encoder(text)
            logits_per_image = image_features @ text_features.t()
        return logits_per_image.cpu().numpy()

    def process_image(self, image: Image) -> torch.Tensor:
        """Process image before calling forward.

        Inputs:
            image: PIL.Image
                Image loaded by Pillow must be provided.
                Example: image = Image.open('<path>')

        Outputs:
            processed_image: torch.Tensor (shape [1, 3, 224, 224])
                Layout: RGB
                The image is converted to torch tensor and normalized
                to be in the range of 0-1.
        """
        return self.preprocess(image).unsqueeze(0)

    def process_text(self, text: str) -> torch.Tensor:
        """Process text into tokens for forward call.

        Input:
            text: str
                Text prompt intended for inference.
                Example: "golden hour"

        Output:
            tokenized_tensor: torch.Tensor (shape: [1, 77])
            Example: tensor([[49406,  3878,  2232, 49407, 0, 0...]])

        """
        return self.tokenizer(text)

    def get_input_spec(
        self,
        image_size: Tuple[int, int] = (224, 224),
        text_size: Tuple[int, int] = (3, 77),
    ):
        # Get the input specification ordered (name -> (shape, type)) pairs for this model.
        #
        # This can be used with the qai_hub python API to declare
        # the model input specification upon submitting a profile job.
        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        return {
            "image": ((1, 3, *image_size), "float32"),
            "text": (text_size, "int32"),
        }

In [None]:
from typing import Callable

import torch
import torchvision

from qai_hub_models.utils.asset_loaders import SourceAsRoot, callback_with_retry
from qai_hub_models.utils.base_model import BaseModel, CollectionModel
from qai_hub_models.utils.input_spec import InputSpec

PRETRAINED_WEIGHTS = "ViT-B/16"
MODEL_ID = __name__.split(".")[-2]
MODEL_ASSET_VERSION = 1
OPENAI_CLIP_SOURCE_REPOSITORY = "https://github.com/openai/CLIP"
OPENAI_CLIP_SOURCE_REPO_COMMIT = "a1d071733d7111c9c014f024669f959182114e33"


def load_clip_and_tokenizer():
    """Downloading pretrained weights via OpenAI and loading them."""
    with SourceAsRoot(
        OPENAI_CLIP_SOURCE_REPOSITORY,
        OPENAI_CLIP_SOURCE_REPO_COMMIT,
        MODEL_ID,
        MODEL_ASSET_VERSION,
    ):
        import clip

        tokenizer_func = clip.tokenize
        net, preprocess = clip.load(PRETRAINED_WEIGHTS)
        return net, preprocess, tokenizer_func


class Clip(CollectionModel):
    def __init__(
        self,
        text_encoder: torch.nn.Module,
        image_encoder: torch.nn.Module,
        preprocess: torchvision.transforms.transforms.Compose,
        tokenizer_func: Callable,
    ):
        super().__init__()
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        self.preprocess = preprocess
        self.tokenizer_func = tokenizer_func

    @staticmethod
    def from_pretrained():
        net, preprocess, tokenizer_func = callback_with_retry(
            num_retries=5, callback=load_clip_and_tokenizer
        )
        return Clip.from_source_model(net, preprocess, tokenizer_func)

    @staticmethod
    def from_source_model(net, preprocess, tokenizer_func):
        net = net.eval()
        text_encoder = ClipTextEncoder(net)
        image_encoder = ClipImageEncoder(net)
        return Clip(text_encoder, image_encoder, preprocess, tokenizer_func)


class ClipTextEncoder(BaseModel):
    def __init__(self, net: torch.nn.Module):
        super().__init__()
        """ Wrapper for OpenAI CLIP."""
        self.net = net
        self.eot_token = 49407

    def forward(self, text: torch.Tensor):
        """Forward call on Open AI CLIP model.

        Inputs:
            text: torch.Tensor (Shape: [1, 77] context_length=77)
                Processed text tensor to be tokenized.

        Outputs:
            text_features: torch.Tensor [512 (transformer_width), num_text_prompts]
                Raw text features are returned. When multiplied to image features,
                you can obtain a matrix of cosine similarities between the
                corresponding image and text input.

        """
        clipped_text = torch.clip(text, min=0, max=self.eot_token)
        text_features = self.net.encode_text(clipped_text)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        return text_features

    def get_input_spec(
        self,
        batch_size: int = 1,
        text_length: int = 77,
    ) -> InputSpec:
        # Get the input specification ordered (name -> (shape, type)) pairs for this model.
        #
        # This can be used with the qai_hub python API to declare
        # the model input specification upon submitting a profile job.
        return {
            "text": ((batch_size, text_length), "int32"),
        }

    @classmethod
    def from_pretrained(cls):
        return Clip.from_pretrained().text_encoder


class ClipImageEncoder(BaseModel):
    def __init__(self, net: torch.nn.Module):
        super().__init__()
        """ Wrapper for OpenAI Clip."""
        self.net = net
        self.eot_token = 49407

    def forward(self, image: torch.Tensor):
        """Forward call on Open AI Clip model.

        Inputs:
            image: torch.Tensor (Shape: [1, 3, 224, 224])
                Processed image tensor with values normalized to be between 0-1.
                Channel Layout: RGB

        Outputs:
            image_features: torch.Tensor [num_images, 512 (transformer_width)]
                Raw image features (multiplied to 100) are returned.
                When multiplied to text features, you can obtain a
                matrix of cosine similarities between the corresponding image and
                text input.

        """
        image_features = self.net.encode_image(image)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        return self.net.logit_scale.exp() * image_features

    def get_input_spec(
        self,
        height: int = 224,
        width: int = 224,
    ) -> InputSpec:
        # Get the input specification ordered (name -> (shape, type)) pairs for this model.
        #
        # This can be used with the qai_hub python API to declare
        # the model input specification upon submitting a profile job.
        return {
            "image": ((1, 3, height, width), "float32"),
        }

    @classmethod
    def from_pretrained(cls):
        return Clip.from_pretrained().image_encoder

In [3]:
import argparse
def add_output_dir_arg(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    parser.add_argument(
        "--output-dir",
        "-o",
        type=str,
        default=None,
        help="If specified, saves demo output (e.g. image) to this directory instead of displaying.",
    )
    return parser

In [None]:
import os

import numpy as np
import torch

from qai_hub_models.models.openai_clip.model import  Clip
from qai_hub_models.utils.display import display_or_save_image


# Run Clip on a directory of images with a query text.
# The demo will display similarity score for each image.
def main(is_test: bool = False):
    # Demo parameters
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--image-dir",
        type=str,
        default=None,
        help="Path to image directory",
    )
    parser.add_argument(
        "--image-names",
        type=str,
        default="image1.jpg,image2.jpg,image3.jpg",
        help="Specify names of the images in the folder.",
    )
    parser.add_argument(
        "--text",
        type=str,
        default="camping under the stars",
        help="Text prompt for image search",
    )
    add_output_dir_arg(parser)
    args = parser.parse_args([] if is_test else None)

    # Load model
    clip_model = Clip.from_pretrained()
    app = ClipApp(clip_model=clip_model)

    image_names = args.image_names.split(",")
    text = app.process_text(args.text)
    images = []

    # Iterate through images and text provided by user
    for filename in image_names:
        # Make sure the file is an image
        if os.path.splitext(filename)[1].lower() in [".jpg", ".jpeg", ".png"]:
            if args.image_dir:
                image = os.path.join(args.image_dir, filename)
            # Preprocess image and text pair
            image = app.process_image(Image.open(image))
            images.append(image)

        else:
            print(f"Skipping file {filename}")

    images = torch.stack(images).squeeze(1)

    # Compute similarity
    predictions = app.predict_similarity(images, text).flatten()

    # Display all the images and their score wrt to the text prompt provided.
    print(f"Searching images by prompt: {args.text}")
    for i in range(len(predictions)):
        print(
            f"\t Image with name: {image_names[i]} has a similarity score={predictions[i]}"
        )

    # Show image
    print("Displaying the most relevant image")

    selected_image = image_names[np.argmax(predictions)]
    if args.image_dir:
        selected_image = os.path.join(args.image_dir, selected_image)
    most_relevant_image = Image.open(selected_image)

    if not is_test:
        display_or_save_image(most_relevant_image, args.output_dir)


if __name__ == "__main__":
    main()