In [22]:
import torch
from datasets import load_dataset
import peft
import transformers as tf
from trl import SFTConfig, SFTTrainer

In [23]:
torch.cuda.empty_cache()
print("Available devices:", torch.cuda.device_count(), "Current device:", torch.cuda.current_device(), "Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Available devices: 1 Current device: 0 Device name: NVIDIA GeForce RTX 4090


In [24]:
bnb_config = tf.BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.float32
)

repo_id = 'microsoft/Phi-3-mini-4k-instruct'

In [25]:
model = tf.AutoModelForCausalLM.from_pretrained(
   repo_id, device_map="cuda:0", quantization_config=bnb_config
)
model = peft.PeftModel.from_pretrained(model, "local-phi3-mini-yoda-adapter")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [26]:
print("Model size (in billions of parameters):", model.num_parameters() / 1e9)
print("Memory footprint (in GB):", model.get_memory_footprint()/1e9)

Model size (in billions of parameters): 3.833662464
Memory footprint (in GB): 2.25667296


In [27]:
tokenizer = tf.AutoTokenizer.from_pretrained(repo_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

In [28]:
def gen_prompt(tokenizer, sentence):
    converted_sample = [{"role": "user", "content": sentence}]
    prompt = tokenizer.apply_chat_template(
        converted_sample, tokenize=False, add_generation_prompt=True
    )
    return prompt

In [29]:
def generate(model, tokenizer, prompt, max_new_tokens=64, skip_special_tokens=False):
    tokenized_input = tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt"
    ).to(model.device)

    model.eval()
    gen_output = model.generate(**tokenized_input,
                                eos_token_id=tokenizer.eos_token_id,
                                max_new_tokens=max_new_tokens)
    
    output = tokenizer.batch_decode(gen_output, skip_special_tokens=skip_special_tokens)
    return output[0]

In [30]:
sentence = 'The Force is strong in you!'
print(generate(model, tokenizer, gen_prompt(tokenizer, sentence)))

<|user|> The Force is strong in you!<|end|><|assistant|> Strong in you, the Force is! Yes, hrrrm.<|end|><|endoftext|>


In [31]:
sentence = 'ArithmeticError is a built-in exception in Python.'
print(generate(model, tokenizer, gen_prompt(tokenizer, sentence)))

<|user|> ArithmeticError is a built-in exception in Python.<|end|><|assistant|> A built-in exception in Python, ArithmeticError is.<|end|><|endoftext|>


In [32]:
sentence = 'What a beautiful day.'
print(generate(model, tokenizer, gen_prompt(tokenizer, sentence)))

<|user|> What a beautiful day.<|end|><|assistant|> A beautiful day, yes, hrrrm.<|end|><|endoftext|>
