In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

class ChatmlSpecialTokens(str, Enum):
    user = "<|im_start|>user"
    assistant = "<|im_start|>assistant"
    system = "<|im_start|>system"
    eos_token = "<|im_end|>"
    bos_token = "<s>"
    pad_token = "<pad>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]

from peft import PeftModel


In [None]:
# Replace model name and checkpoint path for Llama2-7b

model = "meta-llama/Meta-Llama-3-8B"
checkpoint_path = "llama3_8b_checkpoint"

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="bfloat16",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_storage="uint8",
        )


model = AutoModelForCausalLM.from_pretrained(
            model,
            quantization_config= bnb_config,
            trust_remote_code=True,
            attn_implementation= "eager",
            torch_dtype="auto",
        )


tokenizer = AutoTokenizer.from_pretrained(
            model,
            pad_token=special_tokens.pad_token.value,
            bos_token=special_tokens.bos_token.value,
            eos_token=special_tokens.eos_token.value,
            additional_special_tokens=special_tokens.list(),
            trust_remote_code=True,
        )


tokenizer.chat_template = DEFAULT_CHATML_CHAT_TEMPLATE
        # make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

finetuned_model = PeftModel.from_pretrained(model,
                                  checkpoint_path,
                                  torch_dtype="auto",
                                  is_trainable=False,
                                  device_map="auto"
                                  )


In [None]:
sample_input = "<|im_start|>user \nTo the best of your knowledge, given a smiles string and a question, pick the correct option between 1-5. Answer with single integer only. SMILES string: CCC(C)C(NC(=O)C1CCCN1C(=O)C(CCC(=O)O)NC(=O)C(NC(=O)C1CCCN1C(=O)C(N)Cc1ccc(O)cc1)C(C)C)C(=O)O, Question: What type of protein is the molecule a fragment of?, Option: ['casein', 'whey', 'lactalbumin', 'lactoferrin', 'albumin'].<|im_end|> \n<|im_start|>system \n"
inputs = tokenizer(sample_input, return_tensors="pt")
output = finetuned_model.generate(inputs["input_ids"], max_length=500, num_return_sequences=1, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id)

In [None]:
tokenizer.decode(output[0], skip_special_tokens=False)