In [1]:
try:
    import peft
    LOCAL = True
    MODEL_PATH = "deepseek-ai/deepseek-math-7b-rl"
    from functions import *
except:
    LOCAL = False
    MODEL_PATH = "/kaggle/input/deepseek-math"
    from functions_math import *
    

import torch
import gc
if not LOCAL:torch.backends.cuda.enable_mem_efficient_sdp(False)
import pandas as pd

In [2]:
if LOCAL:
    import json
    with open('../Data/AMC/aime_normal.json', 'r') as file:
        data = json.load(file)
    # to have consistent format as in Kaggle
    data = pd.DataFrame(data)
    data.rename(columns={'question': 'problem'}, inplace=True)
else:
    data = pd.read_csv('/kaggle/input/ai-mathematical-olympiad-prize/test.csv')
    if len(data) < 5:
        data = pd.read_csv('/kaggle/input/ai-mathematical-olympiad-prize/train.csv')
        PRIVATE = False
    else:
        PRIVATE = True

#### Generation

In [None]:
repeat_times = 5

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map = "auto",
    torch_dtype="auto",
    trust_remote_code = True,
    use_flash_attention_2=LOCAL,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
outs = []
no_repeat_processor = [NoRepeatTokenLogitsProcessor()]
for index, row in data.iterrows():
    answers = []
    problem = row['problem']
    query_prompt = gen_prompt(problem)
    messages = [
        {
            "role": "user",
            "content": query_prompt
        }
    ]
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

    if LOCAL: monitor = [problem,int(row['final_answer'][0])] # LOCAL only
    for _ in range(repeat_times):
        # problem -> answer #
        with torch.no_grad():
            encoded_output = model.generate(inputs, max_new_tokens=1500, do_sample=True, pad_token_id=tokenizer.eos_token_id,\
                                            logits_processor=no_repeat_processor,temperature=0.7)

        decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(query_prompt, '')
        try:
            answer = decoded_output.split('\n')[-1]
            answer = naive_parse(answer) % 1000
        except:
            answer = 'parsing error'
        # End problem -> answer #
        answers.append(answer)
        if LOCAL:
            monitor.append(decoded_output)
            monitor.append(answer)
        elif not PRIVATE:
            print(decoded_output)

    final_answer = aggregate(answers)
    if LOCAL:
        monitor.append(final_answer)
        outs.append(monitor)
    else:
        outs.append(final_answer)
        torch.cuda.empty_cache()
        gc.collect()

Token indices sequence length is longer than the specified maximum sequence length for this model (5214 > 4096). Running this sequence through the model will result in indexing errors
This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


In [17]:
if LOCAL:
    outs_df = pd.DataFrame(outs,columns=['problem','y'] + ['output','yhat_i']*repeat_times + ['yhat'])
    print(sum(outs_df.yhat == outs_df.y))
    out_path = create_next_model_folder('../llmOutputs')
    print(out_path)
    outs_df.to_csv(out_path+'/outputs.csv', header=True, index=False)
    # 49
    # ../llmOutputs/model2
else:
    if not PRIVATE:
        answers = data.answer.tolist()
        correct = sum([y==yhat for y,yhat in zip(answers,outs)])
        print(f'{correct} correct answers')    
    data['answer'] = outs
    data[['id','answer']].to_csv("submission.csv", header=True, index=False)