In [1]:
import warnings
import torch 
warnings.filterwarnings('ignore')
torch.cuda.empty_cache()
torch.cuda.device_count()

3

In [2]:
import os
import torch
from datasets import load_dataset, load_from_disk
from typing import List, Dict
from datasets import Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
    LlamaForCausalLM,
    LlamaTokenizer,
)
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, get_peft_model, TaskType
import evaluate
from typing import List, Dict

train_df = load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/data/train")
test_df = load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/data/test")
validate_df = load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/data/validate")

In [3]:
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)
cache_dir='./llama3-70B_cache'
model_dir = "meta-llama/Meta-Llama-3-70B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    quantization_config=nf4_config,
    token="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",
    device_map="auto",
    cache_dir=cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, 
                                          token="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",
                                          cache_dir=cache_dir
                                         )

Loading checkpoint shards: 100%|██████████| 30/30 [04:17<00:00,  8.59s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [17]:
def flatten_table(table: Dict) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    
    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        flattened_rows.append("## " + row_text)

    flattened_table = " ".join(flattened_rows)
    return flattened_table

task = f"""Given a table and a user query, your task is to extract only the columns that are essential for summarizing the information relevant to the query. The summarization should capture the key insights and statistics related to the user's information need.
To achieve this, you will need to:
1. Understand the user's query and identify the main entities, attributes, and analysis that the user is interested in.
2. Analyze the column names, data types, and sample values to determine which columns contain the core entities and numeric/categorical data relevant for summarization.
3. Exclude any columns that seem peripheral or redundant for the purposes of summarizing the key information requested.
4. Output a comma-separated list of only the column names that are essential for summarizing the query.
5. If no columns seem directly relevant for summarizing the user's query, output a message stating that the table does not contain suitable information.

Your response should be a single line listing the relevant column names, or a statement that no columns apply, for example:
"product_name, sales_amount, region, quarter"
or 
"This table does not have columns relevant for summarizing [query]"
I'll provide the table and query next, and you can demonstrate extracting the essential columns for summarization."""

def generate_validate_prompt(examples):
    table = examples['table']
    query = examples['query']
    summary = examples['summary']
    table_title = table['title']
    system_prompt = "You are a helpful, respectful and honest assistant. Below is an instruction that describes a query-focused summarization task. Write a summary that appropriately response to the user query."
    
    task = "Using the information from the above table, give the name of relevant columns in the given table that support or oppose the following user query."

    
    flattened_table = flatten_table(table)
    input_text = f"Table Title: {table_title}\n{flattened_table}\n{task}\nQuery: {query}\n\nColumn names:\n"
    prompt = f"""<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{input_text} [/INST]"""
    prompt = input_text
    return prompt

In [18]:
print(generate_validate_prompt(validate_df[1]))

Table Title: Swiss Locomotive And Machine Works
## Row 0, Built:1895,Number:1,Type:Mountain Railway Rack Steam Locomotive,Slm Number:923,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Ladas ## Row 1, Built:1895,Number:2,Type:Mountain Railway Rack Steam Locomotive,Slm Number:924,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Enid ## Row 2, Built:1895,Number:3,Type:Mountain Railway Rack Steam Locomotive,Slm Number:925,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Wyddfa ## Row 3, Built:1896,Number:4,Type:Mountain Railway Rack Steam Locomotive,Slm Number:988,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Snowdon ## Row 4, Built:1896,Number:5,Type:Mountain Railway Rack Steam Locomotive,Slm Number:989,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Moel Siabod ## Row 5, Built:1922,Number:6,Type:Mountain Railway Rack Steam Locomotive,Slm Number:2838,Wheel Arrangement:0 - 4 - 2 T

In [5]:
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)
cache_dir='./llama3-70B_cache'
model_dir = "meta-llama/Meta-Llama-3-70B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    quantization_config=nf4_config,
    token="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",
    device_map="auto",
    cache_dir=cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, 
                                          token="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",
                                          cache_dir=cache_dir
                                         )

Loading checkpoint shards: 100%|██████████| 30/30 [05:05<00:00, 10.17s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
prompt = generate_validate_prompt(validate_df[1])
print(prompt)

Table Title: Swiss Locomotive And Machine Works
## Row 0, Built:1895,Number:1,Type:Mountain Railway Rack Steam Locomotive,Slm Number:923,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Ladas ## Row 1, Built:1895,Number:2,Type:Mountain Railway Rack Steam Locomotive,Slm Number:924,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Enid ## Row 2, Built:1895,Number:3,Type:Mountain Railway Rack Steam Locomotive,Slm Number:925,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Wyddfa ## Row 3, Built:1896,Number:4,Type:Mountain Railway Rack Steam Locomotive,Slm Number:988,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Snowdon ## Row 4, Built:1896,Number:5,Type:Mountain Railway Rack Steam Locomotive,Slm Number:989,Wheel Arrangement:0 - 4 - 2 T,Location:Snowdon Mountain Railway,Notes:Moel Siabod ## Row 5, Built:1922,Number:6,Type:Mountain Railway Rack Steam Locomotive,Slm Number:2838,Wheel Arrangement:0 - 4 - 2 T

In [7]:
pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [20]:
from tqdm import tqdm
output_summary = []
for i in tqdm(range(200)):
    prompt = generate_validate_prompt(validate_df[i])
    messages = [
        {"role": "system", "content": "You are a helpful, respectful and honest assistant. Below is an instruction that describes a query-based table decomposition task"},
        {"role": "user", "content": prompt},
    ]

    prompt = pipeline.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
    )

    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = pipeline(
        prompt,
        max_new_tokens=400,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.0001,
        top_k=10,
        num_return_sequences=1,
    )
    output_summary.append(outputs[0]["generated_text"][len(prompt):])
    if (i == 4) :
        break

  0%|          | 0/200 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 1/200 [00:26<1:27:55, 26.51s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  1%|          | 2/200 [00:39<1:00:16, 18.26s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  2%|▏         | 3/200 [00:54<55:58, 17.05s/it]  Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  2%|▏         | 4/200 [01:13<58:27, 17.89s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  2%|▏         | 4/200 [01:23<1:08:20, 20.92s/it]


In [21]:
for column in output_summary:
    print (column)

Based on the provided table, the relevant columns that support the user query are:

* `Player`
* `Points`

These columns are necessary to identify the players who scored more than 600 points in the 2008-09 Connecticut Huskies Women's Basketball Team.

After analyzing the table, I found that there are two players who scored more than 600 points:

1. **Moore Maya Moore**: She scored 754 points.
2. **Montgomery Renee Montgomery**: She scored 644 points. (Although she didn't exactly score more than 600 points, she came close with 644 points.)

Note that there are no players who scored exactly more than 600 points, but these two players are the closest to achieving that milestone.
Based on the provided table, the relevant column names that support the user query are:

* Built
* Slm Number
* Number
* Type
* Wheel Arrangement
* Location
* Notes

These columns provide the basic information about the locomotive(s) built by Swiss Locomotive and Machine Works with slm number 988.
To support the u