Skip to content

Commit ed83ae2

Browse files
authored
Unpin nightly version (#593)
Summary: Previously there was some inductor errors so we pinned the nightly version. It should be fixed by pytorch/pytorch#132096 and we now can't use `unwrap_tensor_subclass` before `torch.compile` now. Test Plan: fix CI errors Reviewers: Subscribers: Tasks: Tags:
1 parent 00529fa commit ed83ae2

File tree

12 files changed

+58
-35
lines changed

12 files changed

+58
-35
lines changed

.github/workflows/regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
gpu-arch-version: "12.1"
3434
- name: CUDA Nightly
3535
runs-on: linux.g5.12xlarge.nvidia.gpu
36-
torch-spec: '--pre torch==2.5.0.dev20240728+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
36+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
3737
gpu-arch-type: "cuda"
3838
gpu-arch-version: "12.1"
3939
- name: CPU 2.2.2
@@ -48,7 +48,7 @@ jobs:
4848
gpu-arch-version: ""
4949
- name: CPU Nightly
5050
runs-on: linux.4xlarge
51-
torch-spec: '--pre torch==2.5.0.dev20240728 --index-url https://download.pytorch.org/whl/nightly/cpu'
51+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
5252
gpu-arch-type: "cpu"
5353
gpu-arch-version: ""
5454

test/integration/test_integration.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,24 @@
101101
def _int8wo_api(mod):
102102
if TORCH_VERSION_AFTER_2_4:
103103
quantize_(mod, int8_weight_only(), set_inductor_config=False)
104-
unwrap_tensor_subclass(mod)
104+
if not TORCH_VERSION_AFTER_2_5:
105+
unwrap_tensor_subclass(mod)
105106
else:
106107
change_linear_weights_to_int8_woqtensors(mod)
107108

108109
def _int8da_int8w_api(mod):
109110
if TORCH_VERSION_AFTER_2_4:
110111
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
111-
unwrap_tensor_subclass(mod)
112+
if not TORCH_VERSION_AFTER_2_5:
113+
unwrap_tensor_subclass(mod)
112114
else:
113115
change_linear_weights_to_int8_dqtensors(mod)
114116

115117
def _int4wo_api(mod):
116118
if TORCH_VERSION_AFTER_2_4:
117119
quantize_(mod, int4_weight_only(), set_inductor_config=False)
118-
unwrap_tensor_subclass(mod)
120+
if not TORCH_VERSION_AFTER_2_5:
121+
unwrap_tensor_subclass(mod)
119122
else:
120123
change_linear_weights_to_int4_woqtensors(mod)
121124

@@ -853,7 +856,8 @@ def api(mod):
853856
kwargs_copy["group_size"] = groupsize
854857
del kwargs_copy["groupsize"]
855858
quantize_(mod, int4_weight_only(**kwargs_copy))
856-
unwrap_tensor_subclass(mod)
859+
if not TORCH_VERSION_AFTER_2_5:
860+
unwrap_tensor_subclass(mod)
857861
else:
858862
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
859863

@@ -985,6 +989,9 @@ def forward(self, x):
985989
# save quantized state_dict
986990
api(model)
987991

992+
# make sure the model is still runnable
993+
model(x)
994+
988995
torch.save(model.state_dict(), "test.pth")
989996
# get quantized reference
990997
model_qc = torch.compile(model, mode="max-autotune")
@@ -1004,7 +1011,9 @@ def forward(self, x):
10041011
model.load_state_dict(state_dict, assign=True)
10051012
model = model.to(device=test_device, dtype=test_dtype).eval()
10061013

1007-
# get quantized reference
1014+
# make sure the model is still runnable
1015+
model(x)
1016+
10081017
model_qc = torch.compile(model, mode="max-autotune")
10091018
test = model_qc(x).detach()
10101019

@@ -1013,6 +1022,7 @@ def forward(self, x):
10131022

10141023
@parameterized.expand(COMMON_DEVICE_DTYPE)
10151024
@unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'")
1025+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Can't save local lambda function for tensor subclass")
10161026
@torch.no_grad()
10171027
def test_save_load_dqtensors(self, device, dtype):
10181028
if device == "cpu":

test/sparsity/test_sparse_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
int8_dynamic_activation_int8_weight,
1919
quantize_,
2020
)
21-
from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass
21+
from torchao.utils import TORCH_VERSION_AFTER_2_3
2222
from torch.testing._internal.common_utils import TestCase
2323

2424

torchao/_models/llama/eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import time
2323
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
2424
from torchao._models.llama.model import prepare_inputs_for_model
25+
from torchao.utils import TORCH_VERSION_AFTER_2_5
2526

2627
def run_evaluation(
2728
checkpoint_path: Path,
@@ -88,7 +89,8 @@ def run_evaluation(
8889
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
8990
model = quantizer.quantize(model, inputs).to(device)
9091
else:
91-
unwrap_tensor_subclass(model)
92+
if not TORCH_VERSION_AFTER_2_5:
93+
unwrap_tensor_subclass(model)
9294

9395
if compile:
9496
model = torch.compile(model, mode="max-autotune", fullgraph=True)

torchao/_models/llama/generate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515
from torchao.utils import get_model_size_in_bytes
16+
from torchao.utils import TORCH_VERSION_AFTER_2_5
1617

1718
def device_sync(device):
1819
if "cuda" in device:
@@ -115,7 +116,7 @@ def generate(
115116
from model import AffineQuantizedKVCache
116117
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
117118
_replace_with_custom_fn_if_matches_filter(
118-
model,
119+
model,
119120
AffineQuantizedKVCache.from_float,
120121
lambda x, y: isinstance(x, torchao._models.llama.model.KVCache),
121122
)
@@ -232,7 +233,8 @@ def main(
232233
# do autoquantization
233234
model.finalize_autoquant()
234235
else:
235-
unwrap_tensor_subclass(model)
236+
if not TORCH_VERSION_AFTER_2_5:
237+
unwrap_tensor_subclass(model)
236238

237239
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
238240

torchao/_models/sam/eval_combo.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
1313
from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight
1414
from torchao.utils import unwrap_tensor_subclass
15+
from torchao.utils import TORCH_VERSION_AFTER_2_5
1516

1617
torch._dynamo.config.cache_size_limit = 50000
1718

@@ -284,7 +285,8 @@ def run(
284285

285286
if compress == "int8_dynamic_quant":
286287
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
287-
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
288+
if not TORCH_VERSION_AFTER_2_5:
289+
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
288290
elif compress == "sparse_mlp_only":
289291
def mlp_only(mod, name):
290292
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
@@ -316,7 +318,8 @@ def mlp_only(mod, name):
316318
sparsify_(predictor.model.image_encoder,
317319
semi_sparse_weight(),
318320
mlp_lin2_only)
319-
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
321+
if not TORCH_VERSION_AFTER_2_5:
322+
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
320323

321324
else:
322325
assert compress is None, f"Unsupported compress mode {compress}"
@@ -401,6 +404,6 @@ def mlp_only(mod, name):
401404
vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile,
402405
use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path]))
403406
f.write(vals+"\n")
404-
407+
405408
if __name__ == '__main__':
406409
fire.Fire(run)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,25 +607,24 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
607607
quantize_affine,
608608
)
609609
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
610+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
610611

611612
cur_shape = self.shape
612613
assert len(cur_shape) == 4
613614
inner_k_tiles = cur_shape[-1] * 2
614615
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
615616
eye_shape = original_shape[1]
616-
block_size = (1, 32)
617+
groupsize = int(original_shape[1] / scale.shape[-2])
618+
block_size = (1, groupsize)
617619
device = self.device
618620
original_dtype = torch.bfloat16
619-
groupsize = 32
620621
target_dtype = torch.int32
621622
quant_min = 0
622623
quant_max = 15
623624
zero_point_domain = ZeroPointDomain.FLOAT
624625
assert len(block_size) == 2 and block_size[0] == 1
625-
groupsize = block_size[-1]
626626
dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero)
627627
dequantized = dequantized.t().contiguous()
628-
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
629628
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
630629
scale = scale.reshape(scale.shape[:-1]).contiguous()
631630
zero = zero.reshape(zero.shape[:-1]).contiguous()

torchao/quantization/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,11 @@ group_size = 32
109109
quantize_(m, int4_weight_only(group_size=group_size))
110110

111111
# temporary workaround for tensor subclass + torch.compile
112+
# NOTE: this is only need for torch 2.5+
113+
from torchao.utils import TORCH_VERSION_AFTER_2_5
112114
from torchao.utils import unwrap_tensor_subclass
113-
m = unwrap_tensor_subclass(m)
115+
if not TORCH_VERSION_AFTER_2_5:
116+
unwrap_tensor_subclass(m)
114117
# compile the model to improve performance
115118
m = torch.compile(m, mode='max-autotune')
116119

torchao/quantization/quant_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ def int8_dynamic_activation_int4_weight(group_size=32):
337337
size is more fine grained
338338
"""
339339
def apply_int8_dynamic_activation_int4_weight_quant(weight):
340+
if weight.shape[-1] % group_size != 0:
341+
return weight
342+
340343
# avoid circular dep
341344
from torchao.dtypes import to_affine_quantized
342345

@@ -379,6 +382,9 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
379382
`inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2]
380383
"""
381384
def apply_int4_weight_only_quant(weight):
385+
if weight.shape[-1] % group_size != 0:
386+
return weight
387+
382388
# avoid circular dep
383389
from torchao.dtypes import to_affine_quantized
384390
from torchao.dtypes import TensorCoreTiledLayoutType
@@ -438,18 +444,12 @@ def get_weight_block_size(x):
438444
zero_point_dtype = torch.int64
439445

440446
# input settings
441-
def get_per_token_block_size(x):
442-
block_size = list(x.shape)
443-
for i in range(len(block_size)-1):
444-
block_size[i] = 1
445-
return block_size
446-
447447
input_mapping_type = MappingType.SYMMETRIC
448448
input_target_dtype = torch.int8
449449
input_eps = 1e-5
450450
input_quant_min = -127
451451
input_quant_max = 127
452-
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
452+
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
453453

454454
block_size = get_weight_block_size(weight)
455455
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)

torchao/quantization/quant_primitives.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def _quantize_affine_no_dtype_cast(
233233
# TODO: validations
234234
# TODO: validate scale/zero_point dimensions are compatible with block_size
235235
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
236+
assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
236237
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
237238
original_shape = input.shape
238239
input = input.view(shape_for_reduction)
@@ -349,6 +350,7 @@ def _dequantize_affine_no_dtype_check(
349350
zero_point_domain: str = ZeroPointDomain.INT.name,
350351
output_dtype: torch.dtype = torch.float32,
351352
) -> torch.Tensor:
353+
assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
352354
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
353355
original_shape = input.shape
354356
input = input.view(shape_for_reduction)
@@ -589,7 +591,7 @@ def _choose_qparams_affine(
589591
if zero_point_dtype is None:
590592
zero_point_dtype = input.dtype
591593

592-
assert len(block_size) == input.dim()
594+
assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
593595
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
594596
input = input.view(shape_for_reduction)
595597

torchao/quantization/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
129129
if size is not None and tensor_arg.size() != size:
130130
raise ValueError(f"Expected Tensor argument {arg_name} to have size {size}, but got {tensor_arg.size()} instead.")
131131

132+
def _get_per_token_block_size(x: torch.Tensor) -> List[int]:
133+
block_size = []
134+
for _ in range(len(x.shape)-1):
135+
block_size.append(1)
136+
block_size.append(x.shape[-1])
137+
return block_size
138+
132139
# taken from
133140
# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26
134141
# and slightly modified
@@ -492,10 +499,3 @@ def recommended_inductor_config_setter():
492499
torch._inductor.config.fx_graph_cache = True
493500
torch._inductor.config.triton.unique_kernel_names = True
494501
torch.set_float32_matmul_precision("high")
495-
496-
def _get_per_token_block_size(x: torch.Tensor) -> List[int]:
497-
block_size = []
498-
for i in range(len(x.shape)-1):
499-
block_size.append(1)
500-
block_size.append(x.shape[-1])
501-
return block_size

tutorials/quantize_vit/run_vit_b_quant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
## compilation configs end
3232

3333
# temporary workaround for the API to work with torch.compile
34+
from torchao.utils import TORCH_VERSION_AFTER_2_5
3435
from torchao.utils import unwrap_tensor_subclass
35-
unwrap_tensor_subclass(model)
36+
if not TORCH_VERSION_AFTER_2_5:
37+
unwrap_tensor_subclass(model)
3638

3739
model = torch.compile(model, mode='max-autotune')
3840

0 commit comments

Comments
 (0)