<a href="https://colab.research.google.com/github/rahiakela/genai-research-and-practice/blob/main/vector-databases/02_zero_shot_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [1]:
from sentence_transformers import SentenceTransformer, util
import torch

## Zero-shot classification

In [2]:
class ZeroShotClassifier:
    def __init__(self, model_name='all-mpnet-base-v2'):
        self.model = SentenceTransformer(model_name)

    def classify(self, text, candidate_labels):
        # Encode the input text
        text_embedding = self.model.encode(text, convert_to_tensor=True)

        # Prepare label prompts
        label_prompts = [f"This text is about {label}" for label in candidate_labels]
        label_embeddings = self.model.encode(label_prompts, convert_to_tensor=True)

        # Calculate similarities
        similarities = util.pytorch_cos_sim(text_embedding, label_embeddings)[0]

        # Create results dictionary
        results = {
            label: float(score)
            for label, score in zip(candidate_labels, similarities)
        }

        return results

In [None]:
# Example usage
classifier = ZeroShotClassifier()

In [6]:
text = "The new quantum computer can perform calculations in seconds that would take classical computers thousands of years."
labels = ["technology", "sports", "cooking", "politics"]

results = classifier.classify(text, labels)
print("Zero-shot classification results:")
for label, score in sorted(results.items(), key=lambda x: x[1], reverse=True):
    print(f"{label}: {score:.3f}")

Zero-shot classification results:
technology: 0.236
cooking: 0.024
politics: 0.023
sports: 0.018
