In [4]:
import warnings
import logging
import time
import json
import torch
import pandas as pd
import numpy as np
import faiss

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer

try:
    from ollama import chat
    OLLAMA_AVAILABLE = True
except:
    OLLAMA_AVAILABLE = False
    print("Ollama SDK missing: pip install ollama")

logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

# Model list
GENERATION_MODELS = [
    "t5-small",
    "google/mt5-small",
    "facebook/bart-base"
]


# Detect device
DEVICE = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)

print(f"Using device: {DEVICE}")

# Load dataset
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:25]")

inputs = dataset["article"]
references = dataset["highlights"]

# Text generation
def run_model_generate(model_name, inputs, max_new_tokens=128):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(DEVICE)
    model.eval()

    outputs = []
    for text in inputs:
        enc = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        with torch.no_grad():
            gen_ids = model.generate(**enc, max_new_tokens=max_new_tokens, do_sample=False)
        generated = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        outputs.append(generated)
    return outputs

# Evaluation (BLEU + ROUGE)
def evaluate_generation(preds, refs):
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    bleu_scores, rouge_scores = [], []

    for p, r in zip(preds, refs):
        bleu_scores.append(sentence_bleu([r.split()], p.split()))
        rouge_scores.append(scorer.score(r, p)["rougeL"].fmeasure)

    return {
        "bleu": round(sum(bleu_scores) / len(bleu_scores) * 100, 2),
        "rougeL": round(sum(rouge_scores) / len(rouge_scores) * 100, 2)
    }

# Run all models
def run_batch_generation(inputs, references, query_type="quality", target_bleu=0):
    results = []
    best_bleu = 0
    best_latency = float('inf')

    for model_name in GENERATION_MODELS:
        try:
            start = time.time()
            preds = run_model_generate(model_name, inputs)
            metrics = evaluate_generation(preds, references)
            latency = round(time.time() - start, 2)

            bleu = metrics["bleu"]
            rouge = metrics["rougeL"]

            skip = False
            if query_type == "quality":
                if bleu < best_bleu or (bleu == best_bleu and latency > best_latency):
                    skip = True
            else:  # latency mode
                if bleu < target_bleu:
                    skip = True

            if skip:
                continue

            results.append({
                "model": model_name,
                "bleu": bleu,
                "rougeL": rouge,
                "latency": latency
            })

            if bleu > best_bleu or (bleu == best_bleu and latency < best_latency):
                best_bleu = bleu
                best_latency = latency

        except Exception as e:
            print(f"Model failed: {model_name} | Error: {e}")

    return results

# Build FAISS index from benchmarks_textgen.csv
def build_vector_db(csv_path="benchmarks_s.csv"):
    df = pd.read_csv(csv_path)
    embedder = SentenceTransformer("all-MiniLM-L6-v2")

    docs = [
        f"Model: {r['model']}. Dataset: {r['dataset']}. BLEU: {r['bleu']}. ROUGE: {r['rouge']}. Latency: {r['latency']}."
        for _, r in df.iterrows()
    ]

    embeddings = embedder.encode(docs, show_progress_bar=False)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(np.array(embeddings, dtype="float32"))

    return df, index, embedder, docs


# RAG retrieval
def query_benchmark_rag(task_query, df, index, embedder, docs, k=5):
    q_emb = embedder.encode([task_query]).astype("float32")
    D, I = index.search(q_emb, k)
    return "\n".join([docs[i] for i in I[0]])

# Ollama LLM model selection using RAG
def ollama_select_model(results, rag_context, query_type="quality", target_bleu=0):
    if not OLLAMA_AVAILABLE:
        return None

    prompt = f"""
        You are a strict model selector.

        Retrieved benchmark evidence:
        {rag_context}

        Evaluation results:
        {json.dumps(results, indent=2)}

        Rules:
        - If goal = quality → pick highest BLEU (tie → lowest latency)
        - If goal = latency → pick lowest latency among models with BLEU ≥ target
        - Output MUST be only the model name.

        Goal: {query_type}
        Target BLEU: {target_bleu}
    """

    try:
        response = chat(model="qwen2.5:7b", messages=[{"role": "user", "content": prompt}])
        return response.message.content.strip()
    except Exception as e:
        print("Ollama error:", e)
        return None


# Main
if __name__ == "__main__":
    print("\nQuery options:")
    print("1 - Highest quality model (BLEU)")
    print("2 - Fastest model with minimum BLEU")

    choice = input("Enter (1 or 2): ").strip()

    if choice == "1":
        query_type = "quality"
        target_bleu = 0
    else:
        query_type = "latency"
        target_bleu = float(input("Enter minimum BLEU target: "))

    # Run evaluations
    results = run_batch_generation(inputs, references, query_type, target_bleu)

    # Build RAG DB
    df, index, embedder, docs = build_vector_db()

    rag_context = query_benchmark_rag(
        "best summarization model for cnn dailymail",
        df, index, embedder, docs
    )

    # LLM selector
    selected_model = ollama_select_model(results, rag_context, query_type, target_bleu)

    print("\n=== OLLAMA SELECTED MODEL ===")
    print(selected_model)


Using device: mps

Query options:
1 - Highest quality model (BLEU)
2 - Fastest model with minimum BLEU

=== OLLAMA SELECTED MODEL ===
facebook/bart-base
