Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model.generate support #6

Merged
merged 1 commit into from
Oct 3, 2023
Merged

Add model.generate support #6

merged 1 commit into from
Oct 3, 2023

Conversation

tomaarsen
Copy link
Owner

@tomaarsen tomaarsen commented Oct 3, 2023

Closes #1

Hello!

Pull Request overview

  • Prevent crashes for model.generate

Details

The _update_model_kwargs_for_generation method in GenerationMixin would endlessly grow the attention_mask to match the past_key_values + 1, which is normally very reasonable. However, with attention_sinks we eventually cap the past_key_values, so it ended up crashing.

This change very simply prevents the endless growth of the attention_mask so it always matches past_key_values + 1.

Usage

import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM


# model_id = "meta-llama/Llama-2-7b-hf"
# model_id = "mistralai/Mistral-7B-v0.1"
model_id = "mosaicml/mpt-7b"
# model_id = "tiiuae/falcon-7b"
# model_id = "EleutherAI/pythia-6.9b-deduped"

# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # for efficiency:
    device_map="auto",
    torch_dtype=torch.float16,
    # `attention_sinks`-specific arguments:
    attention_sink_size=4,
    attention_sink_window_size=252,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Our input text
text = "Vaswani et al. (2017) introduced the Transformers"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)

# Print tokens as they're being generated
streamer = TextStreamer(tokenizer)
generated_tokens = model.generate(
    input_ids,
    generation_config=GenerationConfig(
        # use_cache=True is required, the rest can be changed up.
        use_cache=True,
        min_new_tokens=20000,
        max_new_tokens=50000,
        penalty_alpha=0.6,
        top_k=5,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    ),
    streamer=streamer,
)
# Decode the final generated text
output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
  • Tom Aarsen

@sparverius
Copy link

Confirmed working, even tested with a few gptq models! just needed to git+ install

pip install git+https://github.com/tomaarsen/attention_sinks.git

@tomaarsen
Copy link
Owner Author

That's awesome! I'm preparing a release now so the install is a bit easier - I'm just doing some edits on the README and CHANGELOG first :)

Thanks for helping with testing!

@tomaarsen
Copy link
Owner Author

v0.2.2 has been released, which includes this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trying a minimal example with LlamaForCasualLM, sadly it fails
2 participants