Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gokulavasan committed Jan 23, 2024
1 parent d4851e0 commit 97f536b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
style = 'google'
check-return-types = 'False'
exclude = 'tests/torchtune/models/llama2/scripts/'

[tool.pytest.ini_options]
addopts = ["--showlocals"] # show local variables in tracebacks
5 changes: 5 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import unittest
import uuid
from pathlib import Path
from typing import Any, Union

import torch
Expand All @@ -19,6 +20,10 @@
)


def get_assets_path():
return Path(__file__).parent / "assets"


def init_weights_with_constant(model: nn.Module, constant: float = 1.0) -> None:
for p in model.parameters():
nn.init.constant_(p, constant)
Expand Down
9 changes: 2 additions & 7 deletions tests/torchtune/datasets/test_slimorca_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import random
from pathlib import Path

import pytest

from torchtune import datasets
from torchtune.modules.tokenizer import Tokenizer

ASSETS = Path(__file__).parent.parent.parent / "assets"
from tests.test_utils import get_assets_path


class TestSlimOrcaDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return Tokenizer.from_file(str(ASSETS / "m.model"))

def test_slim_orca_dataset(self, tokenizer):
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)
assert len(dataset) == 363_491
return Tokenizer.from_file(str(get_assets_path() / "m.model"))

def test_prompt_label_generation(self, tokenizer):
dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer)
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/slimorca.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SlimOrcaDataset(Dataset):
**kwargs: Additional keyword arguments to pass to the SlimOrca Dataset.
Keyword Arguments:
max_token_length (int): Maximum number of tokens in the returned input and label token id lists. This value needs to be at least 4 though it is generally set it to max sequence length accepted by the model. Default is 1024.
max_token_length (int): Maximum number of tokens in the returned input and label token id lists. This value needs to be at least 4 though it is generally set to max sequence length accepted by the model. Default is 1024.
Raises:
ValueError: If `max_token_length` is less than 4.
Expand Down

0 comments on commit 97f536b

Please sign in to comment.