Skip to content

Commit

Permalink
Merge pull request #7 from wellecks/mmlu-minerva
Browse files Browse the repository at this point in the history
Minerva MMLU-STEM Replication
  • Loading branch information
zhangir-azerbayev committed Jul 12, 2023
2 parents 7dd95f4 + 6416483 commit fef9d47
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 0 deletions.
3 changes: 3 additions & 0 deletions configs/config_mmlu_cot.json
@@ -0,0 +1,3 @@
{
"minerva-hendrycksTest-abstract_algebra": {"params": {"majority_voting": 16,"sampling_temperature":0.3,"eval_batch_size":16}}
}
2 changes: 2 additions & 0 deletions lm_eval/tasks/__init__.py
Expand Up @@ -55,6 +55,7 @@
from . import crowspairs
from . import lila
from . import proofnet
from . import hendrycks_test_cot

########################################
# Translation tasks
Expand Down Expand Up @@ -364,6 +365,7 @@
# ProofNet
"proofnet_autoformalize_statements": proofnet.ProofNetAutoformalizeStatements,
"proofnet_informalize_statements": proofnet.ProofNetInformalizeStatements,
**hendrycks_test_cot.create_all_mcqa_tasks(),
#
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
Expand Down
312 changes: 312 additions & 0 deletions lm_eval/tasks/hendrycks_test_cot.py
@@ -0,0 +1,312 @@
"""
Measuring Massive Multitask Language Understanding
https://arxiv.org/pdf/2009.03300.pdf
The Hendryck's Test is a benchmark that measured a text model’s multitask accuracy.
The test covers 57 tasks including elementary mathematics, US history, computer
science, law, and more. To attain high accuracy on this test, models must possess
extensive world knowledge and problem solving ability. By comprehensively evaluating
the breadth and depth of a model’s academic and professional understanding,
Hendryck's Test can be used to analyze models across many tasks and to identify
important shortcomings.
Homepage: https://github.com/hendrycks/test
"""

"""
Solving Quantitative Reasoning Problems with Language Models
https://arxiv.org/pdf/2206.14858.pdf
Minerva CoT version of MMLU-STEM. See Appendix G for prompt reference.
`SUBJECTS_MCQA` consists of those subsets (STEM subsets that "use equations") for which the few-shot prompt is Listing 5 from the
Minerva Appendix G. SUBJECTS_CUSTOM use a special subject-specific prompt, which are listed as being in the supplementary material but
do not appear to be included in the current downloadable zip.
"""

from lm_eval.base import Task, rf
from lm_eval.metrics import mean

import re

_CITATION = """
@article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding},
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2021}
}
@misc{lewkowycz2022solving,
title={Solving Quantitative Reasoning Problems with Language Models},
author={Aitor Lewkowycz and Anders Andreassen and David Dohan and Ethan Dyer and Henryk Michalewski and Vinay Ramasesh and Ambrose Slone and Cem Anil and Imanol Schlag and Theo Gutman-Solo and Yuhuai Wu and Behnam Neyshabur and Guy Gur-Ari and Vedant Misra},
year={2022},
eprint={2206.14858},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""

SUBJECTS_MCQA = [
"abstract_algebra",
"college_mathematics",
"college_physics",
"elementary_mathematics",
"high_school_mathematics",
"high_school_physics",
"high_school_statistics",
]

SUBJECTS_CUSTOM = [
"astronomy",
"college_biology",
"college_chemistry",
"college_computer_science",
"computer_security",
"conceptual_physics",
"electrical_engineering",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"machine_learning"
]


SUBJECTS_STEM = SUBJECTS_MCQA + SUBJECTS_CUSTOM


MCQA_PROMPT = """\
Problem:
Find the domain of the expression $\frac{\sqrt{x-2}}{\sqrt{5-x}}$.
What of the following is the right choice? Explain you answer.
(A) [-5,-2), (B) [2,5), (C) [-2,-5), (D) [5,2)
Solution:
The expressions inside each square root must be non-negative. Therefore, $x-2 \ge 0$, so $x\ge2$, and $5 - x \
ge 0$, so $x \le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.
Therefore, the domain of the expression is $\boxed{[2,5)}$.
Final Answer: The final answer is (B). I hope it is correct.
Problem:
If $\det \mathbf{A} = 2$ and $\det \mathbf{B} = 12,$ then find $\det (\mathbf{A} \mathbf{B}).$
What of the following is the right choice? Explain you answer.
(A) 14, (B) 4, (C) 2, (D) 24
Solution:
We have that $\det (\mathbf{A} \mathbf{B}) = (\det \mathbf{A})(\det \mathbf{B}) = (2)(12) = \boxed{24}.$
Final Answer: The final answer is (D). I hope it is correct.
Problem:
Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times \
must Terrell lift them in order to lift the same total weight?
What of the following is the right choice? Explain you answer.
(A) 12, (B) 20, (C) 16, (D) 15
Solution:
If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\cdot 12\cdot20=480$ pounds of weight. \
If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\cdot15\cdot n=30n$ \
pounds of weight. Equating this to 480 pounds, we can solve for $n$: \begin{align*}
30n&=480\\
\Rightarrow\qquad n&=480/30=\boxed{16}
\end{align*}
Final Answer: The final answer is (C). I hope it is correct.
Problem:
If the system of equations
\begin{align*}
6x-4y&=a,\\
6y-9x &=b.
\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\frac{a}{b},$ assuming $b$ is
nonzero.
What of the following is the right choice? Explain you answer.
(A) $-\frac{2}{3}$, (B) $\frac{2}{3}$, (C) $\frac{1}{3}$, (D) $\frac{4}{9}$
Solution:
If we multiply the first equation by $-\frac{3}{2}$, we obtain
$$6y-9x=-\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have
$$-\frac{3}{2}a=b\Rightarrow\frac{a}{b}=\boxed{-\frac{2}{3}}.$$
Final Answer: The final answer is (A). I hope it is correct.
"""


def create_all_mcqa_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
"""
return {f"minerva-hendrycksTest-{sub}": create_mcqa_task(sub) for sub in SUBJECTS_MCQA}


def create_mcqa_task(subject):
class HendrycksTest(MinervaCoTMMLU):
def __init__(self):
super().__init__(subject)

return HendrycksTest


class MinervaCoTMMLU(Task):
VERSION = 0
DATASET_PATH = "hendrycks_test"
DATASET_NAME = None
MAJORITY_VOTING = "majority_voting"
SAMPLING_TEMPERATURE = "sampling_temperature"
EVAL_BATCH_SIZE = "eval_batch_size"

ANS_RE = re.compile(r"Final Answer: The final answer is \([ABCD]\). I hope it is correct.")
INVALID_ANS = "[not found]"

def __init__(self, subject):
self.DATASET_NAME = subject
super().__init__()

def has_training_docs(self):
return False

def has_validation_docs(self):
return True

def has_test_docs(self):
return True

def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])

def test_docs(self):
return map(self._process_doc, self.dataset["test"])

def _process_doc(self, doc):
def format_example(doc, keys):
"""
Problem: <prompt>
What of the following is the right choice? Explain you answer.
(A) <choice1>, (B) <choice2>, (C) <choice3>, (D) <choice4>
Solution:
"""
prompt = MCQA_PROMPT + "\n\n" + "Problem: " + doc["question"] + "\nWhat of the following is the right choice? Explain you answer.\n"
prompt += ", ".join(
[f"{key} {choice}" for key, choice in zip(keys, doc["choices"])]
)
prompt += "\nSolution:"
return prompt

keys = ["(A)", "(B)", "(C)", "(D)"]
return {
"query": format_example(doc, keys),
"choices": doc["choices"],
"gold": keys.index(doc["answer"])
if isinstance(doc["answer"], str)
else keys[doc["answer"]],
}

def doc_to_text(self, doc):
return doc["query"]

def construct_requests(self, doc, ctx, params={}):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
if params == {}:
return rf.generate(ctx, ["\n\n", "Problem:"])

majority_voting_value = int(params.get(self.MAJORITY_VOTING, 1))
sampling_temperature_value = float(params.get(self.SAMPLING_TEMPERATURE, 1.0))
eval_batch_size = params.get(self.EVAL_BATCH_SIZE, None)
eval_batch_size = int(eval_batch_size) if isinstance(eval_batch_size, str) else eval_batch_size
generation_params = {
'num_return_sequences': majority_voting_value,
'temperature': sampling_temperature_value,
'num_return_sequences_batch': eval_batch_size
}
completion = rf.generate(ctx, ["\n\n", "Problem:"], generation_params)
return completion

def majority_vote(self, candidates):

answer_votes = {}
for answer in candidates:
answer_votes[answer] = answer_votes.get(answer, 0) + 1

max_vote = 0
elected = None
for answer, vote in answer_votes.items():
if vote > max_vote and answer is not None and answer is not self.INVALID_ANS:
elected = answer
max_vote = vote
return elected

def process_results(self, doc, results, params={}):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
candidates = results[0]
assert isinstance(params, dict)
if params == {}:
completion = self._extract_answer(candidates)
elif self.MAJORITY_VOTING in params:
completion = self.majority_vote([self._extract_answer(c) for c in candidates])
else:
raise AssertionError

answer = doc["gold"]
return {
"acc": self._is_correct(completion, answer),
"metadata": {
"selected_answer": completion,
"candidates": candidates
}
}

def _extract_answer(self, completion):
match = self.ANS_RE.search(completion)
if match is not None:
match_str = match.group(0)
match_str = match_str.lstrip("Final Answer: The final answer is ").rstrip(". I hope it is correct.")
print(match_str)
return match_str
else:
return self.INVALID_ANS

def _is_correct(self, completion, answer):
gold = answer
assert gold != self.INVALID_ANS, "No ground truth answer found in the document."
return completion == gold

def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't

assert k == 0, "Custom Minerva MMLU prompt hardcodes fewshot context! Must use num_fewshot=0 here"
return None

def doc_to_text(self, doc):
return doc["query"]

def doc_to_target(self, doc):
raise NotImplementedError("Should not rely on doc_to_target for pure-zeroshot Minerva-MMLU(STEM)")

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["query"]

def aggregation(self):
return {"acc": mean}

def higher_is_better(self):
return {"acc": True}

2 changes: 2 additions & 0 deletions setup.py
Expand Up @@ -36,6 +36,8 @@
"tqdm-multiprocess",
"transformers>=4.1",
"zstandard",
"accelerate",
"timeout_decorator"
],
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
Expand Down

0 comments on commit fef9d47

Please sign in to comment.