In [16]:
import warnings
import logging
import time
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score

try:
    import langid
    LANGID_AVAILABLE = True
except ImportError:
    LANGID_AVAILABLE = False
    print("langid.py not installed. Install using: pip install langid")

try:
    from ollama import chat
    OLLAMA_AVAILABLE = True
except ImportError:
    OLLAMA_AVAILABLE = False
    print("Ollama SDK not installed. Ollama recommendations will be skipped.")

logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

# -----------------------------
# HuggingFace language models
# -----------------------------
LANGUAGE_MODELS = [
    "papluca/xlm-roberta-base-language-detection",
    "Joshi-Aryan/Fine_Tuned_HF_Language_Identification_Model",
    "langid.py",
]

# -----------------------------
# Detect device
# -----------------------------
DEVICE = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print(f"Using device: {DEVICE}")

# -----------------------------
# Load multi-language dataset
# -----------------------------
languages = ["en", "fr", "es", "de"]
texts = []
true_labels = []

for lang in languages:
    dataset = load_dataset("wiki40b", lang, split="train[:25]")
    texts.extend(dataset['text'])
    true_labels.extend([lang] * len(dataset))

# -----------------------------
# HF model inference
# -----------------------------
def run_model_hf(model_name, texts):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(DEVICE)
    model.eval()

    preds = []
    for t in texts:
        encoded = tokenizer(t, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        with torch.no_grad():
            output = model(**encoded)
            pred_id = output.logits.argmax(dim=1).item()
            pred_label = model.config.id2label[pred_id]
        preds.append(pred_label.lower()[:2])
    return preds

# -----------------------------
# langid.py inference
# -----------------------------
def run_langid(texts):
    if not LANGID_AVAILABLE:
        raise RuntimeError("langid.py not installed. Install using: pip install langid")
    return [langid.classify(t)[0] for t in texts]

# -----------------------------
# Unified model runner
# -----------------------------
def run_any_model(model_name, texts):
    if model_name == "langid.py":
        return run_langid(texts)
    else:
        return run_model_hf(model_name, texts)

# -----------------------------
# Run all models with accuracy/latency
# -----------------------------
def run_batch_language_detection(texts, true_labels, query_type="accuracy", target_acc=0):
    results = []
    best_acc = 0
    best_latency = float('inf')

    for model_name in LANGUAGE_MODELS:
        try:
            start = time.time()
            preds = run_any_model(model_name, texts)
            acc = round(accuracy_score(true_labels, preds) * 100, 2)
            latency = round(time.time() - start, 2)

            skip = False
            if query_type == "accuracy":
                if acc < best_acc or (acc == best_acc and latency > best_latency):
                    skip = True
            else:  # latency mode
                if acc < target_acc:
                    skip = True

            if skip:
                continue

            results.append({
                "model": model_name,
                "accuracy": acc,
                "latency": latency
            })

            if acc > best_acc or (acc == best_acc and latency < best_latency):
                best_acc = acc
                best_latency = latency

        except Exception as e:
            print(f"Model failed: {model_name}, Error: {e}")

    return results

# -----------------------------
# Ollama selection
# -----------------------------
def ollama_select_model(results, query_type="accuracy", target_acc=0):
    if not OLLAMA_AVAILABLE or not results:
        return None

    prompt = "You are a model selection assistant.\nHere are model results:\n"
    for r in results:
        prompt += f"{r['model']}: accuracy={r['accuracy']} latency={r['latency']}\n"

    if query_type == "accuracy":
        prompt += "\nSelect the best model by accuracy (tie-break: latency). Return ONLY the model name."
    else:
        prompt += f"\nSelect fastest model with accuracy >= {target_acc}%. Return ONLY the model name."

    try:
        response = chat(model="qwen2.5:7b", messages=[{"role": "user", "content": prompt}])
        return response.message.content.strip()
    except Exception as e:
        print("Ollama error:", e)
        return None

# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    print("\nQuery options:")
    print("1 - Most accurate model")
    print("2 - Fastest model at target accuracy")

    choice = input("Enter (1 or 2): ").strip()

    if choice == "1":
        query_type = "accuracy"
        target_acc = 0
    else:
        query_type = "latency"
        target_acc = float(input("Enter target accuracy: "))

    results = run_batch_language_detection(texts, true_labels, query_type, target_acc)
    selected_model = ollama_select_model(results, query_type, target_acc)

    print("Ollama Selected Model:", selected_model)


Using device: mps

Query options:
1 - Most accurate model
2 - Fastest model at target accuracy
Ollama Selected Model: langid.py
