Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alpaca Dataset Updates and Fixes #303

Merged
merged 5 commits into from
Feb 4, 2024
Merged

Alpaca Dataset Updates and Fixes #303

merged 5 commits into from
Feb 4, 2024

Conversation

kartikayk
Copy link
Contributor

@kartikayk kartikayk commented Feb 4, 2024

Context

Our current Alpaca dataset implementation doesn't allow us to train on the inputs i.e. not mask the input during training. Looking at reference implementations, this is pretty common and the only way we can replicate training curves.

The class is also written in a way which doesn't allow the user to easily switch in and out the different variations of the alpaca datasets. Using the clean version of the dataset, allows the loss to go down faster.

In this PR, both of these features are added. Plus tests for the alpaca dataset are added.

Thanks @ebsmothers for helping find some of these issues!

Changelog

  • Re-write the alpaca dataset to be able to easily switch between the original and cleaned version of the datasets, using the use_clean flag
  • Allow training on the input (no masking) using the train_on_input flag
  • Add tests

Test plan

  • Unit tests, including the newly added test_alpaca_dataset succeeded.
pytest tests
Screenshot 2024-02-03 at 4 47 39 PM
  • Training loss is closer to expectation (loss < 1). In the screenshot, blue is with the cleaned version, orange is with the original model.
image

Comment on why we're changing the loss values in test_finetune_llm.py

The loss changes because we have a small difference in the input and label generation. This change is:

  • [old]: tokenizer.encode(text=prompt, add_bos=True, add_eos=False) + tokenizer.encode(text=response, add_bos=False, add_eos=True)
  • [new]: tokenizer.encode(text=prompt+response, add_bos=True, add_eos=True)

This creates a small difference in the output which results in changes in the loss:

image

Copy link

netlify bot commented Feb 4, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit 1ada47f
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65bf051b7888af0008c4b36e
😎 Deploy Preview https://deploy-preview-303--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 4, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for getting this up and tested so quickly! My few comments are all nits so feel free to take or leave any of them

Comment on lines 93 to 95
instruction = self._data[index]["instruction"],
input = self._data[index]["input"],
output = self._data[index]["output"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could just define sample = self._data[index] to avoid multiple calls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!



class AlpacaDataset(Dataset):
"""
PyTorch Representation of the Alpaca Dataset
Support for the Alpaca dataset and it's variants from HuggingFace Datasets.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
Support for the Alpaca dataset and it's variants from HuggingFace Datasets.
Support for the Alpaca dataset and its variants from HuggingFace Datasets.

alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer)

# alpaca_dataset._data contains the raw data loaded from HF's dataset. We need the raw data
# to test the prompt generation since calling __get__item on the alpaca_dataset object will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
# to test the prompt generation since calling __get__item on the alpaca_dataset object will
# to test the prompt generation since calling __getitem__ on the alpaca_dataset object will

@patch("torchtune.datasets.alpaca.load_dataset")
def test_prompt_generation(self, load_dataset, tokenizer):
"""
Test the the prompt generation based on the alpaca template is correct.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
Test the the prompt generation based on the alpaca template is correct.
Test that the prompt generation based on the alpaca template is correct.

@kartikayk
Copy link
Contributor Author

Thanks so much @ebsmothers for the quick review! Addressed all comments.

@kartikayk kartikayk merged commit aaf43de into main Feb 4, 2024
15 checks passed
@kartikayk kartikayk deleted the fix_alpaca branch February 4, 2024 03:51
]

alpaca_dataset = datasets.get_dataset(
"alpaca", tokenizer=tokenizer, use_clean=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we parametrize instead with the test_label_masking, as only difference is the use_clean flag?

where `instruction`, `input`, and `output` are fields from the dataset.

Masking of the prompt during training is controlled by the `train_on_input` flag, which is
set to `True` by default (ref: https://github.com/tloen/alpaca-lora/blob/main/finetune.py#L49)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are our thoughts on referring to reference implementations in torchtune? Not sure if citing them sort of implies that we as torchtune are sort of certifying that repo is a reference we endorse / want to compare against in an outward fashion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants