In [2]:
import pandas as pd
import torch
from huggingface_hub import notebook_login
from peft import PeftModel
from transformers import (
    BitsAndBytesConfig,
    LlamaForCausalLM,
    LlamaTokenizer,
)
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "openlm-research/open_llama_3b_v2"


In [3]:
quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)

In [4]:
model = LlamaForCausalLM.from_pretrained(
            MODEL_NAME,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            quantization_config=quantization_config,
            token="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",
            device_map="cuda",)



In [5]:
tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME, token="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",)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [6]:
model = PeftModel.from_pretrained(
                model,
                "openlm-research-open_llama_3b_v2/",
                torch_dtype=torch.float16)

In [7]:
test = pd.read_feather("datasets/movie_datasets/imdb/test_llm_ds_v1.feather")

In [8]:
test.head()

Unnamed: 0,tconst,originalTitle,data,question,answer,prompt
0,tt16252240,The Pitch,Description: A tech geek and a gorgeous secret...,What is the secretary's name?,Unknown,Below is a question regarding movies and shows...
1,tt11172868,Unbreakable,Description: Mariel and Deena have been best f...,What is the movie's genre?,"Comedy, Drama, Romance",Below is a question regarding movies and shows...
2,tt12448312,Posts to the Pope,Description: RTE News asked a range of people ...,What is the name of the person who is committe...,Unanswerable,Below is a question regarding movies and shows...
3,tt11229886,Les Misérables: The Staged Concert,Description: Seen by over 120 million people w...,"Where can you watch ""Les Misérables: The Stage...",cinemas,Below is a question regarding movies and shows...
4,tt11994944,"Plymouth, Michigan - A Rich History","Description: Founded in 1825, the Plymouth com...","What fires are mentioned in ""Plymouth, Michiga...","The Great Fire of 1871, the Plymouth Train Sta...",Below is a question regarding movies and shows...


In [9]:
idx = 0
prompt, answer = test.iloc[idx].prompt, test.iloc[idx].answer
print(prompt)
print(answer)

Below is a question regarding movies and shows paired with an input that provides further context. Write a response that appropriately completes the request.
###Instruction: What is the secretary's name?
###Input: Description: A tech geek and a gorgeous secretary meet during a pitch. They forge a risky partnership without guarantees. 
Release Year: 2019 
Genre: Drama,Romance
###Response:
Unknown


In [10]:
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(DEVICE)

In [11]:
with torch.no_grad():
    generation_output = model.generate(
                    input_ids=input_ids,
                    return_dict_in_generate=True,
                    output_scores=True,
                    max_new_tokens=1024,
                    temperature=0.8,
                    do_sample=True
                )

In [12]:
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)

In [13]:
print(output)

Below is a question regarding movies and shows paired with an input that provides further context. Write a response that appropriately completes the request.
###Instruction: What is the secretary's name?
###Input: Description: A tech geek and a gorgeous secretary meet during a pitch. They forge a risky partnership without guarantees. 
Release Year: 2019 
Genre: Drama,Romance
###Response: N/A
