In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch
from datasets import load_dataset
from trl import SFTTrainer

In [None]:
# base_model = "microsoft/phi-2"
base_model = "phi-2-chartSummarization/checkpoint-479640"

In [None]:
import pandas as pd
test_data = pd.read_csv("./test.csv")

In [None]:
# Load base model(Phi-2)
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map=0,
    trust_remote_code=True,
)

model.config.use_cache = False
model.config.pretraining_tp = 1


# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        'q_proj',
        'k_proj',
        'v_proj',
        'dense',
        'fc1',
        'fc2',
    ]
)
model = get_peft_model(model, peft_config)

In [None]:
from tqdm import tqdm

In [None]:
pipe = pipeline(task="text-generation", model = model, tokenizer=tokenizer, max_length=256)

In [None]:
with open("pred_summary_phi2.txt", "a") as file:
    for i in tqdm(0, len(test_data)):
        prompt = "<|text| "
        prompt += test_data.iloc[i]['text']
        prompt += " <|summary|>"

        try:
            result = pipe(prompt)
            result = result[0]['generated_text'].split("<|summary|>")[1]
            result = ' '.join(result.splitlines())
            file.write(result + "\n")
        except:
            result = ""
            file.write(result + "\n")