|
34 | 34 | TensorCoreTiledLayout,
|
35 | 35 | )
|
36 | 36 | from torchao.quantization import (
|
| 37 | + Int4TilePackedTo4dTensor, |
| 38 | + IntxUnpackedToInt8Tensor, |
37 | 39 | LinearActivationQuantizedTensor,
|
38 | 40 | PerGroup,
|
39 | 41 | )
|
|
57 | 59 | _replace_with_custom_fn_if_matches_filter,
|
58 | 60 | )
|
59 | 61 | from torchao.quantization.quant_primitives import MappingType
|
60 |
| -from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( |
61 |
| - IntxUnpackedToInt8Tensor, |
62 |
| -) |
63 | 62 | from torchao.quantization.subclass import (
|
64 | 63 | Int4WeightOnlyQuantizedLinearWeight,
|
65 | 64 | Int8WeightOnlyQuantizedLinearWeight,
|
@@ -691,6 +690,100 @@ def test_module_fqn_to_config_module_name(self):
|
691 | 690 | assert isinstance(model.linear2.weight, AffineQuantizedTensor)
|
692 | 691 | assert isinstance(model.linear2.weight._layout, PlainLayout)
|
693 | 692 |
|
| 693 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 694 | + def test_module_fqn_to_config_regex_basic(self): |
| 695 | + config1 = Int4WeightOnlyConfig( |
| 696 | + group_size=32, int4_packing_format="tile_packed_to_4d" |
| 697 | + ) |
| 698 | + config = ModuleFqnToConfig({"re:linear.*": config1}) |
| 699 | + model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) |
| 700 | + example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) |
| 701 | + quantize_(model, config) |
| 702 | + model(*example_inputs) |
| 703 | + assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) |
| 704 | + assert isinstance(model.linear2.weight, Int4TilePackedTo4dTensor) |
| 705 | + |
| 706 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 707 | + def test_module_fqn_to_config_regex_precedence(self): |
| 708 | + """Testing that full path config takes precedence over |
| 709 | + regex config in ModuleFqnToConfig |
| 710 | + """ |
| 711 | + config1 = Int4WeightOnlyConfig( |
| 712 | + group_size=32, int4_packing_format="tile_packed_to_4d" |
| 713 | + ) |
| 714 | + config2 = IntxWeightOnlyConfig() |
| 715 | + config = ModuleFqnToConfig({"linear1": config1, "re:linear.*": config2}) |
| 716 | + model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) |
| 717 | + example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) |
| 718 | + quantize_(model, config) |
| 719 | + model(*example_inputs) |
| 720 | + assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) |
| 721 | + assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor) |
| 722 | + |
| 723 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 724 | + def test_module_fqn_to_config_regex_precedence2(self): |
| 725 | + """Testing that full path config takes precedence over |
| 726 | + regex config in ModuleFqnToConfig, swapping |
| 727 | + the order of `re:linear.*` and `linear1` to make sure that |
| 728 | + `linear1` config has precedence even it comes after `linear*` |
| 729 | + """ |
| 730 | + config1 = Int4WeightOnlyConfig( |
| 731 | + group_size=32, int4_packing_format="tile_packed_to_4d" |
| 732 | + ) |
| 733 | + config2 = IntxWeightOnlyConfig() |
| 734 | + config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1}) |
| 735 | + model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) |
| 736 | + example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) |
| 737 | + quantize_(model, config) |
| 738 | + model(*example_inputs) |
| 739 | + assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) |
| 740 | + assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor) |
| 741 | + |
| 742 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 743 | + def test_module_fqn_to_config_regex_fullmatch(self): |
| 744 | + """Testing that we will only match the fqns that fully |
| 745 | + matches the regex |
| 746 | + """ |
| 747 | + |
| 748 | + class M(torch.nn.Module): |
| 749 | + def __init__(self, dtype, device): |
| 750 | + super().__init__() |
| 751 | + self.dtype = dtype |
| 752 | + self.device = device |
| 753 | + self.linear1 = torch.nn.Linear(32, 64, dtype=dtype, device=device) |
| 754 | + self.not_full_match_linear2 = torch.nn.Linear( |
| 755 | + 64, 32, dtype=dtype, device=device |
| 756 | + ) |
| 757 | + self.linear3_full_match = torch.nn.Linear( |
| 758 | + 32, 32, dtype=dtype, device=device |
| 759 | + ) |
| 760 | + |
| 761 | + def forward(self, x): |
| 762 | + x = self.linear1(x) |
| 763 | + x = self.not_full_match_linear2(x) |
| 764 | + x = self.linear3_full_match(x) |
| 765 | + return |
| 766 | + |
| 767 | + def example_inputs(self): |
| 768 | + return (torch.randn(1, 32, dtype=self.dtype, device=self.device),) |
| 769 | + |
| 770 | + config1 = Int4WeightOnlyConfig( |
| 771 | + group_size=32, int4_packing_format="tile_packed_to_4d" |
| 772 | + ) |
| 773 | + config2 = IntxWeightOnlyConfig() |
| 774 | + config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1}) |
| 775 | + model = M(dtype=torch.bfloat16, device="cuda") |
| 776 | + example_inputs = model.example_inputs() |
| 777 | + quantize_(model, config) |
| 778 | + model(*example_inputs) |
| 779 | + assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor) |
| 780 | + # since fqn does not fully match `linear*`, it should not be quantized |
| 781 | + assert not isinstance( |
| 782 | + model.not_full_match_linear2.weight, IntxUnpackedToInt8Tensor |
| 783 | + ) |
| 784 | + # linear3_full_match matches `linear*`, so should be quantized |
| 785 | + assert isinstance(model.linear3_full_match.weight, IntxUnpackedToInt8Tensor) |
| 786 | + |
694 | 787 | def test_module_fqn_to_config_embedding_linear(self):
|
695 | 788 | weight_dtype = torch.int8
|
696 | 789 | granularity = PerGroup(8)
|
|
0 commit comments