-
Notifications
You must be signed in to change notification settings - Fork 322
Open
Description
CUDA OOM When Running AWQ int4 Quantized llama3.1-8b at Batch Size 1.
repro on 4070 Ti Super
conda create -yn test python=3.10
conda activate test
pip3 install --pre torch==2.9.0.dev20250818+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128
pip3 install --pre torchao==0.13.0.dev20250729+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128
pip3 install transformers datasets
python test.py
script:
import os
import sys
import logging
for handler in logging.root.handlers[:]:
print("Before setup, remove log root handler: ", handler)
logging.root.removeHandler(handler)
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=LOG_LEVEL,
stream=sys.stdout,
format="%(asctime)s %(levelname)s %(name)s %(message)s"
)
logger = logging.getLogger()
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
)
if os.name == 'nt':
os.system('chcp 65001')
sys.stdout.reconfigure(encoding='utf-8')
model_id = "meta-llama/Llama-3.1-8B-Instruct"
device = torch.device("cuda")
dtype = torch.float16
group_size = 128
seq_len = 32
calib_batches = 4
infer_bs = 1
gen_tokens = 32
quant_mode = "woq"
use_hf_code = "False"
device = torch.device("cuda")
def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512):
from datasets import load_dataset
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
samples = []
n_tokens = n_samples * block_size
n_run = n_tokens
for data in dataset:
line = data["text"]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
continue
samples.append(sample)
n_run -= len(line_encoded)
if n_run <= n_samples:
break
cat_samples = torch.cat(samples, dim=1)
return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)
]
amp_enabled = True
amp_dtype = getattr(torch, "bfloat16")
load_dtype = amp_dtype
MODEL_CLASSES = {
"auto": (AutoModelForCausalLM, AutoTokenizer),
"llama": (AutoModelForCausalLM, AutoTokenizer),
}
model_type = next(
(x for x in MODEL_CLASSES.keys() if x in model_id.lower()), "auto"
)
model_class = MODEL_CLASSES[model_type]
config = AutoConfig.from_pretrained(model_id, trust_remote_code=use_hf_code)
quantization_config = None
model = model_class[0].from_pretrained(
model_id,
torch_dtype=load_dtype,
config=config,
low_cpu_mem_usage=True,
trust_remote_code=use_hf_code,
device_map=device,
quantization_config=quantization_config
)
tokenizer = model_class[1].from_pretrained(model_id, trust_remote_code=use_hf_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
from torchao.quantization import quantize_
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
model.eval().to(device)
quant_dtype = torch.uint4
group_size = 128
logger.info(f"running {quant_dtype} calibration")
insert_awq_observer_(model, 1, 512, quant_dtype=quant_dtype, group_size=group_size)
calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=10, block_size=512)
for batch in calibration_data:
if batch.numel() == 0:
continue
model(batch.to(device))
batch.to("cpu")
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = False
awq_uintx_config = awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq)
quantize_(model, awq_uintx_config, is_observed_linear)
model = model.eval().to(device)
print("== Step 6: inference (bs>1, no static cache) ==")
prompts = ["Hey, are you conscious? Can you talk to me?"] * infer_bs
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
gen_kwargs = dict(do_sample=True, temperature=0.9, num_beams=1)
with torch.inference_mode():
out_ids = model.generate(**inputs, max_new_tokens=gen_tokens)
outs = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
print("== Generations ==")
for i, o in enumerate(outs):
print(f"[{i}]:", o[:200].replace("\n", " "))
log:
Traceback (most recent call last):
File "/mnt/ssd1/mint/run.py", line 87, in <module>
model = model_class[0].from_pretrained(
File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 571, in from_pretrained
return model_class.from_pretrained(
File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 279, in _wrapper
return func(*args, **kwargs)
File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4399, in from_pretrained
) = cls._load_pretrained_model(
File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4833, in _load_pretrained_model
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 824, in _load_state_dict_into_meta_model
_load_parameter_into_model(model, param_name, param.to(param_device))
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1002.00 MiB. GPU 0 has a total capacity of 15.56 GiB of which 771.06 MiB is free. Including non-PyTorch memory, this process has 14.62 GiB memory in use. Of the allocated memory 13.87 GiB is allocated by PyTorch, and 527.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Metadata
Metadata
Assignees
Labels
No labels