diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 9bd16fa7c07..707bce88249 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -216,3 +216,30 @@ runtime.python_test( "//executorch/examples/models/llama:llama_transformer", ], ) + +runtime.python_library( + name = "test_8da4w_library", + srcs = [ + "test_8da4w.py" + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama", + visibility = [ + "//bento/...", + "//bento_kernels/...", + "//executorch/examples/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//pytorch/ao:torchao", + ] +) + +runtime.python_binary( + name = "test_8da4w", + main_function = "executorch.examples.models.llama.test_8da4w.main", + deps = [ + ":test_8da4w_library", + "//caffe2:torch", + ] +) diff --git a/examples/models/llama/test_8da4w.py b/examples/models/llama/test_8da4w.py new file mode 100644 index 00000000000..c477bd9579b --- /dev/null +++ b/examples/models/llama/test_8da4w.py @@ -0,0 +1,58 @@ +import os + +import torch.cuda + +from torch import nn +from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + + +class Attention(nn.Module): + + def __init__(self, device): + super().__init__() + self.wq = Int8DynActInt4WeightLinear( + in_features=2048, + out_features=2048, + bias=False, + device=device, + groupsize=32, + precision=torch.float32, + scales_precision=torch.float32, + ) + + def forward(self, x: torch.tensor): + return self.wq.forward(x) + + +def main() -> None: + seed = 42 + torch.manual_seed(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + input = torch.load(f"{os.path.dirname(__file__)}/x.pt").to(device=device) + checkpoint = torch.load( + f"{os.path.dirname(__file__)}/wq.pth", + map_location=device, + mmap=True, + ) + print(f"input {input}") + results = [] + iterations = 10 + for i in range(iterations): + model = Attention(device).to(device=device) + model.load_state_dict(checkpoint, strict=False, assign=True) + + result = model.forward(input) + exist = False + for existing_result in results: + if torch.allclose(result, existing_result): + exist = True + break + if not exist: + results.append(result) + print(f"Generated {len(results)} results with {iterations} iterations") + for i, result in enumerate(results): + print(f"result {i} {result}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/wq.pth b/examples/models/llama/wq.pth new file mode 100644 index 00000000000..ab1fa1bc349 Binary files /dev/null and b/examples/models/llama/wq.pth differ diff --git a/examples/models/llama/x.pt b/examples/models/llama/x.pt new file mode 100644 index 00000000000..1214840f860 Binary files /dev/null and b/examples/models/llama/x.pt differ