<a href="https://colab.research.google.com/github/selectwait/colab/blob/main/fuse_quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Import and Load GPT-OSS one MLP Tensor
import gdown
import numpy as np

MXFP4_TENSOR_LINK = 'https://drive.google.com/uc?id=1EMCfy_FWfkpICZ7j6oeINsiRS7jAUoqe'
tensor_npz = 'tensor.npz'
gdown.download(MXFP4_TENSOR_LINK, tensor_npz, quiet=False)
mlp_tensors = np.load(tensor_npz)
mlp_weight, mlp_scale = mlp_tensors['arr_0'], mlp_tensors['arr_1']
print(f'weight tensor: {mlp_weight.shape}, scale tensor: {mlp_scale.shape}')

Downloading...
From (original): https://drive.google.com/uc?id=1EMCfy_FWfkpICZ7j6oeINsiRS7jAUoqe
From (redirected): https://drive.google.com/uc?id=1EMCfy_FWfkpICZ7j6oeINsiRS7jAUoqe&confirm=t&uuid=ff97a624-0a66-466e-8e7e-1a4fa0380d9b
To: /content/tensor.npz
100%|██████████| 521M/521M [00:03<00:00, 165MB/s]


weight tensor: (128, 2880, 90, 16), scale tensor: (128, 2880, 90)


In [2]:
#@title Dequantization Tensor
def make_fp4_e2m1_lut() -> np.ndarray:
  """Make lookup table for fp4 e2m1."""
  lut = np.zeros(16, dtype=np.float32)
  for code in range(16):
    s = (code >> 3) & 0x1
    E = (code >> 1) & 0x3
    M = code & 0x1
    bias = 1
    if E == 0:
      val = (M / 2.0)  # subnormal
    else:
      frac = 1.0 + (M / 2.0)
      exp  = E - bias
      val  = np.ldexp(frac, exp)
    lut[code] = (-1.0)**s * val
  return lut

# A Global value :-/
_FP4_LUT = make_fp4_e2m1_lut()

def e8m0_decode(scales_u8: np.ndarray) -> np.ndarray:
  """Scale u8 dequantize E8M0."""
  return np.exp2(scales_u8.astype(np.int16) - 127)


def mxfp4_dequantize(packed_fp4: np.ndarray,
                     scales_u8: np.ndarray) -> np.ndarray:
  """Dequantize MXFP4 tensor with scale.

  Args:
    packed_fp4: Packed MXFP4 tensor. np.ndarray(uint8), shape (..., B)
                ach byte holds 2 FP4 values (low nibble first).
    scales_u8: np.ndarray(uint8), shape (...) — same as packed_fp4.shape[:-1]
               One E8M0 scale per block of 2*B elements.

  Returns:
    np.ndarray(float32), shape (..., 2*B)
  """
  assert packed_fp4.dtype == np.uint8
  assert scales_u8.dtype == np.uint8
  assert packed_fp4.shape[:-1] == scales_u8.shape, \
      f"scales shape {scales_u8.shape} must match packed_fp4.shape[:-1] {packed_fp4.shape[:-1]}"

  # unpack nibbles → (..., 2*B)
  low  = packed_fp4 & 0x0F
  high = packed_fp4 >> 4
  nibbles = np.concatenate([low, high], axis=-1)

  # FP4 decode via LUT
  elems = _FP4_LUT[nibbles]

  # decode scales and broadcast
  scales = e8m0_decode(scales_u8)[..., None]  # expand last dim
  return elems * scales


def mxfp4_mlp_matmul_activation(
    x: np.ndarray,
    weight_packed: np.ndarray,
    scale_u8: np.ndarray,
    expert_idx: int,
    bias: np.ndarray | None = None
) -> np.ndarray:
    """Fused Quantizatze and Matmul.

    Args:
      x: (..., intermediate) float32/float16 — activation vector(s) entering this MLP linear
      weight_packed: Packed weight tensor (n_experts, intermediate, n_blocks, b) uint8
      scale_u8: Scale tensor (n_experts, intermediate, n_blocks) uint8
      expert_idx:   which expert to use.
      bias: Optional bias tensor.

    Returns:
      (..., d_model) float32 — y = x @ W^T (+ bias)
    """
    assert weight_packed.dtype == np.uint8 and scale_u8.dtype == np.uint8
    assert weight_packed.shape[:-1] == scale_u8.shape
    assert 0 <= expert_idx < weight_packed.shape[0]

    intermediate = weight_packed.shape[1]
    x = np.asarray(x)
    assert x.shape[-1] == intermediate, f"expected last dim {intermediate}, got {x.shape[-1]}"
    x = x.astype(np.float32, copy=False)

    # output buffer (..., O)
    out_shape = x.shape[:-1] + (weight_packed.shape[1],)  # (..., 2880)
    y = np.zeros(out_shape, dtype=np.float32)

    # grab expert slice once
    Wp_e = weight_packed[expert_idx]     # For GPT-OSS (i, 90, 16)
    Sc_e = scale_u8[expert_idx]          # For GPT-OSS (i, 90)

    # process 32-wide input blocks
    # For each block j: dequantize (O,32) then y += einsum('...k,ok->...o', x_block, W_block)
    for j in range(Wp_e.shape[1]):
        x_block = x[..., (j*32):((j+1)*32)]                      # (..., 32)
        W_block = mxfp4_dequantize(Wp_e[:, j, :],      # (O,16) → (O,32)
                                   Sc_e[:, j])         # (O,)
        # Accumulate: batch-friendly
        y += np.einsum('...k,ok->...o', x_block, W_block, optimize=True)

    if bias is not None:
        bias = np.asarray(bias, dtype=np.float32)
        assert bias.shape == (weight_packed.shape[1],)
        y += bias

    return y


In [3]:
# Create a random activation and pick 4 random expert.
dummy_activations = np.random.randn(10, 2880).astype(np.float32)
experts = np.random.randint(0, 128, 4)

In [6]:
#@title Fused Performance
%%time

# --- Perform the fused multiplication ---
output = np.zeros((10, 2880))
for expert in experts:
  output += mxfp4_mlp_matmul_activation(dummy_activations, mlp_weight, mlp_scale, expert)

print(f"\nOutput shape: {output.shape}")
print(output[0, :5])



Output shape: (10, 2880)
[-59.60875702  69.62116623  -5.2908659   82.61049032  -7.77480698]
CPU times: user 1.07 s, sys: 0 ns, total: 1.07 s
Wall time: 545 ms


In [5]:
#@title Dequant and Matmul Performance
%%time
mat = mxfp4_dequantize(mlp_weight, mlp_scale)
out = np.zeros((10, 2880))
for expert in experts:
  expert_mat = mat[expert].reshape((2880, -1))
  out += np.matmul(dummy_activations, expert_mat.T)
print(f"\nOutput shape: {out.shape}")
print(out[0, :5])



Output shape: (10, 2880)
[-59.60876846  69.62115955  -5.29086876  82.61053276  -7.77484131]
CPU times: user 6.51 s, sys: 2.88 s, total: 9.39 s
Wall time: 9.42 s
