Skip to content

CUDA OOM When Running AWQ int4 Quantized llama3.1-8b at Batch Size 1 #2867

@MingxuZh

Description

@MingxuZh

CUDA OOM When Running AWQ int4 Quantized llama3.1-8b at Batch Size 1.

repro on 4070 Ti Super

conda create -yn test python=3.10
conda activate test
pip3 install --pre torch==2.9.0.dev20250818+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128  
pip3 install --pre torchao==0.13.0.dev20250729+cu128 --index-url https://download.pytorch.org/whl/nightly/cu128
pip3 install transformers datasets
python test.py

script:

import os
import sys
import logging

for handler in logging.root.handlers[:]:
    print("Before setup, remove log root handler: ", handler)
    logging.root.removeHandler(handler)
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
    level=LOG_LEVEL,
    stream=sys.stdout,
    format="%(asctime)s %(levelname)s %(name)s %(message)s"
)
logger = logging.getLogger()  

import torch

from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
)

if os.name == 'nt':
    os.system('chcp 65001')
    sys.stdout.reconfigure(encoding='utf-8')

model_id = "meta-llama/Llama-3.1-8B-Instruct" 
device = torch.device("cuda")
dtype = torch.float16         
group_size = 128
seq_len = 32                  
calib_batches = 4             
infer_bs = 1                 
gen_tokens = 32               
quant_mode = "woq"
use_hf_code = "False"

device = torch.device("cuda")

def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512):
    from datasets import load_dataset
    dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
    samples = []
    n_tokens = n_samples * block_size
    n_run = n_tokens
    for data in dataset:
        line = data["text"]
        line = line.strip()
        line_encoded = tokenizer.encode(line)
        if len(line_encoded) > 512:
            continue
        sample = torch.tensor([line_encoded])
        if sample.numel() == 0:
            continue
        samples.append(sample)
        n_run -= len(line_encoded)
        if n_run <= n_samples:
            break

    cat_samples = torch.cat(samples, dim=1)
    return [
        cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)
    ]

amp_enabled = True 
amp_dtype = getattr(torch, "bfloat16")
load_dtype = amp_dtype

MODEL_CLASSES = {
    "auto": (AutoModelForCausalLM, AutoTokenizer),
    "llama": (AutoModelForCausalLM, AutoTokenizer),
}

model_type = next(
    (x for x in MODEL_CLASSES.keys() if x in model_id.lower()), "auto"
)
model_class = MODEL_CLASSES[model_type]
config = AutoConfig.from_pretrained(model_id, trust_remote_code=use_hf_code)
quantization_config = None

model = model_class[0].from_pretrained(
    model_id,
    torch_dtype=load_dtype,
    config=config,
    low_cpu_mem_usage=True,
    trust_remote_code=use_hf_code,
    device_map=device,
    quantization_config=quantization_config
)
tokenizer = model_class[1].from_pretrained(model_id, trust_remote_code=use_hf_code)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
from torchao.quantization import quantize_
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_

model.eval().to(device)
quant_dtype = torch.uint4
group_size = 128
logger.info(f"running {quant_dtype} calibration")
insert_awq_observer_(model, 1, 512, quant_dtype=quant_dtype, group_size=group_size)
calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=10, block_size=512)
for batch in calibration_data:
    if batch.numel() == 0:
        continue
    model(batch.to(device))
    batch.to("cpu")
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = False
awq_uintx_config = awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq)

quantize_(model, awq_uintx_config, is_observed_linear)


model = model.eval().to(device)

print("== Step 6: inference (bs>1, no static cache) ==")
prompts = ["Hey, are you conscious? Can you talk to me?"] * infer_bs
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
gen_kwargs = dict(do_sample=True, temperature=0.9, num_beams=1)  
with torch.inference_mode():
    out_ids = model.generate(**inputs, max_new_tokens=gen_tokens)
outs = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
print("== Generations ==")
for i, o in enumerate(outs):
    print(f"[{i}]:", o[:200].replace("\n", " "))

log:

Traceback (most recent call last):
  File "/mnt/ssd1/mint/run.py", line 87, in <module>
    model = model_class[0].from_pretrained(
  File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 571, in from_pretrained
    return model_class.from_pretrained(
  File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 279, in _wrapper
    return func(*args, **kwargs)
  File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4399, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4833, in _load_pretrained_model
    disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
  File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/root/miniforge3/envs/2025ww30_torchao/lib/python3.10/site-packages/transformers/modeling_utils.py", line 824, in _load_state_dict_into_meta_model
    _load_parameter_into_model(model, param_name, param.to(param_device))
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1002.00 MiB. GPU 0 has a total capacity of 15.56 GiB of which 771.06 MiB is free. Including non-PyTorch memory, this process has 14.62 GiB memory in use. Of the allocated memory 13.87 GiB is allocated by PyTorch, and 527.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions