In [1]:
# standard python imports
import os
# import pandas as pd
import torch

# huggingface libraries

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    # HfArgumentParser,
    # TrainingArguments,
    pipeline,
    # logging,
    LlamaForCausalLM
)
from peft import (
#     LoraConfig,
    PeftModel,
#     prepare_model_for_kbit_training,
#     get_peft_model,
)
from datasets import load_dataset, Dataset
# from trl import SFTTrainer, setup_chat_format

# import wandb

import polars as pl
# import pandas as pd

from transformers.pipelines.pt_utils import KeyDataset

In [2]:
def create_prompt(review):
    system_prompt = f"You read Yelp reviews and return a number (1, 2, 3, 4, or 5) that represents your besst guess of the number of star ratings that were given by that reviewer. Return just the number 1, 2, 3, 4, or 5, with no context, explanation, or special symbols."
    prompt = f"Here is the review to evaluate: [[[{review}]]]. Remember, you read Yelp reviews and return a number (1, 2, 3, 4, or 5) that represents your besst guess of the number of star ratings that were given by that reviewer. Return just the number 1, 2, 3, 4, or 5, with no context, explanation, or special symbols."
        
    return system_prompt, prompt

In [3]:
# df_train = pl.read_csv("../data/1_train_test_split/df_train.csv")
df_val = pl.read_csv("../data/1_train_test_split/df_validation.csv")

In [4]:
lst_system_prompt, lst_prompt = [], []
for row in df_val.iter_rows(named=True):
    system_prompt, prompt = create_prompt(row["text"])
    lst_system_prompt.append(system_prompt)
    lst_prompt.append(prompt)
df_val = df_val.with_columns(pl.Series(lst_system_prompt).alias("system_prompt"), pl.Series(lst_prompt).alias("prompt"))

In [5]:
test_texts = df_val["text"].to_list()
test_labels = df_val["stars"].to_list()

data_ = Dataset.from_polars(df_val)

In [6]:
# !ls /home/richardarcher/Dropbox/Sci24_LLM_Polarization/project_/weights_local/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659

In [7]:
base_model = "/home/richardarcher/Dropbox/Sci24_LLM_Polarization/project_/weights_local/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659"

# HERE

In [8]:
# PATH_adapter_custom_weights = "../weights/sft/run00/checkpoint-1000/"
# PATH_adapter_custom_weights = "../weights/sft/run01/checkpoint-1000/"
# PATH_adapter_custom_weights = "../weights/sft/run01/checkpoint-10000/"
# PATH_adapter_custom_weights = "../weights/sft/run01/checkpoint-20000/"
PATH_adapter_custom_weights = "../weights/sft/run01/checkpoint-29000/"

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
    base_model,
    tokenizer_file=os.path.join(base_model, 'tokenizer.json'),
    tokenizer_config_file=os.path.join(base_model, 'tokenizer_config.json'),
    special_tokens_map_file=os.path.join(base_model, 'special_tokens_map.json'),
    trust_remote_code=True,
    padding_side='left'
)

tokenizer.padding_side = 'left'

In [10]:
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # load_in_8bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,  # Match input dtype

)

model = LlamaForCausalLM.from_pretrained(base_model, quantization_config=nf4_config)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [11]:
model = PeftModel.from_pretrained(model, PATH_adapter_custom_weights)
model = model.merge_and_unload() # This line merges the weights



In [12]:
if not tokenizer.pad_token_id:
    tokenizer.pad_token_id = tokenizer.eos_token_id
if model.config.pad_token_id is None:
    model.config.pad_token_id = model.config.eos_token_id

In [13]:
def remove_header(text, K_times):
    for _ in range(K_times):
        if "<|end_header_id|>" in text:
            text = text.split("<|end_header_id|>", 1)[1]
    return text

In [14]:
def create_format_chat_template(tokenizer):
    def format_chat_template(row):
        row_json = [{"role": "system", "content": row["system_prompt"]},
                    {"role": "user", "content": row["prompt"]}]

        # row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
        row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False, add_generation_prompt=True)
        return row
    return format_chat_template

In [15]:
batch_size = 8

In [16]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    # torch_dtype=torch.float32,
    torch_dtype=torch.float16,
    device_map="auto",
    batch_size=batch_size, # CHANGE TO FOUR IF TOO SLOW
    max_new_tokens=5,
)

In [17]:
data_ = data_.map(
    create_format_chat_template(tokenizer)
)

Map:   0%|          | 0/1018 [00:00<?, ? examples/s]

In [18]:
res = []
ix = 0 
for out in pipe(KeyDataset(data_, "text")):
    ix = ix + 1
    # print(ix)
    if ix % batch_size == 0:
        print(f"{ix}/{data_.shape[0]}")
    
    cleaned_text = remove_header(out[0]["generated_text"], 3).strip()
    res.append(cleaned_text)

8/1018
16/1018
24/1018
32/1018
40/1018
48/1018
56/1018
64/1018
72/1018
80/1018
88/1018
96/1018
104/1018
112/1018
120/1018
128/1018
136/1018
144/1018
152/1018
160/1018
168/1018
176/1018
184/1018
192/1018
200/1018
208/1018
216/1018
224/1018
232/1018
240/1018
248/1018
256/1018
264/1018
272/1018
280/1018
288/1018
296/1018
304/1018
312/1018
320/1018
328/1018
336/1018
344/1018
352/1018
360/1018
368/1018
376/1018
384/1018
392/1018
400/1018
408/1018
416/1018
424/1018
432/1018
440/1018
448/1018
456/1018
464/1018
472/1018
480/1018
488/1018
496/1018
504/1018
512/1018
520/1018
528/1018
536/1018
544/1018
552/1018
560/1018
568/1018
576/1018
584/1018
592/1018
600/1018
608/1018
616/1018
624/1018
632/1018
640/1018
648/1018
656/1018
664/1018
672/1018
680/1018
688/1018
696/1018
704/1018
712/1018
720/1018
728/1018
736/1018
744/1018
752/1018
760/1018
768/1018
776/1018
784/1018
792/1018
800/1018
808/1018
816/1018
824/1018
832/1018
840/1018
848/1018
856/1018
864/1018
872/1018
880/1018
888/1018
896/1018
904/1

In [19]:
res_int = [int(i) for i in res]

In [20]:
right, total = 0, 0
for pred, actual in zip(res_int, test_labels):
    if pred==actual:
        right += 1
    total += 1

print(right/total)

0.7416502946954814


In [21]:
df_val = df_val.with_columns(pl.Series(res_int).alias("8b_quant_prediction"))

In [22]:
df_val.write_csv("../data/3_outputs/8b_quantized_predictions_for_eval_set_check20000.csv")