Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SlimOrca Dataset to the datasets collection #116

Merged
merged 20 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
==================
torchtune.datasets
==================

.. currentmodule:: torchtune.datasets

.. autosummary::
:toctree: generated/
:template: class.rst

AlpacaDataset
SlimOrcaDataset
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,6 @@ TorchTune tutorials.
:caption: API Reference
:hidden:

api_ref_datasets
api_ref_models
api_ref_modules
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
style = 'google'
check-return-types = 'False'
exclude = 'tests/torchtune/models/llama2/scripts/'

[tool.pytest.ini_options]
addopts = ["--showlocals"] # show local variables in tracebacks
5 changes: 5 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import unittest
import uuid
from pathlib import Path
from typing import Any, Union

import torch
Expand All @@ -19,6 +20,10 @@
)


def get_assets_path():
return Path(__file__).parent / "assets"


def init_weights_with_constant(model: nn.Module, constant: float = 1.0) -> None:
for p in model.parameters():
nn.init.constant_(p, constant)
Expand Down
106 changes: 106 additions & 0 deletions tests/torchtune/datasets/test_slimorca_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import random

import pytest

from torchtune import datasets
from torchtune.datasets.slimorca import _Llama2ChatFormatConstants
from torchtune.modules.tokenizer import Tokenizer

from tests.test_utils import get_assets_path


class TestSlimOrcaDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return Tokenizer.from_file(str(get_assets_path() / "m.model"))

def test_prompt_label_generation(self, tokenizer):
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)
sample = [
{
"from": "system",
"value": "hi",
},
{
"from": "human",
"value": "mid",
},
{
"from": "gpt",
"value": "lo",
},
]
prompt, label = dataset._generate_prompt_label(sample)
assert (
prompt
== f"{_Llama2ChatFormatConstants.B_INST} {_Llama2ChatFormatConstants.B_SYS}hi{_Llama2ChatFormatConstants.E_SYS}mid {_Llama2ChatFormatConstants.E_INST}" # noqa: B950
)
assert label == " lo "

sample = [
{
"from": "human",
"value": "mid",
},
{
"from": "gpt",
"value": "lo",
},
]
prompt, label = dataset._generate_prompt_label(sample)
assert (
prompt
== f"{_Llama2ChatFormatConstants.B_INST} mid {_Llama2ChatFormatConstants.E_INST}"
)
assert label == " lo "

def test_token_generation(self, tokenizer):
dataset = datasets.get_dataset(
"slimorca", tokenizer=tokenizer, max_token_length=4096
)
input, label = dataset._generate_tokens("Hello ", "world!")
assert input == [tokenizer.bos_id, 12, 1803, 1024, 103, tokenizer.eos_id]
assert label == ([-100] * 3 + [1024, 103, tokenizer.eos_id])

def test_truncated_token_generation(self, tokenizer):
dataset = datasets.get_dataset(
"slimorca", tokenizer=tokenizer, max_token_length=5
)
# 5 is enough for full prompt, but not for label
input, label = dataset._generate_tokens("Hello ", "world!")
assert input == [tokenizer.bos_id, 12, 1803, 1024, tokenizer.eos_id]
assert label == ([-100] * 3 + [1024, tokenizer.eos_id])

# 4 is not enough for full prompt nor response but truncation
# is still feasible
dataset = datasets.get_dataset(
"slimorca", tokenizer=tokenizer, max_token_length=4
)
input, label = dataset._generate_tokens("Hello ", "world!")
assert input == [tokenizer.bos_id, 12, 1024, tokenizer.eos_id]
assert label == ([-100] * 2 + [1024, tokenizer.eos_id])

def test_value_error(self, tokenizer):
with pytest.raises(ValueError):
datasets.get_dataset("slimorca", tokenizer=tokenizer, max_token_length=3)

@pytest.mark.parametrize("max_token_length", [128, 512, 1024, 4096])
def test_dataset_get_item(self, tokenizer, max_token_length):
ds = datasets.get_dataset(
"slimorca", tokenizer=tokenizer, max_token_length=max_token_length
)
index = random.randint(0, len(ds))
input, label = ds[index]
assert len(input) <= max_token_length
assert len(label) <= max_token_length
assert len(input) == len(label)
assert input[0] == tokenizer.bos_id
assert input[-1] == tokenizer.eos_id
assert label[-1] == tokenizer.eos_id
3 changes: 2 additions & 1 deletion torchtune/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from torch.utils.data import Dataset

from .alpaca import AlpacaDataset
from .slimorca import SlimOrcaDataset # noqa

_DATASET_DICT = {"alpaca": AlpacaDataset}
_DATASET_DICT = {"alpaca": AlpacaDataset, "slimorca": SlimOrcaDataset}


def get_dataset(name: str, **kwargs) -> Dataset:
Expand Down
28 changes: 14 additions & 14 deletions torchtune/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,24 @@


class AlpacaDataset(Dataset):
"""PyTorch Representation of the Alpaca Dataset from Hugging Face.
"""
PyTorch Representation of the Alpaca Dataset
https://huggingface.co/datasets/tatsu-lab/alpaca
from Hugging Face.

Data input format: https://huggingface.co/datasets/tatsu-lab/alpaca#data-instances


Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
**kwargs: Additional keyword arguments to pass to the Alpaca Dataset.

Data input format:
{
"instruction": "Create a classification task by clustering the given list of items.",
"input": "Apples, oranges, bananas, strawberries, pineapples",
"output": "Class 1: Apples, Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples",
"text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a classification task by clustering the given list of items.\n\n### Input:\nApples, oranges, bananas, strawberries, pineapples\n\n### Response:\nClass 1: Apples,
Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples", # noqa: B950
}

Example:
>>> alpaca_ds = AlpacaDataset(tokenizer=tokenizer)
>>> for batch in Dataloader(alpaca_ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
>>> alpaca_ds = AlpacaDataset(tokenizer=tokenizer)
>>> for batch in Dataloader(alpaca_ds, batch_size=8):
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""

def __init__(self, tokenizer: Tokenizer, **kwargs) -> None:
Expand All @@ -49,7 +48,8 @@ def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
return self._transform(self._data[index]["text"])

def _transform(self, sample: str) -> Tuple[List[int], List[int]]:
"""Split a sample on 'response' tag to create input and labels.
"""
Split a sample on ``response`` tag to create input and labels.

Args:
sample (str): Sample text.
Expand Down
135 changes: 135 additions & 0 deletions torchtune/datasets/slimorca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Tuple

from datasets import load_dataset
from torch.utils.data import Dataset

# Not ideal to import this type here but it's needed for the transform function
from torchtune.modules import Tokenizer


class _Llama2ChatFormatConstants:
"""
Contains constants that are used in Llama2 Chat Format.
"""

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"


_CROSS_ENTROPY_IGNORE_IDX = -100


class SlimOrcaDataset(Dataset):
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
"""
PyTorch Representation of the SlimOrca Dataset
https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
from Hugging Face.

The data is formatted to adhere to Llama2 Chat Format.
This format is required if the base model is Llama2 Chat Model.
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
The base Llama2 Model doesn't prescribe a particular format.

The returned data is a tuple of input token id list and label token id
list. If `max_token_length` keyword argument is provided, the returned
input token id list is ensured (by truncation if necessary) to be within
that length.

Data input format: https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup#dataset-format

Args:
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
**kwargs: Additional keyword arguments to pass to the SlimOrca Dataset.

Keyword Arguments:
max_token_length (int): Maximum number of tokens in the returned input and label token id lists. This value needs to be at least 4 though it is generally set to max sequence length accepted by the model. Default is 1024.

Raises:
ValueError: If `max_token_length` is less than 4.

Example:
>>> ds = SlimOrcaDataset(tokenizer=tokenizer, max_token_length=10)
>>> for input, label in ds:
>>> print(input)
>>> print(label)
>>>
>>> Sample Ouput:
>>> [1, 351, 82, 391, 221, 220, 193, 12, 471, ..., 2]
>>> [-100, -100, -100, -100, -100, -100, -100, -100, 471, ..., 2]
""" # noqa

def __init__(self, tokenizer: Tokenizer, **kwargs) -> None:
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
self._data = load_dataset("Open-Orca/SlimOrca-Dedup", split="train")
self._tokenizer = tokenizer
self._max_token_length = kwargs.get("max_token_length", 1024)
if self._max_token_length < 4:
# Input token needs to have 1 bos, 1 eos,
# and 1 token from prompt, 1 from label
raise ValueError("max_token_length must be at least 4")

def __len__(self):
return len(self._data)

def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
data = self._data[index]["conversations"]
prompt, label = self._generate_prompt_label(data)
return self._generate_tokens(prompt, label)

def _generate_tokens(self, prompt: str, label: str) -> Tuple[List[int], List[int]]:
"""
Given a prompt string and label string, generate input and label token id lists.

Tokenizer is used to tokenize both the strings.
The prompt token list is truncated to `max_token_length` - 2
(so that there is at least one label token, as EOS takes one token).

The label token list is truncated to `max_token_length` - len(prompt_token_list)

Finally input token list is the concatenation of prompt and label token lists.

Label token list is padded with cross entropy ignore idx value to match the length of input token list.
"""
prompt_tokens = self._tokenizer.encode(prompt, add_bos=True, add_eos=False)
# Truncate to max token length - 2 (so that there is at least one label token)
prompt_tokens = prompt_tokens[: self._max_token_length - 2]

# Calculate space left for label tokens
label_tokens_length = self._max_token_length - len(prompt_tokens)
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
label_tokens = self._tokenizer.encode(label, add_bos=False, add_eos=True)

# Truncate label tokens
label_tokens = label_tokens[: label_tokens_length - 1]
if label_tokens[-1] != self._tokenizer.eos_id:
label_tokens.append(self._tokenizer.eos_id)
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved

input = prompt_tokens + label_tokens
label = [
_CROSS_ENTROPY_IGNORE_IDX for _ in range(len(prompt_tokens))
] + label_tokens
return input, label

def _generate_prompt_label(self, data: List[Dict[str, str]]) -> Tuple[str, str]:
"""
Construct prompt and label strings adhering to Llama2 Chat Format.
This method supports only back-and-forth conversation per sample (as it is sufficient for SlimOrca dataset).
"""
agent_text_dict = {}
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
# agents can be {system, human, gpt}
for conversation in data:
agent = conversation["from"]
text = conversation["value"]
agent_text_dict[agent] = text

# Llama2 Chat Format - https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L284
if "system" in agent_text_dict:
prompt = f"{_Llama2ChatFormatConstants.B_INST} {_Llama2ChatFormatConstants.B_SYS}{agent_text_dict['system']}{_Llama2ChatFormatConstants.E_SYS}{agent_text_dict['human']} {_Llama2ChatFormatConstants.E_INST}" # noqa: B950
else:
prompt = f"{_Llama2ChatFormatConstants.B_INST} {agent_text_dict['human']} {_Llama2ChatFormatConstants.E_INST}"

response = f" {agent_text_dict['gpt']} "
return prompt, response
Loading