Skip to content

Commit

Permalink
Making progress
Browse files Browse the repository at this point in the history
  • Loading branch information
gokulavasan committed Jan 21, 2024
1 parent b061969 commit 1983024
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
4 changes: 2 additions & 2 deletions recipes/configs/alpaca_llama2_finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Dataset and Dataloader
dataset: alpaca
seed: null
dataset: slimorca
seed: 10
shuffle: True

# Model Arguments
Expand Down
25 changes: 25 additions & 0 deletions tests/torchtune/datasets/test_slimorca_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path

import pytest

from torchtune import datasets
from torchtune.modules.tokenizer import Tokenizer

ASSETS = Path(__file__).parent.parent.parent / "assets"


class TestSlimOrcaDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return Tokenizer.from_file(str(ASSETS / "m.model"))

def test_slim_orca_dataset(self, tokenizer):
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)
assert dataset
17 changes: 13 additions & 4 deletions torchtune/datasets/slimorca.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SlimOrcaDataset(Dataset):
from Hugging Face.
The data is formatted to adhere to Llama2 Chat Format.
It will work only for Llama2 models.
This format is required if the base model is Llama2 Chat Model.
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
Expand All @@ -45,13 +45,22 @@ class SlimOrcaDataset(Dataset):
Batch size: 8
"""

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

def __init__(self, tokenizer: Tokenizer, **kwargs) -> None:
self._data = load_dataset("Open-Orca/SlimOrca-Dedup", split="train")
self._tokenizer = tokenizer

def __len__(self):
return len(self._data)

def prompt_with_system(self, content: str) -> str:
return f"{self.B_INST} {self.B_SYS}{content}{self.E_SYS} {self.E_INST}"

def prompt_without_system(self, content: str) -> str:
return f"{self.B_INST} {content} {self.E_INST}"

def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
data = self._data[index]["conversations"]
agent_text_dict = {}
Expand All @@ -63,11 +72,11 @@ def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:

# Llama2 Chat Format - https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L284
if len(agent_text_dict["system"]) > 0:
prompt = f"[INST] <<SYS>> {agent_text_dict['system']} <</SYS>> {agent_text_dict['human']} [/INST] "
prompt = f"{self.B_INST} {self.B_SYS}{agent_text_dict['system']}{self.E_SYS}{agent_text_dict['human']} {self.E_INST}"
else:
prompt = f"[INST] {agent_text_dict['human']} [/INST] "
prompt = f"{self.B_INST} {agent_text_dict['human']} {self.E_INST}"

response = f"{agent_text_dict['gpt']} "
response = f" {agent_text_dict['gpt']} "

prompt_tokens = self._tokenizer.encode(prompt, add_bos=True, add_eos=False)
label_tokens = self._tokenizer.encode(response, add_bos=False, add_eos=True)
Expand Down

0 comments on commit 1983024

Please sign in to comment.