from litgpt import LLM
import lightning as L
import os

OUTPUT_DIRECTORY = './litgpt_model_responses'
SEED = 0
SELECTED_MODEL = 'gemma'  # 'gemma' or 'llama'
SELECTED_PROMPT_INDEX = 0

gemma_prompts = [
"""
Hail, Thor! Tell me a long story about your greatest quest. Make it a real saga with multiple adventures, battles, and trials -- at least 1000 words. Speak only as Thor.
""",
"""
Tell me a long story about a neanderthal boy. Make it a real saga with multiple adventures -- at least 1000 words.
""",
"""
Give me a detailed biography of Albert Einstein with special emphasis on his explanation of the photo-electric effect.
"""
]

llama_prompts = [
"""
Hail, Thor! Tell me a long story about your greatest quest. Make it a real saga with multiple adventures, battles, and trials -- at least 1000 words. Speak only as Thor.
""",
"""
Tell me a long story about a neanderthal boy. Make it a real saga with multiple adventures -- at least 1000 words.
""",
"""
Give me a detailed biography of Albert Einstein with special emphasis on his explanation of the photo-electric effect.
"""
]

if SELECTED_MODEL == 'gemma':
    # Gemma config
    max_new_tokens = 1500 
    temperature = 1.0
    top_k = 64  
    top_p = 0.95
    model_path = "google/gemma-3-4b-it"
    prompt = gemma_prompts[SELECTED_PROMPT_INDEX]

elif SELECTED_MODEL == 'llama':
    # Llama config
    max_new_tokens = 1500 
    temperature = 0.6 
    top_k = 64    
    top_p = 0.9
    model_path = "meta-llama/Llama-3.2-3B-Instruct"
    prompt = llama_prompts[SELECTED_PROMPT_INDEX]

else:
    assert False

L.seed_everything(SEED)
llm = LLM.load(model_path)

response = llm.generate(prompt,
                        max_new_tokens=max_new_tokens,
                        temperature=temperature,
                        top_k=top_k,
                        top_p=top_p)

output_filename = f'{SELECTED_MODEL}_seed_{SEED}_promptindex_{SELECTED_PROMPT_INDEX}_maxtokens_{max_new_tokens}_temp_{temperature}_topk_{top_k}_topp_{top_p}.txt'

os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
with open(os.path.join(OUTPUT_DIRECTORY, output_filename), 'w') as fp:
    fp.write(f'PROMPT:\n{prompt}\n')
    fp.write(f'MODEL RESPONSE:\n{response}')