In [1]:
import warnings
import gc
warnings.filterwarnings('ignore')

In [2]:
import torch
import evaluate
import pandas as pd
from tqdm import tqdm
from peft import PeftModel, PeftConfig
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tqdm.pandas()
gc.collect()
torch.manual_seed(42)

<torch._C.Generator at 0x1a069a62730>

In [3]:
full_data_test = pd.read_csv('../dataset/full_test_data_summarization.csv')

In [3]:
checkpoint = './model_checkpoint/checkpoint-1576'

In [None]:
config = PeftConfig.from_pretrained(checkpoint)
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    device_map={"":0},
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    quantization_config=BitsAndBytesConfig(    
        load_in_4bit=True,
        load_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.bfloat16
    )
)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
model = PeftModel.from_pretrained(base_model, checkpoint, device_map={"":0})
model.eval()

In [8]:
def create_prompt(sample):
    template = """<s>[INST] Bạn là một trợ lý AI. Bạn sẽ được giao một nhiệm vụ. Hãy tóm lược ngắn gọn nội dung sau bằng tiếng Việt:
{} [/INST] """
    prompt = template.format(sample)
    return prompt

In [8]:
references = []
predictions = []
def generate_text():
  for batch_1, batch_2 in tqdm(zip(torch.utils.data.DataLoader(full_data_test['context'], batch_size=16, shuffle=False), torch.utils.data.DataLoader(full_data_test['summarization'], batch_size=16, shuffle=False), strict=True), total=int(round(len(full_data_test)/16, 0))):
    prompts = [create_prompt(context) for context in batch_1]
    inputs = tokenizer(prompts, max_length=4096, truncation=True, padding=False, add_special_tokens=True, return_tensors="pt").to(device)
    outputs = model.generate(
      **inputs,
      early_stopping=False,
      max_new_tokens=1024,
      temperature=0.7,
      top_p=0.9,
      top_k=50,
      repetition_penalty=1.2,
      penalty_alpha=0.6,
      pad_token_id=tokenizer.eos_token_id
    )
    references.extend(batch_2)
    predictions.extend([tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False).split('[/INST]')[1].strip() for output in outputs])
    torch.cuda.empty_cache()

In [None]:
generate_text()

In [None]:
full_data_test['summarization_predictions'] = full_data_test['context'].progress_apply(lambda x: generate_text(x))

In [None]:
full_data_test['summarization_predictions'] = predictions

In [None]:
rouge_metric = evaluate.load("rouge")
rouge_scores = rouge_metric.compute(references=full_data_test['summarization'].tolist(), predictions=full_data_test['summarization_predictions'].tolist(), use_stemmer=True, rouge_types=['rouge1', 'rouge2', 'rougeL'])

In [11]:
rouge_scores

{'rouge1': 0.6073709993163005,
 'rouge2': 0.3535813077835422,
 'rougeL': 0.4224855441941776}

In [12]:
full_data_test.to_csv('test_mistral_lora.csv', index=False)

In [3]:
tmp_test_1 = pd.read_csv('./test_mistral_lora.csv')
tmp_test_2 = pd.read_csv('./validation_mistral_lora.csv')

In [13]:
tmp = pd.concat([tmp_test_1, tmp_test_2], axis=0)
tmp.reset_index(drop=True, inplace=True)

In [4]:
tmp1 = pd.concat([tmp_test_1[0:1179], tmp_test_2[0:590]], axis=0)
tmp1.reset_index(drop=True, inplace=True)

In [5]:
# VLSP dataset
tmp2 = pd.concat([tmp_test_1[1179:1427], tmp_test_2[590:673]], axis=0)
tmp2.reset_index(drop=True, inplace=True)

In [6]:
# Wikilingua
tmp3= pd.concat([tmp_test_1[1427:4095], tmp_test_2[673:1563]], axis=0)
tmp3.reset_index(drop=True, inplace=True)

In [7]:
tmp4 = pd.concat([tmp_test_1[4095:], tmp_test_2[1563:]], axis=0)
tmp4.reset_index(drop=True, inplace=True)

In [8]:
# News dataset
tmp5 = pd.concat([tmp1, tmp4], axis=0)
tmp5.reset_index(drop=True, inplace=True)

In [9]:
rouge_metric = evaluate.load("rouge")

In [10]:
rouge_metric.compute(references=tmp2['summarization'].tolist(), predictions=tmp2['summarization_predictions'].tolist(), use_stemmer=True, rouge_types=['rouge1', 'rouge2', 'rougeL'])

{'rouge1': 0.5556730900726031,
 'rouge2': 0.2461537974616837,
 'rougeL': 0.33873741074411445}

In [11]:
rouge_metric.compute(references=tmp3['summarization'].tolist(), predictions=tmp3['summarization_predictions'].tolist(), use_stemmer=True, rouge_types=['rouge1', 'rouge2', 'rougeL'])

{'rouge1': 0.5576205043038466,
 'rouge2': 0.28404084704064686,
 'rougeL': 0.3909748540769764}

In [12]:
rouge_metric.compute(references=tmp5['summarization'].tolist(), predictions=tmp5['summarization_predictions'].tolist(), use_stemmer=True, rouge_types=['rouge1', 'rouge2', 'rougeL'])

{'rouge1': 0.654149485144364,
 'rouge2': 0.4175033018006436,
 'rougeL': 0.4538408078557469}

In [15]:
rouge_metric.compute(references=tmp['summarization'].tolist(), predictions=tmp['summarization_predictions'].tolist(), use_stemmer=True, rouge_types=['rouge1', 'rouge2', 'rougeL'])

{'rouge1': 0.608391039289374,
 'rouge2': 0.3527924081481474,
 'rougeL': 0.42188942920528405}