In [1]:
import warnings
import logging
import time
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

try:
    from ollama import chat
    OLLAMA_AVAILABLE = True
except ImportError:
    OLLAMA_AVAILABLE = False
    print("Ollama SDK not installed. Ollama selection will be skipped.")

# Settings
warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Embedding models to compare
EMBEDDING_MODELS = [
    "sentence-transformers/all-MiniLM-L6-v2",
    "sentence-transformers/all-mpnet-base-v2",
    "sentence-transformers/paraphrase-MiniLM-L3-v2",
]

# Load dataset
from datasets import load_dataset
dataset = load_dataset("imdb", split="test").select(range(100))
documents = dataset['text']

# Encode function
def encode_texts(model_name, texts):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(DEVICE)
    model.eval()

    embeddings = []
    with torch.no_grad():
        for text in texts:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
            outputs = model(**inputs)
            mask = inputs['attention_mask'].unsqueeze(-1)
            pooled = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
            embeddings.append(pooled.cpu().numpy())
    return np.vstack(embeddings)

# Evaluate model latency
def evaluate_model_latency(model_name, docs, queries):
    start_time = time.time()
    _ = encode_texts(model_name, docs)
    _ = encode_texts(model_name, queries)
    latency = time.time() - start_time
    return latency

# Ollama selection
def ollama_select_model(results):
    if not OLLAMA_AVAILABLE or not results:
        return None

    prompt = "You are a model selection assistant.\n"
    prompt += "Here are model latencies:\n"
    for r in results:
        prompt += f"{r['model']}: latency={r['latency']:.2f}s\n"
    prompt += "\nSelect the fastest model. Reply ONLY with the model name."

    try:
        response = chat(model="qwen2.5:7b", messages=[{"role": "user", "content": prompt}])
        return response.message.content.strip()
    except Exception as e:
        print("Ollama error:", e)
        return None

# Main
if __name__ == "__main__":
    queries = documents[:10]

    results = []
    for model_name in EMBEDDING_MODELS:
        latency = evaluate_model_latency(model_name, documents, queries)
        results.append({"model": model_name, "latency": latency})
        #print(f"{model_name} -> Latency: {latency:.2f}s")

    # Ollama selection
    selected_model = ollama_select_model(results)
    print("\nOllama Selected Model (lowest latency):", selected_model)


  from .autonotebook import tqdm as notebook_tqdm


Using device: mps

Ollama Selected Model (lowest latency): sentence-transformers/paraphrase-MiniLM-L3-v2
