From 0bef22ad8454c0ba9bdafce29d1d2e6380e54416 Mon Sep 17 00:00:00 2001 From: helunwencser Date: Tue, 22 Oct 2024 22:51:09 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- .../llama/TestInt8DynActInt4WeightLinear.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 examples/models/llama/TestInt8DynActInt4WeightLinear.py diff --git a/examples/models/llama/TestInt8DynActInt4WeightLinear.py b/examples/models/llama/TestInt8DynActInt4WeightLinear.py new file mode 100644 index 00000000000..f25859e8b06 --- /dev/null +++ b/examples/models/llama/TestInt8DynActInt4WeightLinear.py @@ -0,0 +1,33 @@ +import torch.cuda +from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear + +from torch import nn +class Attention(nn.Module): + + def __init__(self): + super().__init__() + self.wq = Int8DynActInt4WeightLinear( + in_features=2048, + out_features=2048, + bias=False, + device="cuda" if torch.cuda.is_available() else "cpu", + groupsize=32, + precision=torch.float32, + scales_precision=torch.float32 + ) + + def forward(self, x: torch.tensor): + return self.wq.forward(x) + + +def main() -> None: + input = torch.load("file/to/input/tensor") + checkpoint = torch.load("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth", map_location="cpu", + mmap=True) + model = Attention() + model.load_state_dict(checkpoint, strict=False, assign=True) + + print(model.forward(input)) + +if __name__ == "__main__": + main() \ No newline at end of file From f853812be5468f42d9c3272bdbd7dc5f8df44073 Mon Sep 17 00:00:00 2001 From: helunwencser Date: Wed, 23 Oct 2024 09:44:32 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- .../llama/TestInt8DynActInt4WeightLinear.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/models/llama/TestInt8DynActInt4WeightLinear.py b/examples/models/llama/TestInt8DynActInt4WeightLinear.py index f25859e8b06..a3116413013 100644 --- a/examples/models/llama/TestInt8DynActInt4WeightLinear.py +++ b/examples/models/llama/TestInt8DynActInt4WeightLinear.py @@ -4,13 +4,13 @@ from torch import nn class Attention(nn.Module): - def __init__(self): + def __init__(self, device): super().__init__() self.wq = Int8DynActInt4WeightLinear( in_features=2048, out_features=2048, bias=False, - device="cuda" if torch.cuda.is_available() else "cpu", + device=device, groupsize=32, precision=torch.float32, scales_precision=torch.float32 @@ -21,13 +21,15 @@ def forward(self, x: torch.tensor): def main() -> None: - input = torch.load("file/to/input/tensor") - checkpoint = torch.load("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth", map_location="cpu", + device = "cuda" if torch.cuda.is_available() else "cpu" + input = torch.load("file/to/input/tensor", map_location=device) + checkpoint = torch.load("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth", map_location=device, mmap=True) - model = Attention() - model.load_state_dict(checkpoint, strict=False, assign=True) + for i in range(5): + model = Attention(device) + model.load_state_dict(checkpoint, strict=False, assign=True) - print(model.forward(input)) + print(model.forward(input)) if __name__ == "__main__": main() \ No newline at end of file