Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 142 additions & 2 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@

import copy
import logging
from typing import Any
from typing import Any, Optional

import gguf

import torch
import torch.nn.functional as F

from build.gguf_util import Q4_0, to_float
from build.model import Model, ModelArgs, TransformerArgs

from gguf import GGUFValueType
from quantization.qops import LinearInt4 as WeightOnlyInt4Linear
from quantization.quantize import pack_scales_and_zeros

from build.utils import find_multiple, get_precision


logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -97,6 +100,143 @@ def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
return metadata


#########################################################################
# Note: int4 quantization is migrated to torchao for general quantization.
# TODO: GGUF workflow needs migration to torchao
#########################################################################


def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_input_size = input.size()
input = input.reshape(-1, origin_input_size[-1])

if "cuda" in str(input.device):
c = torch.ops.aten._weight_int4pack_mm(
input.to(torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16),
).to(
input.dtype
) # cast back to input.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_input_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales_and_zeros: torch.Tensor

def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
*,
groupsize: int = 128,
inner_k_tiles: int = 8,
weight: Optional[torch.Tensor] = None,
scales_and_zeros: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.padding = not self._check_k(
k=in_features,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
)
if self.padding:
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)

self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
assert (weight is None) == bool(
scales_and_zeros is None
), "must specify both weights and scales_and_zeros, or neither"

if weight is None:
weight = torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
)
scales_and_zeros = torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
device=device,
)

self.register_buffer(
"weight",
weight,
)
self.register_buffer(
"scales_and_zeros",
scales_and_zeros,
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)

@classmethod
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0

@classmethod
def _prepare_weight_and_scales_and_zeros(
cls, weight_bf16, groupsize, inner_k_tiles
):
from quantization.quantize import group_quantize_tensor

weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
torch.uint8
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_uint8, inner_k_tiles
)
return weight_int4pack, scales_and_zeros

@classmethod
def _calc_padded_size(cls, *, k, groupsize=1, innner_k_tiles=1):
return find_multiple(k, 1024)


#########################################################################


def load_model(gguf_file: str) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device.
Expand Down
Loading