Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference #619

Merged
merged 7 commits into from
Mar 31, 2024
Merged

Inference #619

merged 7 commits into from
Mar 31, 2024

Conversation

kartikayk
Copy link
Contributor

@kartikayk kartikayk commented Mar 29, 2024

Context

Our current story for inference is sub-optimal. This PR makes an attempt at fixing this. The code for inference is heavily inspired by gpt-fast though it's missing some of the functionality around compile. I'll add this as a follow up PR. This is basically setting up the inference to actually work.

Changelog

  • Update the way we setup KV Cache, including removing curr_pos which was really confusing and never made any sense to me. I replace this with input_pos (naming consistent with gpt_fast) which does exactly what you expect it to i.e. its a tensor with the current position. When we first start inference, this includes positions for all tokens in the prompt since the K and V tensors for these need to be correctly computed.
  • Update all components and tests
  • Remove the current GenerationUtils class and logit_transforms file. Replace these with standa-alone utilities for generation.

Test plan

  • All unit tests and recipe tests pass
  • Output generation is sensible for a 13B model

image

Comparison with gpt-fast: Without compile, our generation speed is on par with gpt-fast. Adding compile support is beyond the scope of this PR

gpt-fast:

image

torchtune:

image

Copy link

pytorch-bot bot commented Mar 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/619

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 24e9ff0 with merge base 08f8235 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 29, 2024
# input has shape [b, s, d]
bsz, seq_len, _ = x.shape

# self.wqkv.weight.data = torch.cat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight ])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old code?

@@ -166,6 +170,8 @@ def forward(
k = self.k_proj(x)
v = self.v_proj(x)

# pdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙃

k_val (Tensor): New k value.
v_val (Tensor): New v value.
def update(self, input_pos, k_val, v_val) -> Tuple[Tensor, Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WOOOOOOOOOOO

model.setup_caches(max_batch_size=1, dtype=self._dtype)
return model

def _multinomial_sample_one_no_sync(self, probs_sort):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment here explaining this?

model, cur_token, input_pos, temperature, top_k
)
input_pos += 1
new_tokens.append(next_token.clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care about performance here? How expensive is the double clone operation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memory impact of this is quite minimal. These are like 300 or so ints

@joecummings
Copy link
Contributor

Should we also add a test ensuring that this does in fact speed up inference? Something like running the same inference twice, once with cache, and once without and checking the time difference?

q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def _logits_to_probs(self, logits: torch.Tensor, temperature: float, top_k: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're pretty inconsistent with where our sampling code lives. I'd be in favor of putting it all in one file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to generalize this code prematurely. I think for now I expect this to be limited to this recipe, but we can generalize as a follow up if that makes sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple questions here:

  1. Does this obviate the need for logits_transforms.py? Seems we are now handling all of that in the recipe. (As a follow-up, if we are gonna keep it, can we at least drop the LogitsTransform ABC? It's literally just a Callable[FloatTensor, FloatTensor])
  2. I notice we are missing top p sampling, which we otherwise have support for. Any particular reason for omitting it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the logits_transform.py file completely. For the support for top_p, I didn't really find a good reference. Maybe we can add back if its needed?

dtype: bf16
seed: 1234

temperature: 0.8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment on this PR explaining how you chose these hyperparams?

path: /tmp/llama2/tokenizer.model

# Generation arguments
prompt: "Hello, my name is"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make sense to have this be a common default prompt. (Coming back to update this comment with examples)


# Model arguments
model:
_component_: torchtune.models.llama2.llama2_13b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call it generate_13b?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really - the model itself doesn't change during inference. The inference script is responsible for setting up the caches and the calling eval() on the model to disable all of the stochastic operations.

pytorch_model-00002-of-00003.bin,
pytorch_model-00003-of-00003.bin
]
recipe_checkpoint: null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this pretty much always going to be null? We don't really need to pass in recipe state into generate, IIUC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeh, good point. I should remove this


# [b, n_h, s, h_d]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Update key-value cache
if self.kv_cache is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, any benefit in updating kv-cache post transpose? Does it enable better access patterns or the like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good question - let me think a bit about it

if incremental_decode:
outputs = self.decoder_lm(input_ids, curr_pos=prev_pos)
outputs = self.decoder_lm(input_ids, input_pos=input_pos)
else:
outputs = self.decoder_lm(input_ids)
if logits_accessor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not part of these changes, but I find the name "accessor" a bit confusing tbh. Is the idea here to do some sort of slicing along the vocab dim or something? What is the difference between this and logits_transforms? Is it just that one is applied before softmax and the other is applied after?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK coming back to this now I'm a bit confused (prob should have read the full recipe code first). Why isn't this file deleted entirely? Seems you've covered all the generation logic in the recipe already. Or am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No particular reason, was leaving this as a follow up. Let me think about this - maybe I should combine the change here

temperature: float,
top_k: int,
) -> torch.Tensor:
# input_pos: [B, S]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that's confusing to me.. do we support batch generation? The way we pass the prompt seems to indicate that we don't, but then comments like this would indicate that we do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's nothing fundamentally stopping multi-sample generation. The inference recipe doesn;t currently handle this though

def forward(
self, tokens: Tensor, mask: Optional[Tensor] = None, curr_pos: int = 0
) -> Tensor:
def setup_caches(self, max_batch_size: int, dtype: torch.dtype) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about the choice to define this as a method on the TransformerDecoder class (especially with the addition of mandatory max_seq_len, num_heads, head_dim which are only needed for inference). Why not just define a utility method, e.g.

def setup_caches(
	model: TransformerDecoder,
	max_batch_size: int,
	max_seq_len: int,
	num_heads: int,
	head_dim: int,
):
	# same body as setup_caches but with self -> model

Then optionally call this on the builder gated behind a flag. The slight drawback is that we do add one extra param to builders so it's not a 1:1 swap on the config side as it is now (maybe there's a way to do that though?). But honestly I think that's OK? Like if we are changing the model arch a bit for inference (which we are) there's no harm in being explicit about that.

I think the benefit is that anyone using TransformerDecoder directly (which should be a decent percentage of the people adding new models) does not have to worry about these extra params that are really only relevant for inference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really interesting point.

I think you touch upon this a bit. The reason I think having a method in the class makes sense is because the behavior of forward changes depending on whether you're calling it during training or during inference and I'd rather the model be able to handle this based on its internal state (eg: causal_mask). For example, if this is a utility, then you also need to take care of passing around the causal_mask explicitly in every call when this can be easily handled by the class itself? I also think, that having the forward of the decoder just take the input and the corresponding positions is nicer than explicitly passing the mask (though that might be a bit of a personal preference).

I'm actually not too concerned about num_heads, head_dim and max_seq_len because the model already has this information as part of all of its components - it's related to the model isn't it? In fact I think this is strictly better than passing something like max_batch_size which was only used during inference :) Let me know if that makes sense.

return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def _logits_to_probs(self, logits: torch.Tensor, temperature: float, top_k: int):
logits = logits / max(temperature, 1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe just add a value check on temperature somewhere?

@@ -103,7 +94,7 @@ def llama2(
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
kv_cache=kv_cache,
kv_cache=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a useless arg now? Seems if we are building KV caches at the decoder level it will always be none?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeh its default to None so I don't think it needs to exist in this call.

Comment on lines +51 to +52
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, is it now the case that max_batch_size really just means batch size? If so, maybe we should update its name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why it was called max_batch_size ever, but this seems to be the convention. Let me do some more research on this

Comment on lines +20 to +22
num_heads (int): number of heads. We take num_heads instead of num_kv_heads because
the cache is created after we've expanded the key and value tensors to have the
same shape as the query tensor. See attention.py for more details
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this makes sense to me, but given that you didn't really change what we were doing on the attention side of things, how was this working properly before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the case where we haven't really run inference for the case where num_kv_heads != num_heads. If we did, this would break

t0 = time.perf_counter()
generated_tokens, _ = self._decode_n_tokens(
self._model,
next_token.view(1, -1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the view actually doing here? Isn't this just a single token?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up.. just a general comment on a bunch of these methods: it'd be nice to add docstrings with examples of each one (e.g. given a prompt of length n, prefill will fill the KV cache's first n elements and return token n+1, something similar for _decode_one_token and/or _decode_n_tokens)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The view converts this into the form of [BS, input] which is what the transformer expects

Comment on lines 114 to 116
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is for performance reasons? Might be good to add a bit of detail on why we're doing this though (is it because the math implementation is faster when seq len and batch size are small?)

top_k: int,
):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always generate a fixed # of tokens? What if e.g. we get EOS before then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we don't respect the EOS token. Let me add that functionality

logger = utils.get_logger("DEBUG")


class InferenceRecipe:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably add documentation here around what's supported and what's not (e.g. explicitly call out we don't have speculative decoding, we support temperature and top-k sampling, etc.)

@@ -22,6 +22,7 @@
validate_no_params_on_meta_device,
wrap_fsdp,
)
from ._generation import generate # noqa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why do you need the #noqa?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls someone add some analogous version of this to our .flake8, then we can delete all these godforsaken noqas in our init files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you've been asking for this @ebsmothers! I'm going to be a pain and punt this to a follow up!

logger.info(
f"Time for inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
)
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, why are you including this information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a standard for every generation tool. I added it for my own debugging, but then realized it might be useful for folks running this recipe as well. Anything look out of place?

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kewl!

@kartikayk kartikayk merged commit c543a5b into main Mar 31, 2024
20 checks passed
@kartikayk kartikayk deleted the inference branch March 31, 2024 23:52
tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants