Skip to content

Commit

Permalink
Update bitnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ridgerchu authored Jun 10, 2024
1 parent 738259c commit 931ab43
Showing 1 changed file with 0 additions and 70 deletions.
70 changes: 0 additions & 70 deletions mmfreelm/ops/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,76 +41,6 @@ def weight_quant(w):
return u


class BitLinear_ReLU(nn.Linear):
"""
A custom linear layer that applies quantization on both activations and weights.
This is primarily for training; kernel optimization is needed for efficiency in deployment.
"""

def __init__(self, in_features, out_features, bias=True):
"""
Initializes the BitLinear layer.
Args:
in_features: Size of each input sample.
out_features: Size of each output sample.
bias: If set to False, the layer will not learn an additive bias. Default: True.
"""
# Initialize the superclass nn.Linear with the given parameters
super(BitLinear_ReLU, self).__init__(in_features, out_features, bias=bias)
self.norm = RMSNorm(in_features, eps=1e-8)

def forward(self, x):
"""
Overrides the forward pass to include quantization.
Args:
x: An input tensor with shape [n, d].
Returns:
An output tensor with shape [n, d].
"""
# Weight tensor
w = self.weight

# Apply RMS normalization to the input
x_norm = self.norm(x)

# Apply quantization to both activations and weights
# Uses Straight-Through Estimator (STE) trick with .detach() for gradient flow
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()

x_quant = torch.relu(x_quant)

# Perform linear operation with quantized values
y = F.linear(x_quant, w_quant)
return y


class BitLinear_Fuse(nn.Linear):
"""
A custom linear layer that applies quantization on both activations and weights.
This is primarily for training; kernel optimization is needed for efficiency in deployment.
"""

def __init__(self, in_features, out_features, bias=True):
"""
Initializes the BitLinear layer.
Args:
in_features: Size of each input sample.
out_features: Size of each output sample.
bias: If set to False, the layer will not learn an additive bias. Default: True.
"""
# Initialize the superclass nn.Linear with the given parameters
super(BitLinear_Fuse, self).__init__(in_features, out_features, bias=bias)
self.fusedlinear = RMSNormLinear(in_features, eps=1e-8)

def forward(self, x):
y = self.fusedlinear(x, self.weight, None)
return y


class BitLinear(nn.Linear):
"""
Expand Down

0 comments on commit 931ab43

Please sign in to comment.