Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented automated broadcasting in weight rescale when number of model shards is fewer than number of experts #265

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jacobthebanana
Copy link

@jacobthebanana jacobthebanana commented Mar 21, 2024

There are use cases where one might want to shard the quantized model across fewer devices than the number of experts. However, doing so would result in a shape mismatch when attempting to re-scale the weights back to bfloat16 during inference. For example, when generating using the grok-1 weights across one model-parallel shard, one would observe a TypeError from jax.numpy:

jax/_src/numpy/ufuncs.py:100: 
TypeError: mul got incompatible shapes for broadcasting: (8, 32768, 6144), (8, 8, 6144).

Below is a brief analysis of why this exception happened, and how this PR would address this issue.

In the quantized weights released with this repo, each tensor is represented as a QuantizedWeight8bit. For example, the w parameter of decoder_layer_0/moe/linear consists of:

  • a weight tensor of shape (8, 32768, 6144) and dtype int8, and
  • a scales tensor of shape (8, 8, 6144) and dtype bfloat16.

The modelling code in the grok-1 repo leverages the jax.experimental.shard_map.shard_map decorator to ensure that re-scaling the weight matrix does not require cross-device communication. Specifically, moe_slow_matmul1 and moe_slow_matmul2 are wrapped in the shard_map decorator to handle parameters from one expert at a time. Note that as seen in the example above, scales is not directly broadcastable to weight when computing weight = weight * scale. Rather, shard_map would partition weight into eight (8, 4096, 6144) blocks and scales into eight (8, 1, 6144) blocks before supplying the partitioned tensors to moe_slow_matmul1. Each block of the scales tensor is then broadcasted along axis=1 of the corresponding block of the weight tensor.

This approach works as expected as long as each model-parallel partition contains exactly one expert. However, when partitioning the pretrained model across fewer devices than experts, the input to moe_slow_matmul1 would no longer be broadcastable. For example, when running a total of 4 devices for 2 experts per device, the tensors supplied to moe_slow_matmul1 would be of shape (8, 8192, 6144) for weights and (8, 2, 6144) for scales. Promptly, jax.numpy would complain about how the two inputs cannot be broadcasted in the multiplication weight = weight * shape (source).

This PR proposes a workaround that reshape the tensors prior to re-scaling. Since the proposed changes are wrapped entirely inside the shard_map decorator, the proposed reshape logic will not require communication between devices. When the number of experts matches the number of model parallelism shards, the proposed behavior would be equivalent to that of the original reference implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant