diff --git a/docs/prompt_templates.md b/docs/prompt_templates.md index 47c2542c35..d36ad2d6ba 100644 --- a/docs/prompt_templates.md +++ b/docs/prompt_templates.md @@ -117,3 +117,26 @@ This is intended to use for [rinna/bilingual-gpt-neox-4b-instruction-sft](https: システム: {answer} ``` For formats for other tasks, please see `lm_eval/tasks/TASK.py`. + + +## `0.6` +This is intended to used for Llama2-chat variants + +- **Reference:** https://huggingface.co/blog/llama2#how-to-prompt-llama-2 +- **Supported Tasks:** `jsquad`, `jaquad`, `jcommonsenseqa`, `jnli`, `marc_ja`, `jaqket_v2`, `xlsum_ja`, `mgsm` +- **Usage:** Set the correct system prompt to an envrionment variable `SYSTEM_PROMPT`. +- **Format:** + e.g. JCommonsenseQA + ``` + [INST] <> + {{ SYSTEM_PROMPT }} + <> + + 与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください: + - choice0 + ... + - choice4 + + 質問:... [/INST] {{ answer }} + ``` + For formats for other tasks, please see `lm_eval/tasks/TASK.py`. diff --git a/lm_eval/base.py b/lm_eval/base.py index 83a9a5a500..3749289f68 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -305,7 +305,6 @@ def _collate(x): for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( chunk, multi_logits, inps, inplens, cont_toks_list ): - # Slice to original seq length contlen = len(cont_toks) logits = logits[inplen - contlen : inplen].unsqueeze( diff --git a/lm_eval/tasks/ja/jaqket_v2.py b/lm_eval/tasks/ja/jaqket_v2.py index 92981f49d1..f7f4cf1ecd 100644 --- a/lm_eval/tasks/ja/jaqket_v2.py +++ b/lm_eval/tasks/ja/jaqket_v2.py @@ -470,12 +470,58 @@ class JAQKETV2WithRinnaBilingualInstructionSFT(JAQKETV2WithRinnaInstructionSFT): FEWSHOT_SEP = "\n" +class JAQKETV2WithLlama2(JAQKETV2WithJAAlpacaPrompt): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3 + ``` + 与えられた文脈から、質問に対する答えを抜き出してください。 + + 文脈:{context} + 質問:{question} [/INST] + ``` + """ + context = self.SEP.join([text for text in doc["ctxs"]["text"]]) + answer_candidate = "文脈:" + context + qa_prompt = self.doc_to_qa_prompt(doc) + return f"{self.INSTRUCTION}\n\n{answer_candidate}\n{qa_prompt} [/INST] " + + def doc_to_answering_text(self, doc): + has_answer = doc["ctxs"]["has_answer"] + answering_index = has_answer.index(True) + answering_contexts = { + k: v[answering_index : answering_index + 1] for k, v in doc["ctxs"].items() + } + answer_candidate = "文脈:" + answering_contexts["text"][0] + qa_prompt = self.doc_to_qa_prompt(doc) + return f"{self.INSTRUCTION}\n\n{answer_candidate}\n{qa_prompt} [/INST] " + + VERSIONS = [ JAQKETV2, JAQKETV2WithFintanPrompt, JAQKETV2WithJAAlpacaPrompt, JAQKETV2WithRinnaInstructionSFT, JAQKETV2WithRinnaBilingualInstructionSFT, + JAQKETV2WithLlama2, ] diff --git a/lm_eval/tasks/ja/jaquad.py b/lm_eval/tasks/ja/jaquad.py index cf00bd0ab3..0ea51f82a4 100644 --- a/lm_eval/tasks/ja/jaquad.py +++ b/lm_eval/tasks/ja/jaquad.py @@ -16,6 +16,7 @@ JSQuADWithJAAlpacaPrompt, JSQuADWithRinnaInstructionSFT, JSQuADWithRinnaBilingualInstructionSFT, + JSQuADWithLlama2, ) @@ -75,12 +76,17 @@ class JaQuADWithRinnaBilingualInstructionSFT( PROMPT_VERSION = 0.5 +class JaQuADWithLlama2(JSQuADWithLlama2, JaQuAD): + PROMPT_VERSION = 0.6 + + VERSIONS = [ JaQuAD, JaQuADWithFintanPrompt, JaQuADWithJAAlpacaPrompt, JaQuADWithRinnaInstructionSFT, JaQuADWithRinnaBilingualInstructionSFT, + JaQuADWithLlama2, ] diff --git a/lm_eval/tasks/ja/jcommonsenseqa.py b/lm_eval/tasks/ja/jcommonsenseqa.py index e098a9c0d0..0cc6e91bd6 100644 --- a/lm_eval/tasks/ja/jcommonsenseqa.py +++ b/lm_eval/tasks/ja/jcommonsenseqa.py @@ -7,6 +7,7 @@ Homepage: https://github.com/yahoojapan/JGLUE """ +import os from lm_eval.base import MultipleChoiceTask, rf import numpy as np @@ -204,12 +205,51 @@ class JCommonsenseQAWithRinnaBilingualInstructionSFT( FEWSHOT_SEP = "\n" +class JCommonsenseQAWithLlama2(JCommonsenseQA): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + INSTRUCTION = "与えられた5つの選択肢の中から、最適な答えを選んでください。" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3 + ``` + 与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください: + - choice0 + ... + - choice4 + + 質問:... [/INST] + ``` + """ + choices = "\n".join([f"- {choice}" for choice in doc["choices"]]) + instruction_text = self.INSTRUCTION + f"出力は以下から選択してください:\n{choices}" + input_text = f"質問:{doc['goal']}" + return f"{instruction_text}\n\n{input_text} [/INST] " + + VERSIONS = [ JCommonsenseQA, JCommonsenseQAWithFintanPrompt, JCommonsenseQAWithJAAlpacaPrompt, JCommonsenseQAWithRinnaInstructionSFT, JCommonsenseQAWithRinnaBilingualInstructionSFT, + JCommonsenseQAWithLlama2, ] diff --git a/lm_eval/tasks/ja/jnli.py b/lm_eval/tasks/ja/jnli.py index a457dc7317..2f474d14f5 100644 --- a/lm_eval/tasks/ja/jnli.py +++ b/lm_eval/tasks/ja/jnli.py @@ -7,6 +7,7 @@ Homepage: https://github.com/yahoojapan/JGLUE """ +import os from lm_eval.base import BalancedMultipleChoiceTask, rf _CITATION = """ @@ -156,11 +157,50 @@ class JNLIWithRinnaBilingualInstructionSFT(JNLIWithRinnaInstructionSFT): FEWSHOT_SEP = "\n" +class JNLIWithLlama2(JNLIWithJAAlpacaPrompt): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3 + ``` + 与えられた前提と仮説の関係を回答してください。 + + 出力は以下から選択してください: + entailment + contradiction + neutral + + 前提:{premise} + 仮説:{hypothesis} [/INST] + ``` + """ + input_text = f"前提:{doc['premise']}\n仮説:{doc['hypothesis']}" + return f"{self.INSTRUCTION}\n\n{input_text} [/INST] " + + VERSIONS = [ JNLIWithFintanPrompt, JNLIWithJAAlpacaPrompt, JNLIWithRinnaInstructionSFT, JNLIWithRinnaBilingualInstructionSFT, + JNLIWithLlama2, ] diff --git a/lm_eval/tasks/ja/jsquad.py b/lm_eval/tasks/ja/jsquad.py index ee14f4cb12..4ec544f654 100644 --- a/lm_eval/tasks/ja/jsquad.py +++ b/lm_eval/tasks/ja/jsquad.py @@ -367,6 +367,59 @@ def doc_to_text(self, doc): return f"ユーザー: {input_text}{self.SEP}システム: " +class JSQuADWithLlama2(JSQuAD): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + INSTRUCTION = "与えられた文脈から、質問に対する答えを抜き出してください。" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}` + ``` + 与えられた文脈から、質問に対する答えを抜き出してください。 + + 文脈:... + 質問:... [/INST] + ``` + """ + input_text = ( + f"文脈:{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}" + ) + return f"{self.INSTRUCTION}\n\n{input_text} [/INST] " + + +class JSQuADWithLlama2V12(JSQuADWithLlama2): + VERSION = 1.2 + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}` + ``` + 与えられた文脈から、質問に対する答えを抜き出してください。 + + 文脈:... + 質問:... [/INST] + ``` + """ + input_text = f"文脈:{doc['title']}\n{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}" + return f"{self.INSTRUCTION}\n\n{input_text} [/INST] " + + VERSIONS = [ JSQuAD, JSQuADWithFintanPrompt, @@ -377,6 +430,8 @@ def doc_to_text(self, doc): JSQuADWithRinnaInstructionSFTV12, JSQuADWithRinnaBilingualInstructionSFT, JSQuADWithRinnaBilingualInstructionSFTV12, + JSQuADWithLlama2, + JSQuADWithLlama2V12, ] diff --git a/lm_eval/tasks/ja/marc_ja.py b/lm_eval/tasks/ja/marc_ja.py index b1d13712f0..46f4261843 100644 --- a/lm_eval/tasks/ja/marc_ja.py +++ b/lm_eval/tasks/ja/marc_ja.py @@ -7,6 +7,7 @@ Homepage: https://github.com/yahoojapan/JGLUE """ +import os from lm_eval.base import BalancedMultipleChoiceTask, rf _CITATION = """ @@ -154,11 +155,44 @@ class MARCJaWithRinnaBilingualInstructionSFT(MARCJaWithRinnaInstructionSFT): FEWSHOT_SEP = "\n" +class MARCJaWithLlama2(MARCJaWithJAAlpacaPrompt): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3 + ``` + 以下の製品レビューを、ポジティブまたはネガティブの感情クラスのいずれかに分類してください。 + + {query} [/INST] + ``` + """ + input_text = doc["query"] + return f"{self.INSTRUCTION}\n\n{input_text} [/INST] " + + VERSIONS = [ MARCJaWithFintanPrompt, MARCJaWithJAAlpacaPrompt, MARCJaWithRinnaInstructionSFT, MARCJaWithRinnaBilingualInstructionSFT, + MARCJaWithLlama2, ] diff --git a/lm_eval/tasks/ja/mgsm.py b/lm_eval/tasks/ja/mgsm.py index 0d3596c8ca..e818427799 100644 --- a/lm_eval/tasks/ja/mgsm.py +++ b/lm_eval/tasks/ja/mgsm.py @@ -4,6 +4,7 @@ Multilingual Grade School Math problems with a numerical answer and a chain-of-thought prompt. """ +import os from lm_eval.base import rf from lm_eval.tasks.gsm8k import GradeSchoolMath8K, INVALID_ANS import re @@ -24,7 +25,6 @@ class MGSM(GradeSchoolMath8K): - DATASET_PATH = "juletxara/mgsm" DATASET_NAME = "ja" @@ -164,11 +164,44 @@ class MGSMWithRinnaBilingualInstructionSFT(MGSMWithRinnaInstructionSFT): FEWSHOT_SEP = "\n" +class MGSMWithLlama2(MGSMWithJAAlpacaPrompt): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3 + ``` + 与えられた問題に対して、ステップごとに答えを導き出してください。 + + {question} [/INST] + ``` + """ + input_text = f"{doc['question'].replace('問題:','')}" + return f"{self.INSTRUCTION}\n\n{input_text} [/INST] " + + VERSIONS = [ MGSM, MGSMWithJAAlpacaPrompt, MGSMWithRinnaInstructionSFT, MGSMWithRinnaBilingualInstructionSFT, + MGSMWithLlama2, ] diff --git a/lm_eval/tasks/ja/wikilingua.py b/lm_eval/tasks/ja/wikilingua.py index 5feef947b8..ab91cc2143 100644 --- a/lm_eval/tasks/ja/wikilingua.py +++ b/lm_eval/tasks/ja/wikilingua.py @@ -6,6 +6,7 @@ Homepage: https://github.com/esdurmus/Wikilingua """ +import os import numpy as np import datasets from lm_eval.base import rf, Task @@ -164,11 +165,44 @@ class WikilinguaWithRinnaBilingualInstructionSFT(WikilinguaWithRinnaInstructionS FEWSHOT_SEP = "\n" +class WikilinguaWithLlama2(Wikilingua): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3 + ``` + 与えられたニュース記事を要約してください。 + + ニュース記事:{doc} [/INST] + ``` + """ + input_text = f"ニュース記事:{doc['text']}" + return f"{self.INSTRUCTION}\n\n{input_text} [/INST] " + + VERSIONS = [ Wikilingua, WikilinguaWithJAAlpacaPrompt, WikilinguaWithRinnaInstructionSFT, WikilinguaWithRinnaBilingualInstructionSFT, + WikilinguaWithLlama2, ] diff --git a/lm_eval/tasks/ja/xlsum_ja.py b/lm_eval/tasks/ja/xlsum_ja.py index 56ac3fac7f..014193a89b 100644 --- a/lm_eval/tasks/ja/xlsum_ja.py +++ b/lm_eval/tasks/ja/xlsum_ja.py @@ -238,11 +238,53 @@ class XLSumJaWithRinnaBilingualInstructionSFT(XLSumJaWithRinnaInstructionSFT): FEWSHOT_SEP = "\n" +class XLSumJaWithLlama2(XLSumJa): + """ + This prompt version follows the Llama2-chat's prompt format: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + """ + + PROMPT_VERSION = 0.6 + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + DESCRIPTION = f"[INST] <>\n{SYSTEM_PROMPT}\n<>\n\n" + INSTRUCTION = "与えられたニュース記事を要約してください。" + FEWSHOT_SEP = " [INST] " + + def doc_to_text(self, doc): + """ + Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3 + ``` + 与えられたニュース記事を要約してください。 + + ニュース記事:{doc} [/INST] + ``` + """ + input_text = f"ニュース記事:{doc['text']}" + return f"{self.INSTRUCTION}\n\n{input_text} [/INST] " + + def preprocess_ctx(self, ctx, max_length): + return super().preprocess_ctx( + ctx, + max_length, + ctx_prompt=f"{self.INSTRUCTION}\n\n", + summary_prompt=" [/INST] ", + ) + + VERSIONS = [ XLSumJa, XLSumJaWithJAAlpacaPrompt, XLSumJaWithRinnaInstructionSFT, XLSumJaWithRinnaBilingualInstructionSFT, + XLSumJaWithLlama2, ]