In [1]:
import warnings
import torch 
warnings.filterwarnings('ignore')
torch.cuda.empty_cache()
torch.cuda.device_count()

3

In [2]:
import os
import torch
from datasets import load_dataset, load_from_disk
from typing import List, Dict
from datasets import Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
    LlamaForCausalLM,
    LlamaTokenizer,
)
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, get_peft_model, TaskType
import evaluate
from typing import List, Dict

train_df = load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_train")
test_df = load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_test")
validate_df = load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_validate")

In [3]:
def flatten_table(table: Dict) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    
    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        flattened_rows.append("## " + row_text)

    flattened_table = " ".join(flattened_rows)
    return flattened_table

def generate_validate_prompt(examples):
    table = examples['table']
    query = examples['query']
    summary = examples['summary']
    table_title = table['title']
    system_prompt = "You are a helpful, respectful and honest assistant. Below is an instruction that describes a query-focused summarization task. Write a summary that appropriately response to the user query."
    
    task = "Using the information from the table, generate a paragraph-long summary to response to the following user query:"

    
    flattened_table = flatten_table(table)
    input_text = f"Table Title: {table_title}\n{flattened_table}\n{task}\nQuery: {query}\n\nSummary:\n"
    prompt = f"""<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{input_text} [/INST]"""
    #prompt = f"{system_prompt}\n{input_text}"
    return prompt

In [4]:
prompt = generate_validate_prompt(validate_df[1])
print(prompt)

<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Below is an instruction that describes a query-focused summarization task. Write a summary that appropriately response to the user query.
<</SYS>>
Table Title: Swiss Locomotive And Machine Works
## Row 0, Built:1895,Number:1,Type:Mountain Railway Rack Steam Locomotive,Slm Number:923,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway ## Row 1, Built:1895,Number:2,Type:Mountain Railway Rack Steam Locomotive,Slm Number:924,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway ## Row 2, Built:1895,Number:3,Type:Mountain Railway Rack Steam Locomotive,Slm Number:925,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway ## Row 3, Built:1896,Number:4,Type:Mountain Railway Rack Steam Locomotive,Slm Number:988,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway ## Row 4, Built:1896,Number:5,Type:Mountain Railway Rack Steam Locomotive,Slm Number:989,Wheel Arrangement:0 - 4 - 2 T,Loca

In [5]:
model_dir = "abacusai/Smaug-72B-v0.1"
cache_dir='smaug-cache'

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_dir,
                                        token="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",
                                        quantization_config=nf4_config,
                                        device_map="auto",
                                        cache_dir=cache_dir
                                        )
tokenizer = AutoTokenizer.from_pretrained(model_dir, 
                                           token="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",
                                           trust_remote_code=True, 
                                           cache_dir=cache_dir
                                          )

Loading checkpoint shards: 100%|██████████| 30/30 [05:26<00:00, 10.89s/it]


In [6]:
generate_text = pipeline(
    model=model, tokenizer=tokenizer,
    return_full_text=False,  
    task="text-generation",
    temperature=0.001,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    do_sample=True,
    top_k=20,
    max_new_tokens=400,  # max number of tokens to generate in the output
    repetition_penalty=1.1  # if output begins repeating increase
)

In [7]:
generated_summary = []

In [8]:
from tqdm import tqdm

for i in tqdm(range(200)):
    prompt = generate_validate_prompt(validate_df[i])
    res = generate_text(prompt)
    generated_summary.append(res[0]["generated_text"])

100%|██████████| 200/200 [1:01:39<00:00, 18.50s/it]


In [9]:
predicted_summary = []

In [10]:
validate = load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/data/validate")
for i in tqdm(range(200)):
    prompt = generate_validate_prompt(validate[i])
    res = generate_text(prompt)
    predicted_summary.append(res[0]["generated_text"])

100%|██████████| 200/200 [1:06:44<00:00, 20.02s/it]


In [11]:
import numpy as np
rougeL = []
bert = []
bertscore = evaluate.load("bertscore")
rougescore = evaluate.load("rouge")

bert_score = bertscore.compute(predictions=generated_summary, references=validate_df['summary'], lang = "en")
rouge_score = rougescore.compute(predictions=generated_summary, references=validate_df['summary'])
print(rouge_score, bert_score)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'rouge1': 0.4639356730374806, 'rouge2': 0.21625065684295014, 'rougeL': 0.33671618906916884, 'rougeLsum': 0.3369727012249729} {'precision': [0.9195883274078369, 0.9216462969779968, 0.8823371529579163, 0.8430105447769165, 0.907228946685791, 0.9630222916603088, 0.8906835317611694, 0.8986313343048096, 0.8711152076721191, 0.9169199466705322, 0.8346363306045532, 0.9277668595314026, 0.8855714201927185, 0.8789509534835815, 0.8557708263397217, 0.9388769865036011, 0.9087942838668823, 0.8560540080070496, 0.8907759189605713, 0.9131604433059692, 0.928785502910614, 0.9318164587020874, 0.8918784260749817, 0.9505659341812134, 0.8600481748580933, 0.935197651386261, 0.9470973014831543, 0.8832494616508484, 0.8809840679168701, 0.9033607244491577, 0.8334094285964966, 0.862772524356842, 0.8951584100723267, 0.8088887929916382, 0.895728349685669, 0.8742184638977051, 0.9235439300537109, 0.9346218109130859, 0.8845429420471191, 0.8767161965370178, 0.9346142411231995, 0.9067267179489136, 0.8625731468200684, 0.92

In [12]:
import numpy as np
rougeL = []
bert = []
bertscore = evaluate.load("bertscore")
rougescore = evaluate.load("rouge")

bert_score = bertscore.compute(predictions=predicted_summary, references=validate['summary'], lang = "en")
rouge_score = rougescore.compute(predictions=predicted_summary, references=validate['summary'])
print(rouge_score, bert_score)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'rouge1': 0.49010609919109493, 'rouge2': 0.2333976706539606, 'rougeL': 0.35744393042354217, 'rougeLsum': 0.3572235450808475} {'precision': [0.9313572645187378, 0.9072473049163818, 0.9084168076515198, 0.8732856512069702, 0.8348760604858398, 0.9630222320556641, 0.9059234261512756, 0.8986313343048096, 0.8775043487548828, 0.934466540813446, 0.8557215332984924, 0.9277669191360474, 0.8942205309867859, 0.8922761678695679, 0.8637212514877319, 0.9388770461082458, 0.9091837406158447, 0.8900796175003052, 0.8842190504074097, 0.9013169407844543, 0.928785502910614, 0.890892744064331, 0.874093234539032, 0.9505659341812134, 0.8939483761787415, 0.9039280414581299, 0.9131155014038086, 0.9033163785934448, 0.9208687543869019, 0.8964378833770752, 0.8639161586761475, 0.8713046312332153, 0.9062983989715576, 0.8219909071922302, 0.9212489724159241, 0.8666428923606873, 0.8412999510765076, 0.9319933652877808, 0.904621958732605, 0.9076036214828491, 0.9478894472122192, 0.9067267179489136, 0.8771636486053467, 0.90