Skip to content

Flux1 dev model quantization and compilation encountered errors during the second execution #48

@one-over

Description

@one-over
import os
import torch
import time
from diffusers import FluxPipeline
from torchao.quantization import autoquant

model_id = "../newmodel/FLUX.1-dev"
SAVED_MODEL_PATH = "./flux_quantized"
dtype = torch.bfloat16
device = "cuda"
quant_type = "int8wo"  # 选择量化类型

def main():
    # 检查是否存在已保存的量化模型
    if not os.path.exists(SAVED_MODEL_PATH):
        print("🚀 正在量化并编译模型(首次运行较慢)...")
        start_load = time.time()
        
        # 加载模型并设置为bfloat16
        pipeline = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype).to(device)

        # 将模型转换为channels_last格式
        pipeline.transformer.to(memory_format=torch.channels_last)

        # 编译模型
        pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)

        # 量化模型
        pipeline.transformer = autoquant(pipeline.transformer, error_on_unseen=False)
        # 保存量化后的模型
        pipeline.save_pretrained(SAVED_MODEL_PATH, safe_serialization=False)
        print(f"⏱️ 模型加载、编译与量化耗时: {time.time() - start_load:.2f}秒")

    else:
        print("🔍 发现已保存的量化模型,直接加载...")
        start_load = time.time()
        pipeline = FluxPipeline.from_pretrained(SAVED_MODEL_PATH, torch_dtype=dtype).to(device)
        print(f"⏱️ 模型加载耗时: {time.time() - start_load:.2f}秒")

    # 预热运行(触发编译)
    print("\n🔥 预热运行...")
    for _ in range(3):
        _ = pipeline("a forest", num_inference_steps=30, guidance_scale=3.5)

    # 正式推理测试
    print("\n⏳ 开始推理测试...")
    start_inference = time.time()
    with torch.no_grad():
        image = pipeline(
            "A robot made of exotic candies, chocolates, and sugar decorations, with joints made of candy jelly, adorned with golden sequins and sugar beads. The background is filled with confetti and colorful ribbons, scattered gifts surround the scene, and the air is filled with a sweet aroma, like a dreamlike celebration.",
            num_inference_steps=20,
            guidance_scale=3.5,
        ).images[0]
    elapsed = time.time() - start_inference

    print(f"\n✅ 推理完成!总耗时: {elapsed:.2f}秒")
    print(f"💻 GPU内存使用: {torch.cuda.max_memory_reserved()/1024**3:.2f}GB")
    image.save("output.png")

if __name__ == "__main__":
    main()

On the second execution of this code, the error is as follows:

(myenv) root@iZj6c65344j8baeqhb5j0tZ:~/data2/dev-test# python3 kk.py
🔍 发现已保存的量化模型,直接加载...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 9.60it/s]
Loading pipeline components...: 14%|█████████████████████ | 1/7 [00:00<00:01, 3.90it/s]An error occurred while trying to fetch ./flux_quantized/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory ./flux_quantized/transformer.
Defaulting to unsafe serialization. Pass allow_pickle=False to raise an error instead.
Loading pipeline components...: 29%|██████████████████████████████████████████ | 2/7 [00:20<00:50, 10.03s/it]
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/model_loading_utils.py", line 150, in load_state_dict
return torch.load(
File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 1432, in load
with _open_zipfile_reader(opened_file) as opened_zipfile:
File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 763, in init
super().init(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: PytorchStreamReader failed reading zip archive: invalid header or archive is corrupted

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/model_loading_utils.py", line 158, in load_state_dict
if f.read().startswith("version"):
File "/usr/lib/python3.10/codecs.py", line 322, in decode
(result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 128: invalid start byte

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/data2/dev-test/kk.py", line 62, in
main()
File "/root/data2/dev-test/kk.py", line 38, in main
pipeline = FluxPipeline.from_pretrained(SAVED_MODEL_PATH, torch_dtype=dtype).to(device)
File "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_utils.py", line 924, in from_pretrained
loaded_sub_model = load_sub_model(
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_loading_utils.py", line 725, in load_sub_model
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
File "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_utils.py", line 886, in from_pretrained
state_dict = load_state_dict(model_file, variant=variant)
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/model_loading_utils.py", line 170, in load_state_dict
raise OSError(
OSError: Unable to load weights from checkpoint file for './flux_quantized/transformer/diffusion_pytorch_model.bin' at './flux_quantized/transformer/diffusion_pytorch_model.bin'. `

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions