Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ class TransformerArgs:
feed_forward_bias: bool = False
# Whether or not to tie the input word embeddings to the output
tie_word_embeddings: bool = False
# Granite architecture multipliers
embedding_multiplier: Optional[float] = None
attention_multiplier: Optional[float] = None
residual_multiplier: Optional[float] = None
logits_scaling: Optional[float] = None

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down Expand Up @@ -723,13 +728,20 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int
if self.tok_embeddings:
x = self.tok_embeddings(x)

# For Granite architectures
if self.config.embedding_multiplier:
x = x * self.config.embedding_multiplier

for _, layer in self.layers.items():
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)

if self.norm:
x = self.norm(x)
if self.output:
x = self.output(x)
# For granite architectures
if self.config.logits_scaling:
x = x / self.config.logits_scaling
# print(f"output shape: {x.shape}")
return x

Expand All @@ -741,6 +753,12 @@ def __init__(self, config: TransformerArgs) -> None:
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
# None for llama architecture, set for granite architectures
self.residual_multiplier = (
config.residual_multiplier
if config.residual_multiplier is not None
else 1.0
)

def distribute(self, device_mesh: DeviceMesh):
self.attention.distribute(device_mesh)
Expand All @@ -751,8 +769,8 @@ def forward(
) -> Tensor:
h = x + self.attention(
self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane
)
out = h + self.feed_forward(self.ffn_norm(h))
) * self.residual_multiplier
out = h + self.feed_forward(self.ffn_norm(h)) * self.residual_multiplier
return out


Expand All @@ -779,6 +797,7 @@ def __init__(self, config: TransformerArgs):
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self.attention_scale = config.attention_multiplier
self._register_load_state_dict_pre_hook(self.load_hook)

def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
Expand Down Expand Up @@ -875,7 +894,16 @@ def forward(

k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=mask,
dropout_p=0.0,
# This is None (default) for llama architecture and set for granite
# architectures
scale=self.attention_scale,
)

# -1 = self.dim
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Expand Down
28 changes: 28 additions & 0 deletions torchchat/model_config/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,33 @@
"distribution_path": "ibm-granite/granite-8b-code-instruct-128k",
"transformer_params_key": "Granite-8B-Code",
"tokenizer_file": "tokenizer.json"
},
"ibm-granite/granite-3.0-2b-instruct": {
"aliases": ["granite3-2b", "granite3"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "ibm-granite/granite-3.0-2b-instruct",
"transformer_params_key": "Granite-3.0-2B-Instruct",
"tokenizer_file": "tokenizer.json"
},
"ibm-granite/granite-3.0-8b-instruct": {
"aliases": ["granite3-8b"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "ibm-granite/granite-3.0-8b-instruct",
"transformer_params_key": "Granite-3.0-8B-Instruct",
"tokenizer_file": "tokenizer.json"
},
"ibm-granite/granite-3.1-2b-instruct": {
"aliases": ["granite3.1-2b", "granite3.1"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "ibm-granite/granite-3.1-2b-instruct",
"transformer_params_key": "Granite-3.1-2B-Instruct",
"tokenizer_file": "tokenizer.json"
},
"ibm-granite/granite-3.1-8b-instruct": {
"aliases": ["granite3.1-8b"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "ibm-granite/granite-3.1-8b-instruct",
"transformer_params_key": "Granite-3.1-8B-Instruct",
"tokenizer_file": "tokenizer.json"
}
}
21 changes: 21 additions & 0 deletions torchchat/model_params/Granite-3.0-2B-Instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"block_size": 8192,
"dim": 2048,
"hidden_dim": 8192,
"n_heads": 32,
"n_local_heads": 8,
"n_layers": 40,
"rope_base": 10000,
"vocab_size": 49155,
"use_hf_tokenizer": true,
"tokenizer_prepend_bos": false,
"norm_eps": 0.00001,
"rope_scaling": null,
"attention_bias": false,
"feed_forward_bias": false,
"tie_word_embeddings": true,
"embedding_multiplier": 12.0,
"attention_multiplier": 0.015625,
"residual_multiplier": 0.22,
"logits_scaling": 8.0
}
20 changes: 20 additions & 0 deletions torchchat/model_params/Granite-3.0-8B-Instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"attention_multiplier": 0.0078125,
"embedding_multiplier": 12.0,
"dim": 4096,
"block_size": 12800,
"hidden_dim": 12800,
"logits_scaling": 16.0,
"n_heads": 32,
"n_layers": 40,
"n_local_heads": 8,
"residual_multiplier": 0.22,
"norm_eps": 1e-05,
"rope_base": 10000,
"tie_word_embeddings": true,
"vocab_size": 49155,
"use_hf_tokenizer": true,
"tokenizer_prepend_bos": false,
"attention_bias": false,
"feed_forward_bias": false
}
20 changes: 20 additions & 0 deletions torchchat/model_params/Granite-3.1-2B-Instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"attention_multiplier": 0.015625,
"embedding_multiplier": 12.0,
"dim": 2048,
"block_size": 8192,
"hidden_dim": 8192,
"logits_scaling": 8.0,
"n_heads": 32,
"n_layers": 40,
"n_local_heads": 8,
"residual_multiplier": 0.22,
"norm_eps": 1e-05,
"rope_base": 5000000.0,
"tie_word_embeddings": true,
"vocab_size": 49155,
"use_hf_tokenizer": true,
"tokenizer_prepend_bos": false,
"attention_bias": false,
"feed_forward_bias": false
}
20 changes: 20 additions & 0 deletions torchchat/model_params/Granite-3.1-8B-Instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"attention_multiplier": 0.0078125,
"embedding_multiplier": 12.0,
"dim": 4096,
"block_size": 12800,
"hidden_dim": 12800,
"logits_scaling": 16.0,
"n_heads": 32,
"n_layers": 40,
"n_local_heads": 8,
"residual_multiplier": 0.22,
"norm_eps": 1e-05,
"rope_base": 10000000.0,
"tie_word_embeddings": true,
"vocab_size": 49155,
"use_hf_tokenizer": true,
"tokenizer_prepend_bos": false,
"attention_bias": false,
"feed_forward_bias": false
}
Loading