Skip to content

Commit

Permalink
add llama2 format (Stability-AI#100)
Browse files Browse the repository at this point in the history
* add llama2 format

* add 0.6 in prompt_templates.md

* make pre-commit pass

* remove debugging line
  • Loading branch information
mkshing authored and polm committed Oct 11, 2023
1 parent 5edb386 commit fac8377
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 2 deletions.
23 changes: 23 additions & 0 deletions docs/prompt_templates.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,26 @@ This is intended to use for [rinna/bilingual-gpt-neox-4b-instruction-sft](https:
システム: {answer}
```
For formats for other tasks, please see `lm_eval/tasks/TASK.py`.


## `0.6`
This is intended to used for Llama2-chat variants

- **Reference:** https://huggingface.co/blog/llama2#how-to-prompt-llama-2
- **Supported Tasks:** `jsquad`, `jaquad`, `jcommonsenseqa`, `jnli`, `marc_ja`, `jaqket_v2`, `xlsum_ja`, `mgsm`
- **Usage:** Set the correct system prompt to an envrionment variable `SYSTEM_PROMPT`.
- **Format:**
e.g. JCommonsenseQA
```
<s>[INST] <<SYS>>
{{ SYSTEM_PROMPT }}
<</SYS>>
与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください:
- choice0
...
- choice4
質問:... [/INST] {{ answer }} </s>
```
For formats for other tasks, please see `lm_eval/tasks/TASK.py`.
1 change: 0 additions & 1 deletion lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def _collate(x):
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list
):

# Slice to original seq length
contlen = len(cont_toks)
logits = logits[inplen - contlen : inplen].unsqueeze(
Expand Down
46 changes: 46 additions & 0 deletions lm_eval/tasks/ja/jaqket_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,58 @@ class JAQKETV2WithRinnaBilingualInstructionSFT(JAQKETV2WithRinnaInstructionSFT):
FEWSHOT_SEP = "\n"


class JAQKETV2WithLlama2(JAQKETV2WithJAAlpacaPrompt):
"""
This prompt version follows the Llama2-chat's prompt format:
```
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
```
reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
"""

PROMPT_VERSION = 0.6
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT)
DESCRIPTION = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
FEWSHOT_SEP = " </s><s>[INST] "

def doc_to_text(self, doc):
"""
Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3
```
与えられた文脈から、質問に対する答えを抜き出してください。
文脈:{context}
質問:{question} [/INST]
```
"""
context = self.SEP.join([text for text in doc["ctxs"]["text"]])
answer_candidate = "文脈:" + context
qa_prompt = self.doc_to_qa_prompt(doc)
return f"{self.INSTRUCTION}\n\n{answer_candidate}\n{qa_prompt} [/INST] "

def doc_to_answering_text(self, doc):
has_answer = doc["ctxs"]["has_answer"]
answering_index = has_answer.index(True)
answering_contexts = {
k: v[answering_index : answering_index + 1] for k, v in doc["ctxs"].items()
}
answer_candidate = "文脈:" + answering_contexts["text"][0]
qa_prompt = self.doc_to_qa_prompt(doc)
return f"{self.INSTRUCTION}\n\n{answer_candidate}\n{qa_prompt} [/INST] "


VERSIONS = [
JAQKETV2,
JAQKETV2WithFintanPrompt,
JAQKETV2WithJAAlpacaPrompt,
JAQKETV2WithRinnaInstructionSFT,
JAQKETV2WithRinnaBilingualInstructionSFT,
JAQKETV2WithLlama2,
]


Expand Down
6 changes: 6 additions & 0 deletions lm_eval/tasks/ja/jaquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
JSQuADWithJAAlpacaPrompt,
JSQuADWithRinnaInstructionSFT,
JSQuADWithRinnaBilingualInstructionSFT,
JSQuADWithLlama2,
)


Expand Down Expand Up @@ -75,12 +76,17 @@ class JaQuADWithRinnaBilingualInstructionSFT(
PROMPT_VERSION = 0.5


class JaQuADWithLlama2(JSQuADWithLlama2, JaQuAD):
PROMPT_VERSION = 0.6


VERSIONS = [
JaQuAD,
JaQuADWithFintanPrompt,
JaQuADWithJAAlpacaPrompt,
JaQuADWithRinnaInstructionSFT,
JaQuADWithRinnaBilingualInstructionSFT,
JaQuADWithLlama2,
]


Expand Down
40 changes: 40 additions & 0 deletions lm_eval/tasks/ja/jcommonsenseqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Homepage: https://github.com/yahoojapan/JGLUE
"""
import os
from lm_eval.base import MultipleChoiceTask, rf
import numpy as np

Expand Down Expand Up @@ -204,12 +205,51 @@ class JCommonsenseQAWithRinnaBilingualInstructionSFT(
FEWSHOT_SEP = "\n"


class JCommonsenseQAWithLlama2(JCommonsenseQA):
"""
This prompt version follows the Llama2-chat's prompt format:
```
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
```
reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
"""

PROMPT_VERSION = 0.6
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT)
DESCRIPTION = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
INSTRUCTION = "与えられた5つの選択肢の中から、最適な答えを選んでください。"
FEWSHOT_SEP = " </s><s>[INST] "

def doc_to_text(self, doc):
"""
Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3
```
与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください:
- choice0
...
- choice4
質問:... [/INST]
```
"""
choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
instruction_text = self.INSTRUCTION + f"出力は以下から選択してください:\n{choices}"
input_text = f"質問:{doc['goal']}"
return f"{instruction_text}\n\n{input_text} [/INST] "


VERSIONS = [
JCommonsenseQA,
JCommonsenseQAWithFintanPrompt,
JCommonsenseQAWithJAAlpacaPrompt,
JCommonsenseQAWithRinnaInstructionSFT,
JCommonsenseQAWithRinnaBilingualInstructionSFT,
JCommonsenseQAWithLlama2,
]


Expand Down
40 changes: 40 additions & 0 deletions lm_eval/tasks/ja/jnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Homepage: https://github.com/yahoojapan/JGLUE
"""
import os
from lm_eval.base import BalancedMultipleChoiceTask, rf

_CITATION = """
Expand Down Expand Up @@ -156,11 +157,50 @@ class JNLIWithRinnaBilingualInstructionSFT(JNLIWithRinnaInstructionSFT):
FEWSHOT_SEP = "\n"


class JNLIWithLlama2(JNLIWithJAAlpacaPrompt):
"""
This prompt version follows the Llama2-chat's prompt format:
```
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
```
reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
"""

PROMPT_VERSION = 0.6
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT)
DESCRIPTION = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
FEWSHOT_SEP = " </s><s>[INST] "

def doc_to_text(self, doc):
"""
Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3
```
与えられた前提と仮説の関係を回答してください。
出力は以下から選択してください:
entailment
contradiction
neutral
前提:{premise}
仮説:{hypothesis} [/INST]
```
"""
input_text = f"前提:{doc['premise']}\n仮説:{doc['hypothesis']}"
return f"{self.INSTRUCTION}\n\n{input_text} [/INST] "


VERSIONS = [
JNLIWithFintanPrompt,
JNLIWithJAAlpacaPrompt,
JNLIWithRinnaInstructionSFT,
JNLIWithRinnaBilingualInstructionSFT,
JNLIWithLlama2,
]


Expand Down
55 changes: 55 additions & 0 deletions lm_eval/tasks/ja/jsquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,59 @@ def doc_to_text(self, doc):
return f"ユーザー: {input_text}{self.SEP}システム: "


class JSQuADWithLlama2(JSQuAD):
"""
This prompt version follows the Llama2-chat's prompt format:
```
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
```
reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
"""

PROMPT_VERSION = 0.6
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT)
DESCRIPTION = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
INSTRUCTION = "与えられた文脈から、質問に対する答えを抜き出してください。"
FEWSHOT_SEP = " </s><s>[INST] "

def doc_to_text(self, doc):
"""
Insert the following prompt into `{{ user_msg }}`
```
与えられた文脈から、質問に対する答えを抜き出してください。
文脈:...
質問:... [/INST]
```
"""
input_text = (
f"文脈:{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
)
return f"{self.INSTRUCTION}\n\n{input_text} [/INST] "


class JSQuADWithLlama2V12(JSQuADWithLlama2):
VERSION = 1.2

def doc_to_text(self, doc):
"""
Insert the following prompt into `{{ user_msg }}`
```
与えられた文脈から、質問に対する答えを抜き出してください。
文脈:...
質問:... [/INST]
```
"""
input_text = f"文脈:{doc['title']}\n{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
return f"{self.INSTRUCTION}\n\n{input_text} [/INST] "


VERSIONS = [
JSQuAD,
JSQuADWithFintanPrompt,
Expand All @@ -377,6 +430,8 @@ def doc_to_text(self, doc):
JSQuADWithRinnaInstructionSFTV12,
JSQuADWithRinnaBilingualInstructionSFT,
JSQuADWithRinnaBilingualInstructionSFTV12,
JSQuADWithLlama2,
JSQuADWithLlama2V12,
]


Expand Down
34 changes: 34 additions & 0 deletions lm_eval/tasks/ja/marc_ja.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Homepage: https://github.com/yahoojapan/JGLUE
"""
import os
from lm_eval.base import BalancedMultipleChoiceTask, rf

_CITATION = """
Expand Down Expand Up @@ -154,11 +155,44 @@ class MARCJaWithRinnaBilingualInstructionSFT(MARCJaWithRinnaInstructionSFT):
FEWSHOT_SEP = "\n"


class MARCJaWithLlama2(MARCJaWithJAAlpacaPrompt):
"""
This prompt version follows the Llama2-chat's prompt format:
```
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
```
reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
"""

PROMPT_VERSION = 0.6
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT)
DESCRIPTION = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
FEWSHOT_SEP = " </s><s>[INST] "

def doc_to_text(self, doc):
"""
Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3
```
以下の製品レビューを、ポジティブまたはネガティブの感情クラスのいずれかに分類してください。
{query} [/INST]
```
"""
input_text = doc["query"]
return f"{self.INSTRUCTION}\n\n{input_text} [/INST] "


VERSIONS = [
MARCJaWithFintanPrompt,
MARCJaWithJAAlpacaPrompt,
MARCJaWithRinnaInstructionSFT,
MARCJaWithRinnaBilingualInstructionSFT,
MARCJaWithLlama2,
]


Expand Down
Loading

0 comments on commit fac8377

Please sign in to comment.