Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance lora tests with more layer and rank variations #3243

Merged
merged 18 commits into from
Mar 10, 2024

Conversation

tterrysun
Copy link
Contributor

Enhance lora tests with more layer and rank variations.

requirements.txt Outdated
@@ -8,6 +8,7 @@ transformers >= 4.38.0 # Required for Gemma.
xformers == 0.0.23.post1 # Required for CUDA 12.1.
fastapi
uvicorn[standard]
peft == 0.8.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is only used for testing, it should go to requirements-dev.txt

"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
]
TMP_PATH = "/mnt/local_storage/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this :)

@@ -121,6 +124,14 @@ def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")


@pytest.fixture(scope="session")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need this fixture and TMP_PATH above. Pytest already has a built in tmpdir fixture that you can use https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture

# Test the functionality when layer and rank are varied
@pytest.mark.parametrize("target_modules", TARGET_MODULES_LIST)
@pytest.mark.parametrize("rank", [8, 16, 32, 64])
def test_layer_variation_functionality(target_modules, rank, tmp_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just remove this test -- it is completely subsumed by test_layer_variation_verify_reference, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed!

TMP_PATH = "/mnt/local_storage/"


def get_lora_model(model_id: str, target_modules: List[str], rank: int):
Copy link
Collaborator

@pcmoritz pcmoritz Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I currently don't understand this function -- what are the lora model weights that are actually applied on top of the meta-llama/Llama-2-7b-hf base model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a default initialized lora, we use the merged one as golden reference to verify the correctness, the lora weights won't matter as long as we're using the same one

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point to where in the docs it says it is a default LoRA and what it is? That part was not clear to me (maybe add a comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
generated_logprobs.append([
list(logprob.keys()) for logprob in outputs[0].outputs[0].logprobs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be output.outputs[0].logprobs? Otherwise you will only ever use the first prompt, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, but this is only ever using the first prompt (i.e. PROMPT[0]) if I understand this correctly -- that can't possibly be your intention. Otherwise why even include the other prompts?

Copy link
Contributor Author

@tterrysun tterrysun Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you're right, I misread the comment. fixed

@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 7, 2024

On a high level, I think these tests would be much better understandable if it was a single test that evaluates the reference and then tests against the target, instead of having two tests and hardcoding the taget.

I also find it very irritating that we only test correctness of the first token for each sequence. Is there a way to test the rest of the sequence too and still have it be robust? Other tests do similar things, right? Maybe instead of testing the predicted tokens, you can test that the logits are numerically close within some error.

@@ -21,6 +21,8 @@
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel, initialize_model_parallel)

TMP_PATH = "/mnt/local_storage/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed

@tterrysun tterrysun requested a review from pcmoritz March 7, 2024 02:48
@pytest.mark.parametrize("target_modules", TARGET_MODULES_LIST)
@pytest.mark.parametrize("rank", [8, 16, 32, 64])
def test_layer_variation_correctness(tp_size, target_modules, rank, tmpdir):
if torch.cuda.device_count() < tp_size:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make it skip in the CI -- try tp_size 1 like in test_llama.py

merged_probs = do_sample(llm, tmp_dir_lora, 1, logprobs=5, n_tokens=32)
del llm
cleanup()
shutil.rmtree(str(tmpdir))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you need to delete the temp dir, it will be better to use

with tempfile.TemporaryDirectory(delete=True) as tmpdir:
    # Code that uses temp dir here

n_tokens: int = 256):
prompts = PROMPTS
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason btw you are not setting max_tokens=n_tokens here and then skip the slicing below?

@tterrysun tterrysun marked this pull request as ready for review March 7, 2024 20:49

model = get_lora_model(MODEL_PATH, target_modules, rank)
with tempfile.TemporaryDirectory() as tmpdir:
tmp_dir_merged = os.path.join(tmpdir, "tmp_dir_merged")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to introduce an additional layer of tmp directories? Same above.

tokenizer=MODEL_PATH,
enable_lora=False,
max_num_seqs=16,
tensor_parallel_size=4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be hard coded (this won't work in the CI)

@tterrysun tterrysun requested a review from pcmoritz March 9, 2024 00:23
Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! There are other tests where

    if torch.cuda.device_count() < tp_size:

is causing problems, it might be better to remove that?

@richardliaw
Copy link
Collaborator

can we merge?

@simon-mo simon-mo merged commit 0bba88d into vllm-project:main Mar 10, 2024
23 checks passed
dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants