In [None]:
import torch
import os
import logging
import re
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from groq import Groq

# ─── CLEANUP ─────────────────────────────────────────────────────
for var in ['base_model', 'peft_model', 'groq_client']:
    if var in globals():
        del globals()[var]

torch.cuda.empty_cache()
gc.collect()

# ─── CONFIG ──────────────────────────────────────────────────────
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
ADAPTER_PATH = "./mistral7b_reasoning_clf_optimized/checkpoint-200"
OFFLOAD_DIR = "./offload_cache"
CLASSIFIER_ADAPTER_NAME = "classifier_adapter"
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")

# ─── LOGGING ─────────────────────────────────────────────────────
torch.backends.cuda.matmul.allow_tf32 = True
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
if not torch.cuda.is_available():
    raise SystemError("CUDA device not available.")
device = torch.device("cuda:0")
compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

# ─── TOKENIZER ───────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ─── MODEL LOAD ──────────────────────────────────────────────────
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False
)
os.makedirs(OFFLOAD_DIR, exist_ok=True)

logging.info("⏳ Loading base model with 4-bit quantization...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
    device_map="auto",
    max_memory={0: "7GB", "cpu": "100GB"},
    offload_folder=OFFLOAD_DIR,
    trust_remote_code=True
)

logging.info("⏳ Attaching LoRA adapter for classification...")
peft_model = PeftModel.from_pretrained(
    base_model,
    ADAPTER_PATH,
    adapter_name=CLASSIFIER_ADAPTER_NAME
)

base_model.eval()
peft_model.eval()
logging.info("✅ Models ready.")

# ─── GROQ INIT ───────────────────────────────────────────────────
if not GROQ_API_KEY:
    raise SystemError("GROQ_API_KEY environment variable not set.")
groq_client = Groq(api_key=GROQ_API_KEY)

# ─── CLEAN TEXT ──────────────────────────────────────────────────
def clean_artifacts(text: str) -> str:
    return re.sub(r'(ACHE)+', '', text, flags=re.IGNORECASE).strip()

# ─── REASONING CONFIG ────────────────────────────────────────────
REASONING_CONFIG = {
    "Simple": {
        "instruction": (
            "Respond with a direct, concise answer. You may use a short list or a few lines "
            "if needed . Do not explain or justify the answer."
        ),
        "max_new_tokens": 128,
        "temperature": 0.0,
        "top_p": None
    },
    "Shallow": {
        "instruction": (
            "Respond with a brief explanation or summary. You may use up to 2–3 paragraphs, "
            "but avoid deep technical breakdowns or extended reasoning."
        ),
        "max_new_tokens": 512,
        "temperature": 0.7,
        "top_p": 0.9
    },
    "Deep": {
        "instruction": (
            "Respond with a detailed and structured explanation. Walk through the reasoning steps, "
            "make assumptions explicit, and provide a thorough breakdown."
        ),
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_p": 0.9
    }
}

conversation_history = []

# ─── MAIN FUNCTION ───────────────────────────────────────────────
def process_query_with_reasoning_and_answer(query: str, max_class_tokens: int = 10):
    peft_model.set_adapter(CLASSIFIER_ADAPTER_NAME)

    cls_prompt = (
        "<s>[INST] You are an expert that classifies query complexity.\n"
        "Respond with only one of: Simple, Shallow, or Deep. Do not explain and respond with EXACTLY ONE WORD.\n"
        f"Query: \"{query}\"\n"
        "Classification:[/INST]"
    )

    inputs_cls = tokenizer(cls_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs_cls = peft_model.generate(
            **inputs_cls,
            max_new_tokens=max_class_tokens,
            do_sample=False,
            temperature=0.0,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    raw_cls = tokenizer.decode(outputs_cls[0][inputs_cls.input_ids.shape[1]:], skip_special_tokens=True)
    level = clean_artifacts(raw_cls).strip().lower().rstrip(".").capitalize()
    if level not in REASONING_CONFIG:
        logging.warning(f"⚠ Unexpected classification: '{level}', defaulting to 'Simple'")
        level = "Simple"

    logging.info(f"Classified reasoning level: {level}")
    config = REASONING_CONFIG[level]
    instruction = config["instruction"]

    conversation_history.append({"role": "system", "content": instruction})
    conversation_history.append({"role": "user", "content": query})

    response = groq_client.chat.completions.create(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        messages=conversation_history,
        max_tokens=config["max_new_tokens"],
        temperature=config["temperature"],
        top_p=config.get("top_p")
    )

    answer = response.choices[0].message.content.strip()
    conversation_history.append({"role": "assistant", "content": answer})
    return level, answer

# ─── CLI LOOP ────────────────────────────────────────────────────
if __name__ == "__main__":
    while True:
        user_query = input("Enter your query (or 'exit' to quit): ").strip()
        if user_query.lower() == "exit":
            break
        lvl, ans = process_query_with_reasoning_and_answer(user_query)
        print(f"\nReasoning Level: {lvl}\nAnswer: {ans}\n")
