diff --git a/torchchat/model.py b/torchchat/model.py index 1c78d4c63..f50d2a8be 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -287,6 +287,11 @@ class TransformerArgs: feed_forward_bias: bool = False # Whether or not to tie the input word embeddings to the output tie_word_embeddings: bool = False + # Granite architecture multipliers + embedding_multiplier: Optional[float] = None + attention_multiplier: Optional[float] = None + residual_multiplier: Optional[float] = None + logits_scaling: Optional[float] = None def __post_init__(self): if self.n_local_heads == -1: @@ -723,6 +728,10 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int if self.tok_embeddings: x = self.tok_embeddings(x) + # For Granite architectures + if self.config.embedding_multiplier: + x = x * self.config.embedding_multiplier + for _, layer in self.layers.items(): x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane) @@ -730,6 +739,9 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int x = self.norm(x) if self.output: x = self.output(x) + # For granite architectures + if self.config.logits_scaling: + x = x / self.config.logits_scaling # print(f"output shape: {x.shape}") return x @@ -741,6 +753,12 @@ def __init__(self, config: TransformerArgs) -> None: self.feed_forward = FeedForward(config) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) + # None for llama architecture, set for granite architectures + self.residual_multiplier = ( + config.residual_multiplier + if config.residual_multiplier is not None + else 1.0 + ) def distribute(self, device_mesh: DeviceMesh): self.attention.distribute(device_mesh) @@ -751,8 +769,8 @@ def forward( ) -> Tensor: h = x + self.attention( self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane - ) - out = h + self.feed_forward(self.ffn_norm(h)) + ) * self.residual_multiplier + out = h + self.feed_forward(self.ffn_norm(h)) * self.residual_multiplier return out @@ -779,6 +797,7 @@ def __init__(self, config: TransformerArgs): self.head_dim = config.head_dim self.n_local_heads = config.n_local_heads self.dim = config.dim + self.attention_scale = config.attention_multiplier self._register_load_state_dict_pre_hook(self.load_hook) def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1): @@ -875,7 +894,16 @@ def forward( k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + y = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=mask, + dropout_p=0.0, + # This is None (default) for llama architecture and set for granite + # architectures + scale=self.attention_scale, + ) # -1 = self.dim y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1) diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json index 8791601fb..d2252e6dd 100644 --- a/torchchat/model_config/models.json +++ b/torchchat/model_config/models.json @@ -178,5 +178,33 @@ "distribution_path": "ibm-granite/granite-8b-code-instruct-128k", "transformer_params_key": "Granite-8B-Code", "tokenizer_file": "tokenizer.json" + }, + "ibm-granite/granite-3.0-2b-instruct": { + "aliases": ["granite3-2b", "granite3"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-3.0-2b-instruct", + "transformer_params_key": "Granite-3.0-2B-Instruct", + "tokenizer_file": "tokenizer.json" + }, + "ibm-granite/granite-3.0-8b-instruct": { + "aliases": ["granite3-8b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-3.0-8b-instruct", + "transformer_params_key": "Granite-3.0-8B-Instruct", + "tokenizer_file": "tokenizer.json" + }, + "ibm-granite/granite-3.1-2b-instruct": { + "aliases": ["granite3.1-2b", "granite3.1"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-3.1-2b-instruct", + "transformer_params_key": "Granite-3.1-2B-Instruct", + "tokenizer_file": "tokenizer.json" + }, + "ibm-granite/granite-3.1-8b-instruct": { + "aliases": ["granite3.1-8b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-3.1-8b-instruct", + "transformer_params_key": "Granite-3.1-8B-Instruct", + "tokenizer_file": "tokenizer.json" } } diff --git a/torchchat/model_params/Granite-3.0-2B-Instruct.json b/torchchat/model_params/Granite-3.0-2B-Instruct.json new file mode 100644 index 000000000..1e9779cb3 --- /dev/null +++ b/torchchat/model_params/Granite-3.0-2B-Instruct.json @@ -0,0 +1,21 @@ +{ + "block_size": 8192, + "dim": 2048, + "hidden_dim": 8192, + "n_heads": 32, + "n_local_heads": 8, + "n_layers": 40, + "rope_base": 10000, + "vocab_size": 49155, + "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, + "norm_eps": 0.00001, + "rope_scaling": null, + "attention_bias": false, + "feed_forward_bias": false, + "tie_word_embeddings": true, + "embedding_multiplier": 12.0, + "attention_multiplier": 0.015625, + "residual_multiplier": 0.22, + "logits_scaling": 8.0 +} diff --git a/torchchat/model_params/Granite-3.0-8B-Instruct.json b/torchchat/model_params/Granite-3.0-8B-Instruct.json new file mode 100644 index 000000000..35db0f90d --- /dev/null +++ b/torchchat/model_params/Granite-3.0-8B-Instruct.json @@ -0,0 +1,20 @@ +{ + "attention_multiplier": 0.0078125, + "embedding_multiplier": 12.0, + "dim": 4096, + "block_size": 12800, + "hidden_dim": 12800, + "logits_scaling": 16.0, + "n_heads": 32, + "n_layers": 40, + "n_local_heads": 8, + "residual_multiplier": 0.22, + "norm_eps": 1e-05, + "rope_base": 10000, + "tie_word_embeddings": true, + "vocab_size": 49155, + "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, + "attention_bias": false, + "feed_forward_bias": false +} diff --git a/torchchat/model_params/Granite-3.1-2B-Instruct.json b/torchchat/model_params/Granite-3.1-2B-Instruct.json new file mode 100644 index 000000000..1e82036ab --- /dev/null +++ b/torchchat/model_params/Granite-3.1-2B-Instruct.json @@ -0,0 +1,20 @@ +{ + "attention_multiplier": 0.015625, + "embedding_multiplier": 12.0, + "dim": 2048, + "block_size": 8192, + "hidden_dim": 8192, + "logits_scaling": 8.0, + "n_heads": 32, + "n_layers": 40, + "n_local_heads": 8, + "residual_multiplier": 0.22, + "norm_eps": 1e-05, + "rope_base": 5000000.0, + "tie_word_embeddings": true, + "vocab_size": 49155, + "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, + "attention_bias": false, + "feed_forward_bias": false +} diff --git a/torchchat/model_params/Granite-3.1-8B-Instruct.json b/torchchat/model_params/Granite-3.1-8B-Instruct.json new file mode 100644 index 000000000..646340580 --- /dev/null +++ b/torchchat/model_params/Granite-3.1-8B-Instruct.json @@ -0,0 +1,20 @@ +{ + "attention_multiplier": 0.0078125, + "embedding_multiplier": 12.0, + "dim": 4096, + "block_size": 12800, + "hidden_dim": 12800, + "logits_scaling": 16.0, + "n_heads": 32, + "n_layers": 40, + "n_local_heads": 8, + "residual_multiplier": 0.22, + "norm_eps": 1e-05, + "rope_base": 10000000.0, + "tie_word_embeddings": true, + "vocab_size": 49155, + "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, + "attention_bias": false, + "feed_forward_bias": false +}