Skip to content

Commit

Permalink
Fix token generation
Browse files Browse the repository at this point in the history
  • Loading branch information
gokulavasan committed Jan 12, 2024
1 parent 4c37715 commit d744fe2
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions torchtune/datasets/slimorca.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,21 @@ def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
text = conversation["value"]
agent_text_dict[agent] = text

# If system value is present
# Llama2 Chat Format - https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L284
if len(agent_text_dict["system"]) > 0:
prompt = f"<s>[INST] <<SYS>> {agent_text_dict['system']} <</SYS>> {agent_text_dict['human']} [/INST]"
prompt = f"[INST] <<SYS>> {agent_text_dict['system']} <</SYS>> {agent_text_dict['human']} [/INST] "
else:
prompt = f"<s>[INST] {agent_text_dict['human']} [/INST]"
prompt = f"[INST] {agent_text_dict['human']} [/INST] "

response = f"{agent_text_dict['gpt']} </s>"
prompt_and_response = prompt + f"{response}"
response = f"{agent_text_dict['gpt']} "

prompt_tokens = self._tokenizer.encode(prompt, add_bos=False, add_eos=False)
print("Prompt Token len ", len(prompt_tokens))
input = self._tokenizer.encode(
prompt_and_response, add_bos=False, add_eos=False
)
print("Prompt and Response Token len ", len(input))
label_tokens = self._tokenizer.encode(response, add_bos=False, add_eos=False)
print("Label Token len ", len(label_tokens))
prompt_tokens = self._tokenizer.encode(prompt, add_bos=True, add_eos=False)
label_tokens = self._tokenizer.encode(response, add_bos=False, add_eos=True)
input = prompt_tokens + label_tokens

length_input_tokens = len(prompt)
label = [
_CROSS_ENTROPY_IGNORE_IDX for _ in range(len(prompt_tokens))
] + label_tokens
print("input length ", len(input), " label length ", len(label))
assert len(input) == len(label)

return input, label

0 comments on commit d744fe2

Please sign in to comment.