In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
os.environ['HF_HUB_CACHE'] = '/next_share/hf_cache/hub'

import torch
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoModelForSeq2SeqLM, 
    AutoModelForSequenceClassification, AutoConfig, AutoModel, BitsAndBytesConfig
)

from peft import get_peft_model, LoraConfig

In [5]:
def build_gen_model(
    model_name, 
    lora = False, 
    dtype = torch.bfloat16, 
    device_map = None,
    quantization = False
):
    """
    Build generation model, support quantization and lora
    """
    # Determin model auto class by is_encoder_decoder
    config = AutoConfig.from_pretrained(model_name)
    is_seq2seq = getattr(config, 'is_encoder_decoder', False)
    mod_cls = AutoModelForSeq2SeqLM if is_seq2seq  else AutoModelForCausalLM

    # Determin the keyword args of from_pretrained
    ## Determin device_map. Default to the first GPU
    if device_map is None:
        device_map = 0
    
    ## Quantization config for qlora
    if quantization:
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    else:
        quant_config = None

    kws = dict(trust_remote_code = True,
               torch_dtype = dtype,
               device_map = device_map,
               quantization_config = quant_config)

    # Build hf model
    model = mod_cls.from_pretrained(model_name, **kws)
    
    # 3. Add lora adapter
    if lora:
        peft_config = LoraConfig(
            r = 16, lora_alpha = 16,
            target_modules = 'all-linear',
            lora_dropout= 0.1,
            bias = "none"
        )
        # determin task_type
        task_type = "SEQ_2_SEQ_LM" if is_seq2seq else "CAUSAL_LM"
        peft_config.task_type = task_type

        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    return model

True
