From d19d1b8ff3181bdab0fe5307dd5ccf76bf5039d2 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 29 Oct 2024 18:56:44 -0700 Subject: [PATCH] Remove deprecated InstructTemplate from llm_pte_finetuning example (#6557) Summary: See https://github.com/pytorch/executorch/issues/6552 for context. Here, we remove the InstructTemplate classes and instead directly replace with a dataset builder that uses the necessary torchtune data components. The associated config is also updated. Reviewed By: larryliu0820 Differential Revision: D65168404 Pulled By: RdoubleA --- .../phi3_alpaca_code_config.yaml | 11 +--- examples/llm_pte_finetuning/training_lib.py | 62 ++++++------------- 2 files changed, 22 insertions(+), 51 deletions(-) diff --git a/examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml b/examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml index 88e5bfac700..4ca3804f086 100644 --- a/examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml +++ b/examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml @@ -4,15 +4,8 @@ tokenizer: max_seq_len: 1024 dataset: - _component_: torchtune.datasets.instruct_dataset - template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly - source: iamtarun/python_code_instructions_18k_alpaca - split: train - column_map: - instruction: instruction - prompt: prompt - input: input - output: output + _component_: executorch.examples.llm_pte_finetuning.training_lib.python_code_instructions_alpaca + seed: null shuffle: True batch_size: 1 diff --git a/examples/llm_pte_finetuning/training_lib.py b/examples/llm_pte_finetuning/training_lib.py index 6324d93814e..dfdaf9b115a 100644 --- a/examples/llm_pte_finetuning/training_lib.py +++ b/examples/llm_pte_finetuning/training_lib.py @@ -7,15 +7,17 @@ # pyre-strict from functools import partial -from typing import Any, Dict, Mapping, Optional +from typing import Any import torch from executorch.extension.pybindings.aten_lib import ExecuTorchModule # @manual from torch.nn import functional as F from torch.utils.data import DataLoader, Dataset, DistributedSampler -from torchtune.data import InstructTemplate +from torchtune.data import AlpacaToMessages from torchtune.data._collate import padded_collate_sft +from torchtune.datasets import PackedDataset, SFTDataset +from torchtune.modules.tokenizers import ModelTokenizer from tqdm import tqdm @@ -44,49 +46,25 @@ def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: return self.loss(logits, labels) -class DatabricksDolly(InstructTemplate): +def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset: """ - Used for the Dolly dataset from Databricks. - - https://huggingface.co/datasets/databricks/databricks-dolly-15k - """ - - template = "Instruction:\n{instruction}\n\nContext:\n{input}\n\nResponse: " - - @classmethod - def format( - cls, - sample: Mapping[str, Any], - column_map: Optional[Dict[str, str]], - ) -> str: - assert column_map is not None - instruction = sample[column_map["instruction"]] - input = sample[column_map["input"]] - return cls.template.format(instruction=instruction, input=input) - - -class PythonCodeInstructions(InstructTemplate): - """ - https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca + Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca. """ - - template = ( - "{prompt}\n\n" - "Instruction:\n{instruction}" - "\n\nContext:\n{input}\n\nResponse: " + ds = SFTDataset( + # pyre-ignore[6]: Incompatible parameter type + model_transform=tokenizer, + source="iamtarun/python_code_instructions_18k_alpaca", + message_transform=AlpacaToMessages( + train_on_input=False, + ), + # pyre-ignore[6]: Incompatible parameter type + split="train", ) - - @classmethod - def format( - cls, - sample: Mapping[str, Any], - column_map: Optional[Dict[str, str]], - ) -> str: - assert column_map is not None - instruction = sample[column_map["instruction"]] - input = sample[column_map["input"]] - prompt = sample[column_map["prompt"]] - return cls.template.format(instruction=instruction, input=input, prompt=prompt) + if tokenizer.max_seq_len is None: + raise ValueError( + "PackedDataset requires a max_seq_len to be set on the tokenizer." + ) + return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False) def update_function(