In [None]:
from torch import cuda, bfloat16
from accelerate import infer_auto_device_map, init_empty_weights
import transformers

# model_id = 'meta-llama/Llama-2-7b-chat-hf'
model_id = 'openlm-research/open_llama_3b_v2'

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
device

In [None]:
# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
# bnb_config = transformers.BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type='nf4',
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=bfloat16
# )

In [None]:
# begin initializing HF items, need auth token for these
hf_auth = ''
model_config = transformers.AutoConfig.from_pretrained(
    model_id,
    use_auth_token=hf_auth
)

with init_empty_weights():
    model = transformers.AutoModelForCausalLM.from_config(model_config)
model

In [None]:
device_map = infer_auto_device_map(model, no_split_module_classes=["LlamaDecoderLayer"])
device_map

In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    config=model_config,
    # quantization_config=bnb_config,
    device_map='auto',
    use_auth_token=hf_auth
)
model.eval()
print(f"Model loaded on {device}")

In [None]:
tokenizer = transformers.LlamaTokenizer.from_pretrained(
    model_id,
    use_auth_token=hf_auth
)

In [None]:
generate_text = transformers.pipeline(
    model=model, tokenizer=tokenizer,
    return_full_text=True,  # langchain expects the full text
    task='text-generation',
    # we pass model parameters here too
    temperature=0.0,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_new_tokens=512,  # mex number of tokens to generate in the output
    repetition_penalty=1.1  # without this output begins repeating
)

In [None]:
res = generate_text("Explain to me the difference between nuclear fission and fusion.")
print(res[0]["generated_text"])

In [None]:
"""Remember to delete access token before committing.""";