# Building CLIP Model for Zero Shot Imag Classification using TT-NN

CLIP (Contrastive Language-Image Pre-Training) is a foundational multimodal AI model developed by OpenAI that learns visual concepts from natural language supervision. Unlike traditional computer vision models that are trained on fixed categories, CLIP can understand and classify images based on arbitrary text descriptions.

## What CLIP Does

CLIP bridges the gap between vision and language by learning to associate images with their textual descriptions. The model consists of two main components:

1. **Vision Encoder**: A Vision Transformer (ViT) that processes images and converts them into feature embeddings
2. **Text Encoder**: A Transformer that processes text descriptions and converts them into feature embeddings

![CLIP Diagram](https://media.githubusercontent.com/media/tenstorrent/tutorial-assets/nmaurice/clip-tutorial/media/clip_tutorial/CLIP.png)

During inference, CLIP can:
- **Zero-shot image classification**: Classify images into categories it has never explicitly seen during training by comparing image embeddings with text embeddings of category descriptions
- **Image-text similarity**: Measure how well an image matches a given text description
- **Content-based image retrieval**: Find images that best match a text query



In this tutorial, we implement CLIP for image classification. Our application will classifiy an image using natural language prompts such as "a diagram", "a dog", or "a cat". 
We use pre-trained weights of OpenAI's clip-vit-base-patch32 model and focus on inference.



## Imports and Dependencies

We start by importing the necessary libraries for our CLIP implementation:

- **ttnn**
- **torch**: model loading and tensor pre-processing
- **transformers**: Hugging Face library for downloading pre-trained models and tokenzing prompts
- **PIL**: Python Imaging Library for image pre-processing
- **torchvision**: Computer vision utilities for image preprocessing

In [None]:
import ttnn
import torch
from loguru import logger
import re
import os
import math
import numpy as np
from PIL import Image
from transformers import CLIPTokenizer, CLIPModel
import requests
from io import BytesIO
import time

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode

## TT-NN Device Management and Utility Functions

We define helper functions to manage TT-NN devices and handle tensor conversions between PyTorch and TT-NN formats. These utilities simplify device initialization, tensor format conversions, and memory management throughout our CLIP implementation.

In [None]:

def open_ttnn():
    """Initialize TT-NN device with specified L1 cache size."""
    global device
    device = ttnn.open_device(device_id=0, l1_small_size=8192)

def close_ttnn():
    """Clean up and close the TT-NN device."""
    global device
    if device is not None:
        ttnn.close_device(device)

def get_device():
    """Get the current TT-NN device handle."""
    global device
    return device

def convert_from_ttnn(x):
    """Convert TT-NN tensor to PyTorch tensor if needed."""
    global device
    if isinstance(x, ttnn._ttnn.tensor.Tensor):
        return ttnn.to_torch(x)
    return x

def to_ttnn(torch_tensor, dtype=None, layout=ttnn.TILE_LAYOUT):
    """Convert PyTorch tensor to TT-NN tensor with specified dtype and layout."""
    global device
    ttnn_tensor = ttnn.from_torch(torch_tensor, device=device, layout=layout, dtype=dtype)
    return ttnn_tensor

def to_torch_shape(ttnn_shape):
    """Convert TT-NN shape to PyTorch-compatible tuple."""
    return tuple(ttnn_shape)

def convert_ttnn_dtype(ttnn_tensor, dtype, new_shape=None):
    """
    Change dtype of TT-NN tensor and optionally reshape.
    Note: Currently requires moving tensor to host for dtype conversion.
    """
    device = get_device()
    # Move tensor to host for dtype conversion (current TT-NN limitation)
    host_tensor = ttnn.from_device(ttnn_tensor)
    host_tensor = ttnn.to_dtype(host_tensor, dtype=dtype)
    if new_shape is not None:
        host_tensor = ttnn.reshape(host_tensor, new_shape)

    return ttnn.to_device(host_tensor, device=device)


## Model Weight Conversion

Since TT-NN does not natively support weight loading from pre-trained models, we rely on PyTorch's model loading capabilities and then convert the weights to TT-NN format. The following helper function converts an entire model's state dictionary from PyTorch tensors to TT-NN tensors, enabling us to use pre-trained CLIP weights on TT hardware.  

In [None]:
def convert_model_to_ttnn(state_dict):
    """
    Convert a PyTorch model's state dictionary to TT-NN format.
    
    Args:
        state_dict: PyTorch model state dictionary containing weights and biases
        
    Returns:
        dict: State dictionary with tensors converted to TT-NN format
    """
    ttnn_state_dict = {}
    logger.info(f"Converting model to TT-NN format")

    # Convert each tensor in the state dictionary to TT-NN format
    for key, value in state_dict.items():
        if isinstance(value, torch.Tensor):
            # Convert PyTorch tensors to TT-NN tensors
            state_dict[key] = to_ttnn(value)
        elif isinstance(value, torch.Size):
            # Convert PyTorch Size objects to TT-NN Size objects
            state_dict[key] = ttnn.Size(value)

    return state_dict

## Generic Transformer Implementation

CLIP uses two types of transformers: a text transformer and a vision transformer. To maximize code reuse, we define a generic Transformer class that can be used for both modalities with appropriate configuration.

### Transformer Architecture

The transformer models used by CLIP consist of multiple layers (residual blocks), each containing the following sub-operations in sequence:

1. **Layer Normalization**: Normalizes inputs for stable inference (and training)
2. **Multi-Head Self-Attention**: 
   - For text: Uses causal masking to prevent attending to future tokens
   - For vision: Uses full attention across all image patches
3. **Layer Normalization**: Second normalization layer
4. **MLP (Multi-Layer Perceptron)**: Two linear layers with GELU activation (Linear → GELU → Linear)

Each block uses residual connections, where the output of each sub-operation is added to its input, enabling deeper networks and better gradient flow. 

In [None]:

class Transformer:
    def __init__(self, state_dict, heads, attention_mask=None, prefix=""):
        """
        Initialize a generic Transformer that can be used for both text and vision encoding.
        
        Args:
            state_dict: Model weights dictionary
            heads: Number of attention heads
            attention_mask: Attention mask for causal attention (used for text, None for vision)
            prefix: Prefix for layer names in state_dict (e.g., "text_model.encoder" or "vision_model.encoder")
        """
        self.layers = []
        self.heads = heads
        self.attention_mask = attention_mask
        self.prefix = prefix

        # Use regex to find all layer indices in the state dictionary
        layer_pattern = re.compile(f"{prefix}\.layers\.(\d+)\.")

        # Count number of transformer layers by finding unique layer indices
        layers_ids = set()
        for k in state_dict.keys():
            re_match = re.search(layer_pattern, k)
            if re_match:
                layers_ids.add(re_match.group(1))

        num_layers = len(layers_ids)

        # Initialize each transformer layer with converted weights
        for i in range(0, num_layers):
            resblock_prefix = f"{prefix}.layers.{i}"

            # Extract and convert all weights for this layer to bfloat16 precision
            self.layers.append(
                {
                    # First layer normalization weights
                    "ln_1_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.layer_norm1.weight"], ttnn.bfloat16
                    ),
                    "ln_1_bias": convert_ttnn_dtype(state_dict[f"{resblock_prefix}.layer_norm1.bias"], ttnn.bfloat16),
                    
                    # Multi-head attention projection weights (Q, K, V)
                    "q_proj_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.q_proj.weight"], ttnn.bfloat16
                    ),
                    "q_proj_bias": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.q_proj.bias"], ttnn.bfloat16
                    ),
                    "k_proj_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.k_proj.weight"], ttnn.bfloat16
                    ),
                    "k_proj_bias": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.k_proj.bias"], ttnn.bfloat16
                    ),
                    "v_proj_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.v_proj.weight"], ttnn.bfloat16
                    ),
                    "v_proj_bias": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.v_proj.bias"], ttnn.bfloat16
                    ),
                    
                    # Attention output projection weights
                    "out_proj_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.out_proj.weight"], ttnn.bfloat16
                    ),
                    "out_proj_bias": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.self_attn.out_proj.bias"], ttnn.bfloat16
                    ),
                    
                    # Second layer normalization weights
                    "ln_2_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.layer_norm2.weight"], ttnn.bfloat16
                    ),
                    "ln_2_bias": convert_ttnn_dtype(state_dict[f"{resblock_prefix}.layer_norm2.bias"], ttnn.bfloat16),
                    
                    # MLP weights (feed-forward network)
                    "mlp_c_fc_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.mlp.fc1.weight"], ttnn.bfloat16
                    ),
                    "mlp_c_fc_bias": convert_ttnn_dtype(state_dict[f"{resblock_prefix}.mlp.fc1.bias"], ttnn.bfloat16),
                    "mlp_c_proj_weight": convert_ttnn_dtype(
                        state_dict[f"{resblock_prefix}.mlp.fc2.weight"], ttnn.bfloat16
                    ),
                    "mlp_c_proj_bias": convert_ttnn_dtype(state_dict[f"{resblock_prefix}.mlp.fc2.bias"], ttnn.bfloat16),
                }
            )


    def forward(self, x):
        def mlp(x, layer):
            x = ttnn.linear(x, layer["mlp_c_fc_weight"], bias=layer["mlp_c_fc_bias"], transpose_b=True)
            x = ttnn.gelu(x)
            x = ttnn.linear(x, layer["mlp_c_proj_weight"], bias=layer["mlp_c_proj_bias"], transpose_b=True)
            return x

        def multi_head_attention(
            hidden_states,
            fused_qkv_weight,
            fused_qkv_bias,
            self_output_weight,
            self_output_bias,
            attention_mask=None,
            prefix="",
        ):
            seq_length, batch_size, hidden_size = hidden_states.shape

            self._embed_dim = hidden_size
            self._head_dim = hidden_size // self.heads
            self._scale = self._head_dim**-0.5
            self._attention_dropout = 0.0  # Unused

            compute_kernel_config = ttnn.WormholeComputeKernelConfig(
                math_fidelity=ttnn.MathFidelity.HiFi4,
                math_approx_mode=False,
                fp32_dest_acc_en=True,
                packer_l1_acc=True,
            )

            # Note: KV-caching not implemented (not needed for single forward pass)
            (q_weights, k_weights, v_weights) = fused_qkv_weight
            (q_bias, k_bias, v_bias) = fused_qkv_bias

            # Compute Q, K, V projections
            q = ttnn.linear(hidden_states, q_weights, bias=q_bias, transpose_b=True)
            k = ttnn.linear(hidden_states, k_weights, bias=k_bias, transpose_b=True)
            v = ttnn.linear(hidden_states, v_weights, bias=v_bias, transpose_b=True)

            # Reshape to [batch_size, seq_length, num_heads, head_dim]
            q = ttnn.reshape(q, (seq_length, batch_size * self.heads, self._head_dim))
            k = ttnn.reshape(k, (seq_length, batch_size * self.heads, self._head_dim))
            v = ttnn.reshape(v, (seq_length, batch_size * self.heads, self._head_dim))

            # Transpose to [batch_size, num_heads, seq_length, head_dim] for attention computation
            q = ttnn.transpose(q, 0, 1)
            k = ttnn.transpose(k, 0, 1)
            v = ttnn.transpose(v, 0, 1)

            # Compute attention scores with proper scaling
            scores = ttnn.matmul(q, ttnn.transpose(k, -2, -1))
            scores = scores * self._scale

            # Apply attention mask if provided (matching PyTorch MHA behavior)
            if attention_mask is not None:
                # Convert attention mask to the right shape and add to scores
                # PyTorch MHA expects mask to be broadcastable to [batch_size, num_heads, seq_len, seq_len]
                scores = scores + attention_mask

            attn_weights = ttnn.softmax(
                scores, dim=-1, numeric_stable=True, compute_kernel_config=compute_kernel_config
            )

            # Apply attention weights to values
            attn_output = ttnn.matmul(attn_weights, v)

            # Reshape to [batch_size, seq_length, embed_dim]
            attn_output = ttnn.transpose(attn_output, 0, 1)
            attn_output = ttnn.reshape(attn_output, (seq_length, batch_size, self._embed_dim))

            # Apply output projection
            dense_out = ttnn.linear(
                attn_output,
                self_output_weight,
                bias=self_output_bias,
                compute_kernel_config=compute_kernel_config,
                transpose_b=True,
            )

            return dense_out

        def residual_attention_block(x, layer, i=0):
            # LayerNorm
            residual = x
            x = ttnn.layer_norm(x, weight=layer["ln_1_weight"], bias=layer["ln_1_bias"])

            # Multihead attention / Self-Attention
            # This must be equal to nn.MultiheadAttention(d_model, n_head)(x, x, x, need_weights=False, attn_mask=self.attn_mask)
            x_attn = multi_head_attention(
                x,
                fused_qkv_weight=(layer["q_proj_weight"], layer["k_proj_weight"], layer["v_proj_weight"]),
                fused_qkv_bias=(layer["q_proj_bias"], layer["k_proj_bias"], layer["v_proj_bias"]),
                self_output_weight=layer["out_proj_weight"],
                self_output_bias=layer["out_proj_bias"],
                attention_mask=self.attention_mask,
                prefix=f"{self.prefix}.layers.{i}.attn",
            )  # Vision transformer doesn't use attention mask

            x = residual + x_attn

            # LayerNorm
            x_post_ln_2 = ttnn.layer_norm(x, weight=layer["ln_2_weight"], bias=layer["ln_2_bias"])

            # Multi-Layer Perceptron
            x = x + mlp(x_post_ln_2, layer)

            return x

        for i in range(len(self.layers)):
            layer = self.layers[i]
            x = residual_attention_block(x, layer, i)

        return x

## Vision Transformer Implementation

The VisionTransformer class handles image processing for CLIP. It converts input images into patch embeddings, adds positional encodings, and processes them through transformer layers.

### Vision Processing Pipeline

1. **Patch Embedding**: Converts 2D image into sequence of patch embeddings using convolution
2. **Class Token**: Prepends a learnable classification token to the sequence
3. **Positional Encoding**: Adds positional information to each patch
4. **Transformer Layers**: Processes the sequence through multiple attention layers
5. **Classification Head**: Extracts features from the class token for final representation

The `forward()` method orchestrates this entire pipeline, preprocessing image embeddings and calling the generic transformer.

In [None]:
class VisionTransformer:
    def __init__(self, state_dict):
        torch.manual_seed(0)
        self.output_dim = 0

        conv2_state_dict_name = "vision_model.embeddings.patch_embedding.weight"
        self.vision_width = state_dict[conv2_state_dict_name].shape[0]
        self.patch_size = state_dict[conv2_state_dict_name].shape[-1]
        self.vision_heads = self.vision_width // 64

        self.class_embedding = convert_ttnn_dtype(
            state_dict["vision_model.embeddings.class_embedding"], dtype=ttnn.bfloat16
        )
        self.positional_embedding = convert_ttnn_dtype(
            state_dict["vision_model.embeddings.position_embedding.weight"], dtype=ttnn.bfloat16
        )

        self.proj = convert_ttnn_dtype(state_dict["visual_projection.weight"], dtype=ttnn.bfloat16)

        # Weights for convolution layer
        # For sharding; use all cores; strategy = block sharding
        core_grid = ttnn.CoreGrid(x=8, y=8)
        # Error: Physical shard shape (8216, 4) must be tile {32, 32} sized
        # memory_config = ttnn.create_sharded_memory_config(conv1_weights_shape, core_grid, ttnn.ShardStrategy.HEIGHT)
        memory_config = ttnn.DRAM_MEMORY_CONFIG
        self.conv1_weights = ttnn.to_layout(
            state_dict[conv2_state_dict_name],
            layout=ttnn.ROW_MAJOR_LAYOUT,
            memory_config=memory_config,
            dtype=ttnn.bfloat16,
        )
        self.conv1_weights = convert_ttnn_dtype(self.conv1_weights, dtype=ttnn.bfloat16)

        assert self.conv1_weights.dtype == ttnn.bfloat16

        # Layer normalization applied before transformer layers
        self.ln_pre_weights = state_dict["vision_model.pre_layrnorm.weight"]
        self.ln_pre_bias = state_dict["vision_model.pre_layrnorm.bias"]

        # Layer normalization applied after transformer layers (to class token)
        self.ln_post_weights = state_dict["vision_model.post_layernorm.weight"]
        self.ln_post_bias = state_dict["vision_model.post_layernorm.bias"]

        self.transformer = Transformer(
            state_dict, self.vision_heads, attention_mask=None, prefix="vision_model.encoder"
        )

    def forward(self, x):
        (batch_size, in_channels, height, width) = x.shape

        # Note: ttnn.conv2d uses 'Array of Struct' shape for input tensor:
        # (N, H, W, C_in)
        # whereas torch.nn.Conv2d uses 'Struct of Array' shape for input tensor:
        # (N, C_in, H, W)
        #
        # # Moreover, ttnn.conv2d produces a flattened output tensor:
        # (N, C_in, H, W) -> (1, 1, N * H * W, C_out)
        # whereas torch.nn.Conv2d produces a 4D tensor:
        # (N, C_out, H_out, W_out)

        # Also:
        # ttnn.conv2d only take a tuple for kernel_size and stride

        # Change tensor layout to (N, H, W, C_in)
        x = ttnn.permute(x, [0, 2, 3, 1])  # (N, C_in, H, W) -> (N, H, W, C_in)

        # Note: ttnn.conv2d requires row-major layout for weight tensor
        x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)

        out_channels = 768

        x = ttnn.conv2d(
            input_tensor=x, 
            weight_tensor=self.conv1_weights,
            in_channels=in_channels,
            out_channels=out_channels,
            batch_size=batch_size,
            input_height=height,
            input_width=width,
            kernel_size=(self.patch_size, self.patch_size),
            stride=(self.patch_size, self.patch_size),
            padding=(0, 0),
            dilation=(1, 1),
            groups=0,  # No grouped convolution (standard convolution)
            device=get_device(),
            return_weights_and_bias=False,
            return_output_dim=False,
        )

        # ERROR: Number of shards along height 7 must not exceed number of cores 2
        output_height = height // self.patch_size
        output_width = width // self.patch_size

        # Check Convolution result
        x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
        host_tensor = ttnn.to_torch(x, dtype=torch.float32)

        host_tensor = torch.reshape(host_tensor, (batch_size, output_height, output_width, out_channels))

        x = ttnn.reshape(x, (x.shape[0], x.shape[1] * x.shape[2], x.shape[3]))

        class_embedding = convert_ttnn_dtype(self.class_embedding, x.dtype, (x.shape[0], 1, x.shape[-1]))

        # Create zero tensor to ensure proper broadcasting and memory layout
        # This helps align the class embedding tensor with the expected shape and memory configuration
        zero_tensor = ttnn.zeros(
            shape=(x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=device, layout=ttnn.TILE_LAYOUT
        )

        class_embedding = ttnn.reshape(class_embedding, zero_tensor.shape)
        class_embedding = class_embedding + zero_tensor  # Addition with zero preserves values but ensures proper layout

        # Move tensor to DRAM memory for concatenation operation
        # Future optimization: Use L1 sharded memory for better performance
        x = ttnn.to_memory_config(x, memory_config=ttnn.DRAM_MEMORY_CONFIG)

        class_embedding = ttnn.reshape(
            class_embedding, (class_embedding.shape[0], class_embedding.shape[1], class_embedding.shape[2])
        )
        class_embedding = ttnn.to_memory_config(class_embedding, memory_config=x.memory_config())

        # Concatenate class embedding with patch embeddings
        # Note: Future optimization could avoid host transfers for better performance
        x = ttnn.concat([class_embedding, x], dim=1, memory_config=None)  # shape = [*, grid ** 2 + 1, width]

        positional_embedding = convert_ttnn_dtype(self.positional_embedding, x.dtype, (1, x.shape[1], x.shape[2]))
        x = x + positional_embedding

        # LayerNorm
        x = ttnn.layer_norm(x, weight=self.ln_pre_weights, bias=self.ln_pre_bias)

        # Permute
        x = ttnn.permute(x, (1, 0, 2))  # NLD -> LND

        # Transformer
        x = self.transformer.forward(x)

        # Permute
        x = ttnn.permute(x, (1, 0, 2))  # LND -> NLD

        # LayerNorm
        x = ttnn.layer_norm(x[:, 0, :], weight=self.ln_post_weights, bias=self.ln_post_bias)

        if self.proj is not None:
            x = ttnn.matmul(x, self.proj, transpose_b=True)

        return x

## Complete CLIP Model Implementation

We now define the main CLIP class that combines both text and vision processing capabilities. This class orchestrates the entire multimodal inference pipeline.

### CLIP Architecture Components

The CLIP class instantiates and manages:
- **Text Transformer**: Processes tokenized text inputs using causal attention masking
- **Vision Transformer**: Processes image inputs through patch-based attention
- **Shared Embedding Space**: Projects both modalities into a common feature space for comparison

### Key Methods
- `encode_text()`: Converts text tokens to feature embeddings
- `encode_image()`: Converts images to feature embeddings  
- `forward()`: Performs complete inference, computing similarity scores between images and text

In [None]:
class CLIP:
    def __init__(self, state_dict):
        self.token_embedding = ttnn.typecast(
            state_dict["text_model.embeddings.token_embedding.weight"], dtype=ttnn.bfloat16
        )
        self.positional_embedding = ttnn.typecast(
            state_dict["text_model.embeddings.position_embedding.weight"], dtype=ttnn.bfloat16
        )

        self.text_projection = ttnn.typecast(state_dict["text_projection.weight"], dtype=ttnn.bfloat16)
        self.context_length = self.positional_embedding.shape[0]
        self.vocab_size = self.token_embedding.shape[0]
        self.transformer_width = state_dict["text_model.final_layer_norm.weight"].shape[0]
        transformer_heads = self.transformer_width // 64

        self.ln_final_weights = state_dict["text_model.final_layer_norm.weight"]
        self.ln_final_bias = state_dict["text_model.final_layer_norm.bias"]

        self.logit_scale = state_dict["logit_scale"].item()

        self.visual = VisionTransformer(state_dict)

        self.transformer = Transformer(
            state_dict, transformer_heads, attention_mask=self.build_attention_mask(), prefix="text_model.encoder"
        )
        
    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = ttnn.full(shape=[self.context_length, self.context_length], fill_value=float("-inf"), dtype=ttnn.bfloat16, device=get_device(), layout=ttnn.TILE_LAYOUT)
        mask = ttnn.triu(mask, diagonal=1)
        return mask


    def encode_image(self, image):
        return self.visual.forward(image)

    def encode_text(self, tokens):
        tokens = convert_ttnn_dtype(tokens, dtype=ttnn.uint32)

        x = ttnn.embedding(tokens, weight=self.token_embedding, dtype=ttnn.bfloat16)

        # Add positional embedding
        x = x + self.positional_embedding

        # Permute
        x = ttnn.permute(x, (1, 0, 2))  # NLD -> LND

        # Call Text Transformer
        x = self.transformer.forward(x) 

        # Permute back
        x = ttnn.permute(x, (1, 0, 2))  # LND -> NLD

        # LayerNorm
        x = ttnn.layer_norm(x, weight=self.ln_final_weights, bias=self.ln_final_bias)

        # Extract features at the end-of-sequence token position and apply text projection
        # Currently falling back to PyTorch for argmax operation
        torch_tokens = ttnn.to_torch(tokens)
        torch_x = ttnn.to_torch(x)

        torch_selected_features = torch_x[torch.arange(torch_x.shape[0]), torch_tokens.argmax(dim=-1)]
        
        # Put tensor back on device for text projection
        x = ttnn.from_torch(torch_selected_features, device=get_device(), layout=ttnn.TILE_LAYOUT)
        x = ttnn.matmul(x, self.text_projection, transpose_b=True)
        
        return x

    def forward(self, image, tokens):
        text_features = self.encode_text(tokens)
        image_features = self.encode_image(image)

        # Normalize features
        norm_image_features = ttnn.operations.moreh.norm(image_features, p=2.0, dim=1, keepdim=True)
        norm_text_features = ttnn.operations.moreh.norm(text_features, p=2.0, dim=1, keepdim=True)

        image_features = ttnn.divide(image_features, norm_image_features)
        text_features = ttnn.divide(text_features, norm_text_features)

        # Cosine similarity as logits
        logit_scale = math.exp(self.logit_scale)

        # Compute `logit_scale * image_features @ text_features.t()`
        logits_per_image = ttnn.matmul(logit_scale * image_features, text_features, transpose_b=True)
        logits_per_text = ttnn.transpose(logits_per_image, 0, 1)

        return logits_per_image, logits_per_text


## Image Preprocessing

While input images can have any dimensions and color spaces, our CLIP model expects standardized 224×224 RGB images. We therefore preprocess images to match the model's expected input format.

### Preprocessing Pipeline

The preprocessing applies the following transformations in sequence:
1. **Resize**: Scale image to 224×224 pixels using bicubic interpolation
2. **Center Crop**: Crop the center region to ensure exact dimensions
3. **RGB Conversion**: Convert to RGB color space if needed
4. **Normalization**: Apply ImageNet normalization statistics used during CLIP training

This preprocessing ensures consistent input format regardless of the original image properties. 

In [None]:
def preprocess_image(image, model_resolution):
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    # Pre-process image on host with torch
    transform_fn = Compose(
        [
            Resize(model_resolution, interpolation=InterpolationMode.BICUBIC),
            CenterCrop(model_resolution),
            _convert_image_to_rgb,
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ]
    )
    return transform_fn(image)

## Image Download Utility

We use a utility function to download images from URLs for demonstration purposes. This function handles HTTP requests and converts the response into a PIL Image object that can be processed by our preprocessing pipeline.

In [None]:
def download_image(url):
    """
    Download an image from a URL and return it as a PIL Image object.
    
    Args:
        url (str): The URL of the image to download
        
    Returns:
        PIL.Image: The downloaded image
    """
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()  # Raise an exception for bad status codes
        
        # Convert the response content to a PIL Image
        image = Image.open(BytesIO(response.content))
        return image
    except requests.RequestException as e:
        raise Exception(f"Failed to download image from {url}: {e}")
    except Exception as e:
        raise Exception(f"Failed to process downloaded image: {e}")

## Running CLIP Inference

Having defined each component of our CLIP model, we can now perform inference on an input image and text prompts. This section demonstrates the complete inference pipeline from loading pre-trained weights to computing similarity scores.

### Inference Pipeline

1. **Model Loading**: Download pre-trained CLIP weights using `CLIPModel.from_pretrained()`
2. **Weight Conversion**: Convert PyTorch weights to TT-NN
3. **Image Processing**: Download, preprocess, and convert image to TT-NN tensor
4. **Text Processing**: Tokenize text prompts and convert to TT-NN tensors
5. **Forward Pass**: Compute image and text embeddings, then calculate similarity scores
6. **Results**: Apply softmax to get probability distribution over text prompts

### Text Tokenization

Since TT-NN does not handle tokenization natively, we use the `CLIPTokenizer` from the `transformers` library. The tokenizer converts text strings into token IDs that match the vocabulary used during CLIP training. We then convert these token tensors to TT-NN format for processing. 

In [None]:
if __name__ == "__main__":
    # Initialize TT-NN device for hardware acceleration
    open_ttnn()

    # Set up logging for debugging (optional)
    logging_file = open("logging.csv", "w")

    # Load pre-trained CLIP model and convert weights to TT-NN format
    print("Loading pre-trained CLIP model...")
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    state_dict = convert_model_to_ttnn(model.state_dict())

    # Initialize our TT-NN CLIP implementation
    clip = CLIP(state_dict)

    # Download and preprocess test image
    print("Downloading and preprocessing image...")
    image_url = "https://media.githubusercontent.com/media/tenstorrent/tutorial-assets/nmaurice/clip-tutorial/media/clip_tutorial/CLIP.png"
    image = download_image(image_url)

    # Preprocess image to model requirements (224x224, normalized)
    image = preprocess_image(image, 224).unsqueeze(0).to("cpu")

    # Convert image to TT-NN tensor with bfloat16 precision
    preferred_dtype = ttnn.bfloat16
    tt_image = to_ttnn(image, preferred_dtype)

    # Define text prompts for zero-shot classification
    prompts = ["a diagram", "a dog", "a cat"]

    # Tokenize text prompts using CLIP's tokenizer
    print("Tokenizing text prompts...")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    tokenized_inputs = tokenizer(prompts, padding="max_length", max_length=clip.context_length, return_tensors="pt")
    tokens_pretrained_host = tokenized_inputs["input_ids"]
    tokens_pretrained = ttnn.from_torch(tokens_pretrained_host, device=get_device(), layout=ttnn.TILE_LAYOUT)

    # Perform CLIP inference: compute similarity between image and text
    print("Running CLIP inference...")
    time_start = time.time()
    logits_per_image, logits_per_text = clip.forward(tt_image, tokens_pretrained)
    time_end = time.time()
    print(f"Time taken: {time_end - time_start} seconds")
    
    # Convert logits to probabilities using softmax
    probs = ttnn.softmax(logits_per_image, dim=-1)
    print(f"==== Classification probabilities:")
    
    # Display results
    probs_torch = ttnn.to_torch(probs)
    for i, prompt in enumerate(prompts):
        print(f"'{prompt}': {probs_torch[0][i].item():.4f}")

    # Clean up resources
    logging_file.close()
    close_ttnn()
