In [5]:
import warnings
import logging
import time
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer

try:
    from ollama import chat
    OLLAMA_AVAILABLE = True
except:
    OLLAMA_AVAILABLE = False
    print("Ollama SDK missing: pip install ollama")

logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

# ------------------------------------------------
# HuggingFace Text Generation Models
# ------------------------------------------------
GENERATION_MODELS = [
    "t5-small",
    "google/mt5-small",
    "facebook/bart-base"
]

# ------------------------------------------------
# Detect device
# ------------------------------------------------
DEVICE = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print(f"Using device: {DEVICE}")

# ------------------------------------------------
# Load dataset (example: english summarization)
# ------------------------------------------------
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:25]")

inputs = dataset["article"]
references = dataset["highlights"]

# ------------------------------------------------
# Generation function
# ------------------------------------------------
def run_model_generate(model_name, inputs, max_new_tokens=128):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(DEVICE)
    model.eval()

    outputs = []
    for text in inputs:
        enc = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)

        with torch.no_grad():
            gen_ids = model.generate(
                **enc,
                max_new_tokens=max_new_tokens,
                do_sample=False
            )

        generated = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        outputs.append(generated)
    return outputs


# ------------------------------------------------
# Evaluation: BLEU + ROUGE
# ------------------------------------------------
def evaluate_generation(preds, refs):
    bleu_scores = []
    rouge_scores = []

    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

    for p, r in zip(preds, refs):
        bleu_scores.append(sentence_bleu([r.split()], p.split()))
        rouge = scorer.score(r, p)
        rouge_scores.append(rouge["rougeL"].fmeasure)

    return {
        "bleu": round(sum(bleu_scores) / len(bleu_scores) * 100, 2),
        "rougeL": round(sum(rouge_scores) / len(rouge_scores) * 100, 2)
    }


# ------------------------------------------------
# Unified model runner
# ------------------------------------------------
def run_any_generation(model_name, inputs):
    return run_model_generate(model_name, inputs)


# ------------------------------------------------
# Batch evaluation
# ------------------------------------------------
def run_batch_generation(inputs, references, query_type="quality", target_bleu=0):
    results = []
    best_bleu = 0
    best_latency = float('inf')

    for model_name in GENERATION_MODELS:
        try:
            start = time.time()
            preds = run_any_generation(model_name, inputs)
            metrics = evaluate_generation(preds, references)
            latency = round(time.time() - start, 2)

            bleu = metrics["bleu"]
            rouge = metrics["rougeL"]

            skip = False
            if query_type == "quality":
                if bleu < best_bleu or (bleu == best_bleu and latency > best_latency):
                    skip = True
            else:
                # latency mode
                if bleu < target_bleu:
                    skip = True

            if skip:
                continue

            results.append({
                "model": model_name,
                "bleu": bleu,
                "rougeL": rouge,
                "latency": latency
            })

            if bleu > best_bleu or (bleu == best_bleu and latency < best_latency):
                best_bleu = bleu
                best_latency = latency

        except Exception as e:
            print(f"Model failed: {model_name} | Error: {e}")

    return results


# ------------------------------------------------
# Ollama: Select best model
# ------------------------------------------------
def ollama_select_model(results, query_type="quality", target_bleu=0):
    if not OLLAMA_AVAILABLE or not results:
        return None

    prompt = "You are a model selection assistant.\nHere are text generation model stats:\n"
    for r in results:
        prompt += (
            f"{r['model']}: BLEU={r['bleu']} ROUGE-L={r['rougeL']} "
            f"latency={r['latency']} sec\n"
        )

    if query_type == "quality":
        prompt += "\nSelect the highest BLEU model (tie → lowest latency). Return ONLY the model name."
    else:
        prompt += f"\nSelect fastest model with BLEU >= {target_bleu}. Return ONLY the model name."

    try:
        response = chat(model="qwen2.5:7b", messages=[{"role": "user", "content": prompt}])
        return response.message.content.strip()
    except:
        return None


# ------------------------------------------------
# Main entry
# ------------------------------------------------
if __name__ == "__main__":
    print("\nQuery options:")
    print("1 - Best quality model (BLEU)")
    print("2 - Fastest model with min BLEU target")

    choice = input("Enter (1 or 2): ").strip()

    if choice == "1":
        query_type = "quality"
        target_bleu = 0
    else:
        query_type = "latency"
        target_bleu = float(input("Enter target BLEU score (0–100): "))

    results = run_batch_generation(inputs, references, query_type, target_bleu)
    selected_model = ollama_select_model(results, query_type, target_bleu)

    print("\nOllama Selected Model:", selected_model)


Using device: mps

Query options:
1 - Best quality model (BLEU)
2 - Fastest model with min BLEU target

Ollama Selected Model: None
