In [None]:
# !pip install fastapi nest-asyncio pyngrok uvicorn torch pillow pydantic transformers

In [None]:
from contextlib import asynccontextmanager
from typing import Optional

import nest_asyncio
import numpy as np
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from pydantic import BaseModel
from pyngrok import ngrok
from torch import nn
from transformers import SiglipVisionModel
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import (
    SiglipEncoder,
    SiglipMultiheadAttentionPoolingHead,
    SiglipVisionModel,
)
from transformers.utils import auto_docstring, can_return_tuple


In [None]:
class SiglipVisionModelNoEmbeddings(SiglipVisionModel):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: SiglipVisionConfig):
        super().__init__(config)

        self.vision_model = SiglipVisionTransformer(config)

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        hidden_states,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: bool = False,
    ) -> BaseModelOutputWithPooling:
        return self.vision_model(
            hidden_states=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )


class SiglipVisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.encoder = SiglipEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.use_head = (
            True if not hasattr(config, "vision_use_head") else config.vision_use_head
        )
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(config)

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        hidden_states,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = False,
    ) -> BaseModelOutputWithPooling:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )

        encoder_outputs: BaseModelOutput = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        last_hidden_state = self.post_layernorm(last_hidden_state)

        pooler_output = self.head(last_hidden_state) if self.use_head else None

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooler_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )


def load_vit(path_vit_pt: str):
    vision_config = SiglipVisionConfig(name_or_path="google/siglip-base-patch16-224")
    model = SiglipVisionModelNoEmbeddings(vision_config)

    vit_weights = torch.load(path_vit_pt)
    encoder_weights = vit_weights["encoder"]
    head_weights = vit_weights["head"]
    post_layernorm_weights = vit_weights["post_layernorm"]

    model.vision_model.encoder.load_state_dict(encoder_weights)
    model.vision_model.head.load_state_dict(head_weights)
    model.vision_model.post_layernorm.load_state_dict(post_layernorm_weights)
    model = model.to("cuda")

    return model

In [None]:
siglip_vision_model: SiglipVisionModel
device: str = "cuda"
images: list[Image] = []
image_embeddings: torch.Tensor = []
text_embeddings: torch.Tensor = []

In [None]:
class Embeddings(BaseModel):
    pixel_values: list[list[float]]  

In [None]:
@asynccontextmanager
async def lifespan(app: FastAPI):
    global \
        siglip_vision_model, \
        device
    siglip_vision_model = load_vit(r"") # Path to your SigLIP model weights
    siglip_vision_model.eval()

    print("Ready for incoming...")
    yield

app = FastAPI(title="SigLIP Image Search API", lifespan=lifespan)

# Enable CORS for React frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

@app.get("/")
async def root():
    """Get the service status"""
    return {"status": "OK",
            "cuda": torch.cuda.is_available()}


@app.post("/generate-embeddings/")
async def generate_embeddings(embeddings_request: Embeddings):
    """Generate embeddings for the uploaded images"""
    global image_embeddings

    embeddings=[]
    for i in range(len(embeddings_request.pixel_values)):
        p = embeddings_request.pixel_values[i]
        p = np.array(p, dtype=np.float32).reshape((196,768))
        embeddings.append(p)
    embeddings = np.stack(embeddings, axis=0)
    embeddings = torch.tensor(embeddings).to('cuda')
    
    print(f"Generating image embeddings from received embeddings of shape {embeddings.shape}...")
    with torch.no_grad():
        image_embeddings = siglip_vision_model(embeddings)
    image_embeddings = image_embeddings["pooler_output"]
    image_embeddings = torch.nn.functional.normalize(image_embeddings, p=2, dim=-1)
    print(f"Embeddings generated: {image_embeddings.shape}")
    print()

    image_embeddings = image_embeddings.cpu().detach().numpy()
    batch_size = image_embeddings.shape[0]
    image_embeddings=image_embeddings.tolist()

    return {"image_embeddings": image_embeddings, "batch_size": batch_size}

if __name__ == "__main__":
    auth_token = ""  # Replace with your ngrok auth token
    if not auth_token:
        raise ValueError("Please set your ngrok auth token.")
    
    ngrok.set_auth_token(auth_token)
    if len(ngrok.get_tunnels()) == 0:
        ngrok_tunnel = ngrok.connect(8000)
        public_url = ngrok_tunnel.public_url
    else:
        public_url = ngrok.get_tunnels()[0].public_url

    print('Public URL:', public_url)
    nest_asyncio.apply()

    uvicorn.run(app, port=8000)
