In [1]:
import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    torch_dtype=torch.bfloat16,
)
model = model.to("cuda")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.94it/s]


In [2]:
from typing import Tuple

def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
    """Quantize a tensor using per-tensor static scaling factor.

    Args:
        tensor: The input tensor.
    """
    finfo = torch.finfo(torch.float8_e4m3fn)
    # Calculate the scale as dtype max divided by absmax.
    # Since .abs() creates a new tensor, we use aminmax to get
    # the min and max first and then calculate the absmax.
    min_val, max_val = tensor.aminmax()
    amax = min_val.abs().max(max_val.abs())
    scale = finfo.max / amax.clamp(min=1e-12)
    # scale and clamp the tensor to bring it to
    # the representative range of float8 data type
    # (as default cast is unsaturated)
    qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(torch.float8_e4m3fn)
    scale = scale.float().reciprocal()
    return qweight, scale

In [3]:
class LinearFP8(torch.nn.Module):
    def __init__(self, qweight, scale):
        super().__init__()
        self.weight = torch.nn.Parameter(qweight, requires_grad=False)
        self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
    
    def forward(self, x):
        shape = x.shape
        x = x.reshape(-1, shape[-1])
        qinput, x_scale = per_tensor_quantize(x)
        
        output, _ = torch._scaled_mm(
            qinput,
            self.weight.t(),
            out_dtype=x.dtype,
            scale_a=x_scale,
            scale_b=self.weight_scale,
            bias=None,
        )
        return output.reshape(shape[0], shape[1], -1)

In [4]:
SELF_ATTN_WEIGHTS = ["q_proj", "k_proj", "v_proj", "o_proj"]
MLP_WEIGHTS = ["gate_proj", "up_proj", "down_proj"]

def quantize_proj(module, proj_name):
    proj = getattr(module, proj_name)
    quant_weight, quant_scale = per_tensor_quantize(proj.weight)
    quant_proj = LinearFP8(quant_weight, quant_scale)
    
    del proj
    setattr(module, proj_name, quant_proj)

for layer in model.model.layers:
    for proj_name in SELF_ATTN_WEIGHTS:
        quantize_proj(layer.self_attn, proj_name)
    for proj_name in MLP_WEIGHTS:
        quantize_proj(layer.mlp, proj_name)

In [5]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): LinearFP8()
          (k_proj): LinearFP8()
          (v_proj): LinearFP8()
          (o_proj): LinearFP8()
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): LinearFP8()
          (up_proj): LinearFP8()
          (down_proj): LinearFP8()
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2"
)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [7]:
input_ids = tokenizer.apply_chat_template(
    [{"role": "user", "content": "What is your name?" }],
    return_tensors="pt"
).to("cuda")

In [8]:
output = model.generate(input_ids=input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> [INST] What is your name? [/INST] I don't have a name. I'm just a computer program designed to assist with information


In [9]:
# hacked transformers/modeling_utils/dtype_byte_size to make this work
model.save_pretrained("mistral-fp8-static")

In [None]:
mod