In [1]:
import os
import getpass
import langchain

#Model = "gemini"
Model = "llama"

# Langsmith
langchain.debug = False
os.environ["LANGSMITH_TRACING"] = "false"
os.environ["LANGCHAIN_ENDPOINT"] = ""
os.environ["LANGCHAIN_API_KEY"] = ""

if Model == 'gemini':
    if "GOOGLE_API_KEY" not in os.environ:
        os.environ["GOOGLE_API_KEY"] = ""

    if not os.environ.get("GOOGLE_API_KEY"):
        os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
else:
    if "FIREWORKS_API_KEY" not in os.environ:
        os.environ["FIREWORKS_API_KEY"] = ""
    if not os.environ.get("FIREWORKS_API_KEY"):
        os.environ["FIREWORKS_API_KEY"] = getpass.getpass("Enter API key for Fireworks: ")






In [2]:
import mmmlu_preparer
from mmmlu_preparer.read_mmmlu_dataset import (
    TARGET_SUBTASKS,
    MMMLULanguage,
    create_mmmlu_dataset,
    sample_first_n_data_from_subtask
)

lang_list = ["EN", "JA_JP"]
curr_language = lang_list[0]
dataset_language_enum = MMMLULanguage[curr_language]

mmmlu_ds = create_mmmlu_dataset(dataset_language_enum)
chosen_subtasks = TARGET_SUBTASKS
mmmlu_subset = sample_first_n_data_from_subtask(mmmlu_ds, chosen_subtasks)
mmmlu_subset

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['Question id in subtask', 'Question', 'Subject', 'Answer', 'A', 'B', 'C', 'D'],
    num_rows: 1700
})

In [3]:
mmmlu_subset[1]

{'Question id in subtask': 1,
 'Question': 'Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the index of <p> in S_5.',
 'Subject': 'abstract_algebra',
 'Answer': 'C',
 'A': '8',
 'B': '2',
 'C': '24',
 'D': '120'}

In [4]:
from mmmlu_preparer.answer_extract import extract_answer_from_response
print(extract_answer_from_response("TEST A TEST"))
print(extract_answer_from_response("Answer: B"))
print(extract_answer_from_response("<Answer> D"))
print(extract_answer_from_response("'Answer': C"))
print(extract_answer_from_response('"Answer": A'))

None
B
D
C
A


In [5]:
import pandas as pd

# Draft
experiment_save_dict = {
    "Model": "",
    "Question id": "",
    "Shuffle method": "",
    "Original to shuffled": "",
    "Input format": "",
    "Output format": "",
    "Query": "",
    "Language": "",
    "Subtask": "",
    "Original correct answer": "",
    "Shuffled correct answer": "",
    "Response answer": "",
    "Model output": "",  # Output text only
    "Full response": "", # All the output
}

experiment_list = [experiment_save_dict]
experiment_df = pd.DataFrame(experiment_list)

In [6]:
experiment_df

Unnamed: 0,Model,Question id,Shuffle method,Original to shuffled,Input format,Output format,Query,Language,Subtask,Original correct answer,Shuffled correct answer,Response answer,Model output,Full response
0,,,,,,,,,,,,,,


In [7]:
import os
save_dir = "mmmlu_output"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [31]:
import getpass
import os
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_fireworks import ChatFireworks
from langchain.chat_models import init_chat_model

try:
    # load environment variables from .env file (requires `python-dotenv`)
    from dotenv import load_dotenv

    load_dotenv()
except ImportError:
    pass


# llama3 rate: 6000 / 100 qps

rate_limiter = InMemoryRateLimiter(
    requests_per_second=80,
    check_every_n_seconds=0.1,
    max_bucket_size=80,
)

if Model == "gemini":
    model_name = "gemini-2.0-flash"
    model = init_chat_model(model_name,
                        model_provider="google_genai",
                        rate_limiter=rate_limiter,
                        temperature=0.0,
                        max_tokens=4096
                        )
else:
    #model_name = "llama-v3p1-405b-instruct"
    model_name = "llama-v3p1-8b-instruct"

    model = ChatFireworks(
        model="accounts/fireworks/models/llama-v3p1-8b-instruct",
        temperature=0,
        max_tokens=4096,
        logprobs = 5,
        rate_limiter=rate_limiter

    )


# text: The text to translate
prompt_template = ChatPromptTemplate.from_messages(
    [("user", "{text}")]
)

                logprobs was transferred to model_kwargs.
                Please confirm that logprobs is what you intended.
  if await self.run_code(code, result, async_=asy):


In [39]:
from mmmlu_preparer.query_formats import (
    get_current_queries,
    InputFormat,
    OutputFormat,
    ShuffleMethod
)

# BASE, JSON, XML
curr_input_format = InputFormat.JSON

# BASE, JSON_FULL, XML_FULL
curr_output_format = OutputFormat.BASE

# DEFAULT, REVERSE, LONGEST_FIRST, SHORTEST_FIRST, MOST_KANA_RATIO, FEWEST_KANA_RATIO
curr_shuffle_method = ShuffleMethod.SHORTEST_FIRST

input_format_save_name = curr_input_format.value.lower().replace("_", "-")
output_format_save_name = curr_output_format.name.lower().replace("_", "-")
shuffle_method_save_name = curr_shuffle_method.name.lower().replace("_", "-")

language_name = curr_language.lower().replace("_", "-")

save_name = f"{model_name}_{language_name}_{input_format_save_name}_input_{output_format_save_name}_output_{shuffle_method_save_name}_shuffle"
save_name

'llama-v3p1-8b-instruct_en_json_input_base_output_shortest-first_shuffle'

In [40]:
print(curr_output_format.name.lower())

base


In [41]:
curr_queries = get_current_queries(mmmlu_subset,
                                   dataset_language_enum,
                                   chosen_subtasks,
                                   curr_input_format,
                                   curr_output_format,
                                   curr_shuffle_method,
                                   )

In [42]:
mmmlu_subset.filter(lambda x: x['Subject'] == "abstract_algebra")[0]['Question']

'Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.'

In [43]:
print(curr_queries[1000]['Query'])

{
    "Task": "Answer the following multiple choice question.",
    "Output_format": "The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.",
    "Instruction": "Think step by step before answering.",
    "Question": "As of 2016, about what percentage of adults aged 18 years or older were overweight?",
    "Options": {
        "A": "80%",
        "B": "40%",
        "C": "20%",
        "D": "10%"
    }
}


In [44]:
import json
from tqdm.auto import trange
from pathlib import Path
from typing import Optional
from mmmlu_preparer.logprobs import extract_answer_logprobs
import asyncio

async def retry_bad_response(bad_prompt, model, curr_output_format, max_retries=5):
    """Retry a single bad prompt until we get a good response"""
    
    retry_instructions = [
        "\n\nSkip the reasoning steps, give the answer directly.",
        "\n\nProvide only the final answer without explanation.", 
        "\n\nBe concise and direct in your response.",
        "\n\nAnswer briefly without showing work."
    ]
    
    original_text = bad_prompt.messages[0].content
    for attempt in range(max_retries):
        try:
            # Create modified text with retry instruction
            if attempt < len(retry_instructions):
                modified_text = original_text + retry_instructions[attempt]
            else:
                modified_text = original_text  # Fallback to original
            
            # Create new prompt using the template
            modified_text_query = {"text": modified_text}
            modified_prompt = prompt_template.batch([modified_text_query])[0]
            
            response = await model.abatch([modified_prompt])
            response_dict = response[0].to_json()
            
            # Check if this retry is good (< 4000 tokens)
            completion_tokens = response_dict['kwargs']['response_metadata']['token_usage']['completion_tokens']
            if completion_tokens <= 4000:
                # Add processed logprobs for good response
                #print(response_dict)
                
                
                return response[0], response_dict
            else:
                print(f"Retry attempt {attempt + 1} still bad ({completion_tokens} tokens)")
                
        except Exception as e:
            print(f"Error in retry attempt {attempt + 1}: {e}")
            
        # Add a small delay between retries
        await asyncio.sleep(2)
    
    # If all retries failed, return None
    print(f"All {max_retries} retry attempts failed")
    return None, None

async def run_experiemnts(queries: list[dict], save_path: str, try_first_n: Optional[int] = None, curr_output_format = None) -> list[dict]:
    text_queries = [{"text": query['Query']} for query in queries]
    
    input_prompts = prompt_template.batch(text_queries)

    results = []
    mini_batch_size = 80

    target_save_path = Path(save_path)
    if target_save_path.suffix != ".jsonl":
        print("Output should be jsonl file")
        target_save_path = target_save_path.with_suffix(".jsonl")
    target_save_path.touch()

    with target_save_path.open('r', encoding='utf-8') as file:
        # Count the nubmer of lines
        start_idx = sum(1 for line in file if line.strip())
    print(f"Start from {start_idx = }")

    total_process_num = len(text_queries)
    if try_first_n is not None:
        total_process_num = start_idx + try_first_n

    for batch_i in trange(start_idx, total_process_num, mini_batch_size):
        
        try:

            batched_prompts = input_prompts[batch_i:batch_i + mini_batch_size]
            responses = await model.abatch(batched_prompts)
            #print(responses)
            results.extend(responses)
            with target_save_path.open('a', encoding='utf-8') as file:
                for i, response in enumerate(responses):
                    retry_count = 0
                    
                    if Model == "gemini":
                        json.dump(response.to_json(), file, ensure_ascii=False)
                        file.write("\n")
                    else:
                        response_dict = response.to_json()
                        # if the model keep repeating the same answer
                        completion_tokens = response_dict['kwargs']['response_metadata']['token_usage']['completion_tokens']
                            
                        # Check if response is bad (> 4000 tokens)
                        if completion_tokens > 4000:
                            print(f"Bad response detected at index {i} ({completion_tokens} tokens), retrying...")
                            retry_count += 1
                            
                            # Get the original prompt for this response
                            bad_prompt = batched_prompts[i]
                            
                            # Retry until we get a good response
                            response, response_dict = await retry_bad_response(
                                bad_prompt, model, curr_output_format
                            )
                            
                            if response is not None:
                                print(f"Successfully retried response {i}")
                            else:
                                # If all retries failed
                                print(f"use bad response")
                                continue
                        
                        # add processed logprobs 
                        answer_probs = extract_answer_logprobs(response, curr_output_format)
                        response_dict['kwargs']['response_metadata']['logprobs'] = answer_probs
                        
                        json.dump(response_dict, file, ensure_ascii=False)
                        file.write("\n")
                file.flush()
            print(f"Finish {batch_i + mini_batch_size} data")
            await asyncio.sleep(1)

        except Exception as e:
            # Rate limit break
            print(f"Current idx:{batch_i}\nencounters exception: {e}\nIt might be daily rate limit or error.")
            break
    return results


In [45]:
save_path = f"{save_dir}/{save_name}.jsonl"
results = await run_experiemnts(curr_queries, save_path, try_first_n=None, curr_output_format = curr_output_format)

Start from start_idx = 0


  0%|          | 0/22 [00:00<?, ?it/s]

Bad response detected at index 9 (4096 tokens), retrying...
Successfully retried response 9
Bad response detected at index 13 (4096 tokens), retrying...
Successfully retried response 13
Bad response detected at index 14 (4096 tokens), retrying...
Successfully retried response 14
Bad response detected at index 17 (4096 tokens), retrying...
Successfully retried response 17
Bad response detected at index 36 (4096 tokens), retrying...
Successfully retried response 36
Bad response detected at index 38 (4096 tokens), retrying...
Successfully retried response 38
Format-specific parsing failed for BASE, trying fallback methods...
🔍 Starting FIXED XML answer search...
📍 Found 'Answer' token at position 358
📋 Context: prev='.

' | current='Answer' | next=':'
❌ Not XML pattern - prev doesn't end with '<' or next doesn't start with '>'
❌ No XML answer pattern found
All format-specific methods failed, using universal answer search...
Found answer keyword at position 358: 'Answer'
Found embedded ans

  5%|▍         | 1/22 [00:21<07:39, 21.87s/it]

Bad response detected at index 18 (4096 tokens), retrying...
Successfully retried response 18
Finish 160 data


  9%|▉         | 2/22 [00:39<06:25, 19.29s/it]

Bad response detected at index 45 (4096 tokens), retrying...
Successfully retried response 45
Finish 240 data


 14%|█▎        | 3/22 [00:56<05:46, 18.24s/it]

Finish 320 data


 18%|█▊        | 4/22 [01:04<04:18, 14.35s/it]

Finish 400 data


 23%|██▎       | 5/22 [01:09<03:03, 10.81s/it]

Finish 480 data


 27%|██▋       | 6/22 [01:14<02:20,  8.77s/it]

Bad response detected at index 31 (4096 tokens), retrying...
Successfully retried response 31
Bad response detected at index 42 (4096 tokens), retrying...
Successfully retried response 42
Bad response detected at index 49 (4096 tokens), retrying...
Successfully retried response 49
Bad response detected at index 72 (4096 tokens), retrying...
Successfully retried response 72
Bad response detected at index 74 (4096 tokens), retrying...
Successfully retried response 74
Finish 560 data


 32%|███▏      | 7/22 [01:34<03:09, 12.64s/it]

Bad response detected at index 2 (4096 tokens), retrying...
Successfully retried response 2
Bad response detected at index 15 (4096 tokens), retrying...
Successfully retried response 15
Bad response detected at index 25 (4096 tokens), retrying...
Successfully retried response 25
Bad response detected at index 29 (4096 tokens), retrying...
Successfully retried response 29
Bad response detected at index 38 (4096 tokens), retrying...
Successfully retried response 38
Finish 640 data


 36%|███▋      | 8/22 [01:54<03:31, 15.08s/it]

Bad response detected at index 6 (4096 tokens), retrying...
Successfully retried response 6
Bad response detected at index 33 (4096 tokens), retrying...
Successfully retried response 33
Format-specific parsing failed for BASE, trying fallback methods...
🔍 Starting FIXED XML answer search...
📍 Found 'Answer' token at position 345
📋 Context: prev='**' | current='Answer' | next=':**'
❌ Not XML pattern - prev doesn't end with '<' or next doesn't start with '>'
❌ No XML answer pattern found
All format-specific methods failed, using universal answer search...
Found answer keyword at position 1: ' answer'
Found answer keyword at position 345: 'Answer'
Found embedded answer letter 'A' in keyword token 'Answer' at position 345
Format-specific parsing failed for BASE, trying fallback methods...
🔍 Starting FIXED XML answer search...
❌ No XML answer pattern found
All format-specific methods failed, using universal answer search...
Found answer keyword at position 1: ' answer'
Found answer keyword 

 41%|████      | 9/22 [02:13<03:30, 16.21s/it]

Format-specific parsing failed for BASE, trying fallback methods...
🔍 Starting FIXED XML answer search...
❌ No XML answer pattern found
All format-specific methods failed, using universal answer search...
Found answer keyword at position 357: ' answer'
Found answer letter 'C' at position 360
Bad response detected at index 26 (4096 tokens), retrying...
Successfully retried response 26
Format-specific parsing failed for BASE, trying fallback methods...
🔍 Starting FIXED XML answer search...
📍 Found 'Answer' token at position 597
📋 Context: prev='.

' | current='Answer' | next=':'
❌ Not XML pattern - prev doesn't end with '<' or next doesn't start with '>'
❌ No XML answer pattern found
All format-specific methods failed, using universal answer search...
Found answer keyword at position 1: ' answer'
Found answer keyword at position 572: ' answers'
Found answer letter 'A' at position 575
Finish 800 data


 45%|████▌     | 10/22 [02:30<03:16, 16.35s/it]

Bad response detected at index 19 (4096 tokens), retrying...
Successfully retried response 19
Format-specific parsing failed for BASE, trying fallback methods...
🔍 Starting FIXED XML answer search...
📍 Found 'Answer' token at position 329
📋 Context: prev='.

' | current='Answer' | next=':'
❌ Not XML pattern - prev doesn't end with '<' or next doesn't start with '>'
❌ No XML answer pattern found
All format-specific methods failed, using universal answer search...
Found answer keyword at position 329: 'Answer'
Found embedded answer letter 'A' in keyword token 'Answer' at position 329
Finish 880 data


 50%|█████     | 11/22 [02:47<03:02, 16.55s/it]

Bad response detected at index 0 (4096 tokens), retrying...
Successfully retried response 0
Finish 960 data


 55%|█████▍    | 12/22 [03:04<02:47, 16.80s/it]

Finish 1040 data


 59%|█████▉    | 13/22 [03:16<02:16, 15.13s/it]

Finish 1120 data


 64%|██████▎   | 14/22 [03:20<01:34, 11.77s/it]

Format-specific parsing failed for BASE, trying fallback methods...
🔍 Starting FIXED XML answer search...
📍 Found 'Answer' token at position 302
📋 Context: prev=':

' | current='Answer' | next=':'
❌ Not XML pattern - prev doesn't end with '<' or next doesn't start with '>'
❌ No XML answer pattern found
All format-specific methods failed, using universal answer search...
Found answer keyword at position 299: ' answer'
Found embedded answer letter 'A' in token 'Answer' at position 302
Finish 1200 data


 68%|██████▊   | 15/22 [03:24<01:06,  9.48s/it]

Finish 1280 data


 73%|███████▎  | 16/22 [03:28<00:48,  8.02s/it]

Bad response detected at index 41 (4096 tokens), retrying...
Successfully retried response 41
Finish 1360 data


 77%|███████▋  | 17/22 [03:44<00:52, 10.47s/it]

Finish 1440 data


 82%|████████▏ | 18/22 [03:49<00:34,  8.68s/it]

Finish 1520 data


 86%|████████▋ | 19/22 [03:53<00:21,  7.31s/it]

Finish 1600 data


 91%|█████████ | 20/22 [03:57<00:12,  6.33s/it]

Finish 1680 data


 95%|█████████▌| 21/22 [04:02<00:05,  5.75s/it]

Finish 1760 data


100%|██████████| 22/22 [04:05<00:00, 11.16s/it]


In [None]:
from mmmlu_preparer.query_formats import (
    get_current_queries,
    InputFormat,
    OutputFormat,
    ShuffleMethod
)
for method in ['LONGEST_FIRST', 'SHORTEST_FIRST']:
    for input_format in ['BASE', 'JSON', 'XML']:
        for output_format in ['BASE', 'JSON_FULL', 'XML_FULL']:
            
            # BASE, JSON, XML
            if input_format == 'BASE':
                curr_input_format = InputFormat.BASE
            elif input_format == 'JSON':
                curr_input_format = InputFormat.JSON
            elif input_format == 'XML':
                curr_input_format = InputFormat.XML

            if output_format == 'BASE':
                # BASE, JSON_FULL, XML_FULL
                curr_output_format = OutputFormat.BASE
            elif output_format == 'JSON_FULL':  
                curr_output_format = OutputFormat.JSON_FULL
            elif output_format == 'XML_FULL':
                curr_output_format = OutputFormat.XML_FULL

            if method == 'LONGEST_FIRST':
                # DEFAULT, REVERSE, LONGEST_FIRST, SHORTEST_FIRST, MOST_KANA_RATIO, FEWEST_KANA_RATIO
                curr_shuffle_method = ShuffleMethod.LONGEST_FIRST
            elif method == 'SHORTEST_FIRST':
                # DEFAULT, REVERSE, LONGEST_FIRST, SHORTEST_FIRST, MOST_KANA_RATIO, FEWEST_KANA_RATIO
                curr_shuffle_method = ShuffleMethod.SHORTEST_FIRST


                input_format_save_name = curr_input_format.value.lower().replace("_", "-")
                output_format_save_name = curr_output_format.name.lower().replace("_", "-")
                shuffle_method_save_name = curr_shuffle_method.name.lower().replace("_", "-")

                language_name = curr_language.lower().replace("_", "-")

                save_name = f"{model_name}_{language_name}_{input_format_save_name}_input_{output_format_save_name}_output_{shuffle_method_save_name}_shuffle"
                save_name

                curr_queries = get_current_queries(mmmlu_subset,
                                                dataset_language_enum,
                                                chosen_subtasks,
                                                curr_input_format,
                                                curr_output_format,
                                                curr_shuffle_method,
                                                )


                save_path = f"{save_dir}/{save_name}.jsonl"
                results = await run_experiemnts(curr_queries, save_path, try_first_n=None, curr_output_format = curr_output_format)
                await asyncio.sleep(20)

In [None]:
from pathlib import Path
import json
save_path = f"{save_dir}/{save_name}.jsonl"
target_save_path = Path(save_path)
with target_save_path.open('r', encoding='utf-8') as file:
    result_dicts = [json.loads(line) for line in file if line.strip()]

In [None]:
print(result_dicts[1]['kwargs']['content'])

In [None]:
output_text = [result['kwargs']['content'] for result in result_dicts]
output_answer = []
none_answer_indice = []
none_answer_output = []
for idx, output in enumerate(output_text):
    extracted_answer = extract_answer_from_response(output)
    output_answer.append(extracted_answer)
    if extracted_answer is None:
        none_answer_indice.append(idx)
        none_answer_output.append(output)
        #print(idx)
        print(f"{idx}:\n{output = }\n")

print(f"{len(none_answer_indice) = }")

In [None]:
import numpy as np
output_tokens_list = [result['kwargs']['usage_metadata']['total_tokens'] for result in result_dicts]
print(np.argsort(-np.array(output_tokens_list)).tolist())
print(np.sort(-np.array(output_tokens_list)).tolist())
print(f"median: {np.median(output_tokens_list)}")
print(f"mean: {np.mean(output_tokens_list)}")

In [None]:
print(result_dicts[7]['kwargs']['content'])

In [None]:
# from langchain_core.output_parsers.json import JsonOutputParser
# import re
# from tqdm.auto import tqdm
# parser = JsonOutputParser()

# for result in tqdm(result_dicts):
#     #try:
#     string = result['kwargs']['content']

#     def escape_single_backslash(match):
#         c = match.group(0)
#         return c.replace("\\", "\\\\")

#     # ChatGPT
#     string = re.sub(r'(?<!\\)\\(?![\\ntbrf"u])', escape_single_backslash, string)

#     x = (parser.parse(string))