In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset


In [None]:
pretrained_model_dir = "Qwen/Qwen-14B"
quantized_model_dir = "Qwen-14B-8bit"

In [None]:
def get_wikitext2(tokenizer):
    import numpy as np
    import torch
    import random
    wikidata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    wikilist = [' \n' if s == '' else s for s in wikidata['text'] ]

    text = ''.join(wikilist)
    trainenc = tokenizer(text, return_tensors='pt')

    random.seed(0)
    np.random.seed(0)
    torch.random.manual_seed(0)

    traindataset = []

    num_example = 120
    seqlen = 4096

    for _ in range(num_example):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        traindataset.append({'input_ids':inp,'attention_mask': attention_mask})
    return traindataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, trust_remote_code=True, use_fast=True)
examples = get_wikitext2(tokenizer)

In [None]:
quantize_config = BaseQuantizeConfig(
    bits=8,  # quantize model to 8-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly bad
)

In [None]:
# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(
    pretrained_model_dir,
    quantize_config,
    # device_map="auto",
    trust_remote_code=True,
    # max_memory={0: "22GIB", 1: "22GIB"},
)

In [None]:
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)

In [None]:
# save quantized model
model.save_pretrained(quantized_model_dir+'-hf')