Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4 bit support #66

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
BitsandbytesNF4,
BitsandbytesFP4,
Gptq,
}

Expand All @@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::BitsandbytesNF4 => {
write!(f, "bitsandbytes-nf4")
}
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Gptq => {
write!(f, "gptq")
}
Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"


Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ def get_model(
raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `lorax-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)

if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
Expand Down
51 changes: 51 additions & 0 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,45 @@ def forward(self, x: torch.Tensor):
self.weight.data = self.state.CxB
return out

class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()

# Initialize weight with 4-bit quantization
self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
)
self.weight.cuda(weight.device)

# Initialize other attributes
self.compute_dtype = None
self.bias = bias

def forward(self, x: torch.Tensor):
# Ensure bias has the same dtype as input x
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

# Check if quantization state is initialized
if getattr(self.weight, "quant_state", None) is None:
print("FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.")

# Convert input to compute_dtype if specified
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)

# Convert bias to compute_dtype if it exists
bias = None if self.bias is None else self.bias.to(self.compute_dtype)

# Perform 4-bit matrix multiplication
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)

# Convert output back to the input dtype
out = out.to(inp_dtype)

return out

def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
Expand All @@ -162,6 +201,18 @@ def get_linear(weight, bias, quantize):
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-nf4":
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "bitsandbytes-fp4":
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
Expand Down
Loading