from litgpt import LLM
from litgpt.chat.base import generate as chat_generate
import lightning as L
import os

OUTPUT_DIRECTORY = './litgpt_model_responses'
SEED = 0
SELECTED_MODEL = 'llama'  # 'gemma' or 'llama'
SELECTED_PROMPT_INDEX = 0

gemma_prompts = [
"""
Hail, Thor! Tell me a long story about your greatest quest. Make it a real saga with multiple adventures, battles, and trials -- at least 1000 words. Speak only as Thor.
""",
"""
Tell me a long story about a neanderthal boy. Make it a real saga with multiple adventures -- at least 1000 words.
""",
"""
Give me a detailed biography of Albert Einstein with special emphasis on his explanation of the photo-electric effect.
"""
]

llama_prompts = [
"""
Hail, Thor! Tell me a long story about your greatest quest. Make it a real saga with multiple adventures, battles, and trials -- at least 1000 words. Speak only as Thor.
""",
"""
Tell me a long story about a neanderthal boy. Make it a real saga with multiple adventures -- at least 1000 words.
""",
"""
Give me a detailed biography of Albert Einstein with special emphasis on his explanation of the photo-electric effect.
"""
]

# Keep the system instruction separate; the API will apply the model's chat template.
SYS_PROMPT = (
    "You are a helpful assistant. Follow instructions precisely. "
    "Stay in-character as Thor when asked."
)

if SELECTED_MODEL == 'gemma':
    # Gemma config
    max_new_tokens = 1500 
    temperature = 0.1
    top_k = 64
    top_p = 0.95
    model_path = "google/gemma-3-4b-it"
    prompt = gemma_prompts[SELECTED_PROMPT_INDEX]

elif SELECTED_MODEL == 'llama':
    # Llama config
    max_new_tokens = 1500 
    temperature = 0.6 
    top_k = 64    
    top_p = 0.9
    model_path = "meta-llama/Llama-3.2-3B-Instruct"
    prompt = llama_prompts[SELECTED_PROMPT_INDEX]

else:
    assert False

L.seed_everything(SEED)
llm = LLM.load(model_path)

# Use the chat generator with model-specific stop tokens to avoid long-form degeneration.
tokenizer = llm.preprocessor.tokenizer
prompt_style = llm.prompt_style
stop_tokens = prompt_style.stop_tokens(tokenizer)

# Prepare input ids with proper chat formatting via the prompt style.
input_ids = llm._text_to_token_ids(prompt, sys_prompt=SYS_PROMPT)
prompt_len = input_ids.size(0)
max_returned_tokens = prompt_len + max_new_tokens

# Pre-size the KV cache for the full generation window to improve stability.
device = llm.fabric.device if llm.fabric is not None else llm.preprocessor.device
llm.model.max_seq_length = min(llm.model.config.block_size, max_returned_tokens)
llm.model.set_kv_cache(batch_size=1, device=device)

# Generate using the same path as the CLI chat, with rich stop sequences.
stream = chat_generate(
    model=llm.model,
    prompt=input_ids,
    max_returned_tokens=max_returned_tokens,
    temperature=temperature,
    top_k=top_k,
    top_p=top_p,
    stop_tokens=stop_tokens,
)

parts = []
for s in tokenizer.decode_stream(stream, device=device):
    parts.append(s)
response = "".join(parts)

output_filename = f'{SELECTED_MODEL}_seed_{SEED}_promptindex_{SELECTED_PROMPT_INDEX}_maxtokens_{max_new_tokens}_temp_{temperature}_topk_{top_k}_topp_{top_p}.txt'

os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
with open(os.path.join(OUTPUT_DIRECTORY, output_filename), 'w') as fp:
    fp.write(f'PROMPT:\n{prompt}\n')
    fp.write(f'MODEL RESPONSE:\n{response}')