From ba1530ec35b64f6e01e3d96af2d9ae50884ce790 Mon Sep 17 00:00:00 2001 From: helunwencser Date: Mon, 28 Oct 2024 16:12:15 -0700 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- examples/models/llama/llama_transformer.py | 197 +++++++++++---------- examples/models/llama/rope.py | 2 +- 2 files changed, 102 insertions(+), 97 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 76e8730328b..e8d865161b2 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -143,16 +143,73 @@ def __post_init__(self): self.hidden_dim = find_multiple(hidden_dim, multiple_of) +class Rope(torch.nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + if self.params.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = partial( + precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + ) + freqs_cos, freqs_sin = self.precompute_freqs_cis( + self.params.dim // self.params.n_heads, + ( + self.params.max_seq_len # Normal llama2. + if self.params.ffn_dim_multiplier is None + else self.params.max_seq_len * 2 # Sharded checkpoint. + ), + self.params.rope_freq_base, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if self.params.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = RotaryEmbedding() + + def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int, input_pos: Optional[torch.LongTensor] = None): + if self.params.use_kv_cache: + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + + if self.params.enable_dynamic_shape: + # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + else: + # When not using dynamic shape, use of the .item results in + # symints, due to querying the data from tensor. + # this path avoids that for mps backend, although probably mps backend + # can support dynamic shape? + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + + else: + assert input_pos is None, "input_pos is unused when use_kv_cache is False" + freqs_cos = self.freqs_cos[:seq_len] + freqs_sin = self.freqs_sin[:seq_len] + q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + return q, k + + class KVCache(nn.Module): def __init__( - self, - max_batch_size: int, - max_seq_length: int, - n_heads: int, - head_dim: int, - transpose_cache: bool, - enable_dynamic_shape: bool, - dtype=torch.float32, + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + transpose_cache: bool, + enable_dynamic_shape: bool, + dtype=torch.float32, ): super().__init__() self.max_seq_length = max_seq_length @@ -175,7 +232,7 @@ def __init__( ) def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache if self.enable_dynamic_shape: @@ -213,13 +270,13 @@ def update( class SDPA(nn.Module): def __init__( - self, - kv_cache: KVCache, - dim: int, - head_dim: int, - n_rep: int, - max_seq_len: int, - enable_dynamic_shape: bool, + self, + kv_cache: KVCache, + dim: int, + head_dim: int, + n_rep: int, + max_seq_len: int, + enable_dynamic_shape: bool, ): super().__init__() self.kv_cache = kv_cache @@ -230,14 +287,14 @@ def __init__( self.enable_dynamic_shape = enable_dynamic_shape def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) - bsz, - seqlen, - mask: torch.Tensor, + self, + input_pos: torch.Tensor, + q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) + v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) + bsz, + seqlen, + mask: torch.Tensor, ) -> torch.Tensor: q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) @@ -262,7 +319,7 @@ def forward( class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -284,6 +341,8 @@ def __init__(self, args: ModelArgs, layer_id: int): self.layer_id = layer_id + self.rope = rope + causal_mask = torch.tril( torch.ones( self.max_seq_len, @@ -300,7 +359,8 @@ def __init__(self, args: ModelArgs, layer_id: int): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v + not args.use_sdpa_with_kv_cache_op, + # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( @@ -311,17 +371,11 @@ def __init__(self, args: ModelArgs, layer_id: int): max_seq_len=self.max_seq_len, enable_dynamic_shape=args.enable_dynamic_shape, ) - if args.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = RotaryEmbedding() def forward( - self, - x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, + self, + x: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, ): bsz, seqlen, _ = x.shape @@ -333,7 +387,7 @@ def forward( v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings - q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + q, k = self.rope.forward(q, k, seqlen, input_pos) if self.use_kv_cache: assert input_pos is not None @@ -421,13 +475,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args, layer_id) + self.attention = Attention(args, layer_id, rope) if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -456,9 +510,10 @@ def __init__(self, params: ModelArgs): self.n_layers = params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.rope = Rope(params) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) + self.layers.append(TransformerBlock(layer_id, params, self.rope)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache @@ -466,31 +521,14 @@ def __init__(self, params: ModelArgs): self.max_seq_len = params.max_seq_len self.input_prune_map = params.input_prune_map self.output_prune_map = params.output_prune_map - if params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis - else: - self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=params.use_scaled_rope - ) - freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.dim // params.n_heads, - ( - params.max_seq_len # Normal llama2. - if params.ffn_dim_multiplier is None - else params.max_seq_len * 2 # Sharded checkpoint. - ), - params.rope_freq_base, - ) - self.register_buffer("freqs_cos", freqs_cos, persistent=False) - self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward( - self, - tokens: Optional[torch.LongTensor] = None, # tokens - input_pos: Optional[ - torch.LongTensor - ] = None, # Scalar tensor indicating size of window of the caches - h: Optional[torch.FloatTensor] = None, # embeddings + self, + tokens: Optional[torch.LongTensor] = None, # tokens + input_pos: Optional[ + torch.LongTensor + ] = None, # Scalar tensor indicating size of window of the caches + h: Optional[torch.FloatTensor] = None, # embeddings ) -> torch.Tensor: if (tokens is None) ^ (h is not None): raise ValueError( @@ -498,42 +536,9 @@ def forward( ) if tokens is not None and h is None: h = self.tok_embeddings(tokens) - seqlen = h.shape[1] - - if self.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" - - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seqlen] - freqs_sin = self.freqs_sin[:seqlen] for layer in self.layers: - h = layer( - h, - freqs_cos, - freqs_sin, - input_pos, - ) + h = layer(h, input_pos) if not self.generate_full_logits: # Only the last logit is used for the new generated token diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 0383c798988..97f28126a91 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -159,4 +159,4 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + return q_embed, k_embed \ No newline at end of file From 411fb82f0a2df3aedd5187f2d61e7af12045b133 Mon Sep 17 00:00:00 2001 From: helunwencser Date: Mon, 28 Oct 2024 16:13:09 -0700 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- examples/models/llama/rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 97f28126a91..0383c798988 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -159,4 +159,4 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed \ No newline at end of file + return q_embed, k_embed From 8792a4db460296fa58180f749bc0bd21968961cd Mon Sep 17 00:00:00 2001 From: helunwencser Date: Mon, 28 Oct 2024 16:17:42 -0700 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- examples/models/llama/llama_transformer.py | 76 ++++++++++++---------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index e8d865161b2..b976d64f0ba 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -169,10 +169,16 @@ def __init__(self, params: ModelArgs): else: self.apply_rotary_emb = RotaryEmbedding() - def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int, input_pos: Optional[torch.LongTensor] = None): + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + seq_len: int, + input_pos: Optional[torch.LongTensor] = None, + ): if self.params.use_kv_cache: assert ( - input_pos is not None + input_pos is not None ), "input_pos must be provided when use_kv_cache is True" if self.params.enable_dynamic_shape: @@ -202,14 +208,14 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int, input_pos: Opt class KVCache(nn.Module): def __init__( - self, - max_batch_size: int, - max_seq_length: int, - n_heads: int, - head_dim: int, - transpose_cache: bool, - enable_dynamic_shape: bool, - dtype=torch.float32, + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + transpose_cache: bool, + enable_dynamic_shape: bool, + dtype=torch.float32, ): super().__init__() self.max_seq_length = max_seq_length @@ -232,7 +238,7 @@ def __init__( ) def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache if self.enable_dynamic_shape: @@ -270,13 +276,13 @@ def update( class SDPA(nn.Module): def __init__( - self, - kv_cache: KVCache, - dim: int, - head_dim: int, - n_rep: int, - max_seq_len: int, - enable_dynamic_shape: bool, + self, + kv_cache: KVCache, + dim: int, + head_dim: int, + n_rep: int, + max_seq_len: int, + enable_dynamic_shape: bool, ): super().__init__() self.kv_cache = kv_cache @@ -287,14 +293,14 @@ def __init__( self.enable_dynamic_shape = enable_dynamic_shape def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) - bsz, - seqlen, - mask: torch.Tensor, + self, + input_pos: torch.Tensor, + q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) + v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) + bsz, + seqlen, + mask: torch.Tensor, ) -> torch.Tensor: q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) @@ -373,9 +379,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): ) def forward( - self, - x: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, + self, + x: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, ): bsz, seqlen, _ = x.shape @@ -523,12 +529,12 @@ def __init__(self, params: ModelArgs): self.output_prune_map = params.output_prune_map def forward( - self, - tokens: Optional[torch.LongTensor] = None, # tokens - input_pos: Optional[ - torch.LongTensor - ] = None, # Scalar tensor indicating size of window of the caches - h: Optional[torch.FloatTensor] = None, # embeddings + self, + tokens: Optional[torch.LongTensor] = None, # tokens + input_pos: Optional[ + torch.LongTensor + ] = None, # Scalar tensor indicating size of window of the caches + h: Optional[torch.FloatTensor] = None, # embeddings ) -> torch.Tensor: if (tokens is None) ^ (h is not None): raise ValueError( From 57fb2e2edf0fe22c776ee3cdef6508f0afacc8cb Mon Sep 17 00:00:00 2001 From: helunwencser Date: Mon, 28 Oct 2024 16:28:51 -0700 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- examples/models/llama/llama_transformer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index b976d64f0ba..2a1970d5f99 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -365,8 +365,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, - # if we are using the custom op don't transpose the cache. Expect untransposed q k v + not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( @@ -495,10 +494,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN - h = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, input_pos - ) + def forward(self, x, input_pos=None): # x: 1xN + h = self.attention.forward(self.attention_norm(x), input_pos) h = x + h if hasattr(self, "block_sparse_moe"):