In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
torch.cuda.empty_cache()
torch.cuda.device_count()

3

In [3]:
from sklearn.model_selection import KFold
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import datasets
import pandas as pd
import os
import logging
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from random import sample


train_df = datasets.load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/question_answered/train_with_answer")
test_df = datasets.load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/question_answered/test_with_answer")
validate_df = datasets.load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/question_answered/validate_with_answer")

In [4]:
model_path = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [5]:
from typing import List, Dict

def tokenization_with_answer(examples):
    inputs = []
    targets = []
    
    task_prefix = "Given a query and a table, generate a summary that answers the query based on the information in the table: "

    for i, (query, table, answer, coordinates, summary) in enumerate(zip(examples['query'], examples['table'], examples['answers'], examples['coordinates'], examples['summary'])):
        flattened_table = flatten_table(table, i)
        input_text = f"{task_prefix} Table {flattened_table}. Query: {query} Potential answer: "
        
        # Get the header names
        header = table.get('header', [])

        # Append row and column names for each coordinate
        for coordinate in coordinates:
            row_idx, col_idx = coordinate
            row_name = f"Row {row_idx}"
            col_name = header[col_idx]
            input_text += f"{row_name}, {col_name} | "

        inputs.append(input_text[:-2])

        targets.append(summary)
        
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True,padding='max_length')
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=512, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"] 

    res = tokenizer(inputs, text_target=targets, truncation=True, padding=True)
    return model_inputs

def flatten_table(table: Dict, row_index: int) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    title = table.get('title', [])

    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        flattened_rows.append("## "+row_text)

    flattened_table = f"Title: {' '.join(map(str, title))}" + " " + " ".join(flattened_rows)
    return flattened_table

tokenized_dataset_train = train_df.map(tokenization_with_answer, batched=True)
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)

processed_data_train = tokenized_dataset_train.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

Map: 100%|██████████| 500/500 [00:01<00:00, 287.10 examples/s]


In [6]:
def k_fold_split(dataset, num_folds=5):
    fold_size = len(dataset) // num_folds
    folds = []
    for i in range(num_folds):
        start = i * fold_size
        end = start + fold_size if i < num_folds - 1 else len(dataset)
        folds.append(dataset.select(range(start, end)))
    return folds

In [8]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate

def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

    rouge = evaluate.load('rouge')

    # Compute ROUGE scores
    rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)

    return rouge_results

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model= model)

train_args = Seq2SeqTrainingArguments(
    output_dir="./train_weights_flan_answer",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    num_train_epochs=20,
    evaluation_strategy="epoch",
    save_strategy = "epoch",
    weight_decay=0.01,
    save_total_limit=5,
    warmup_ratio=0.03,
    load_best_model_at_end=True,
    predict_with_generate=True,
    overwrite_output_dir= True
)

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=processed_data_train,
    eval_dataset=processed_data_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_fn
)

In [11]:
folds = k_fold_split(train_df, num_folds=10)

for i in range(len(folds)):
    val_fold = folds[i]
    train_folds = [folds[j] for j in range(len(folds)) if j != i]
    train_dataset = concatenate_datasets(train_folds)

    tokenized_train = train_dataset.map(tokenization_with_answer, batched=True)
    tokenized_val = val_fold.map(tokenization_with_answer, batched=True)

    # Remove unnecessary columns
    processed_train = tokenized_train.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
    processed_val = tokenized_val.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

    # Update your trainer's train_dataset and eval_dataset
    trainer.train_dataset = processed_train
    trainer.eval_dataset = processed_val

    # Train your model
    trainer.train()
    trainer.evaluate()

Map: 100%|██████████| 1600/1600 [00:05<00:00, 306.61 examples/s]
Map: 100%|██████████| 400/400 [00:01<00:00, 327.69 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.618289,0.25437,0.124666,0.2106,0.228948
2,No log,2.581428,0.262828,0.130392,0.220403,0.237117
3,No log,2.556853,0.266105,0.133399,0.223644,0.24083
4,No log,2.539484,0.269337,0.135087,0.225102,0.241243
5,No log,2.526888,0.268227,0.133962,0.224918,0.240426
6,No log,2.517482,0.268164,0.134784,0.225554,0.24042
7,No log,2.510343,0.267518,0.134848,0.225212,0.239706
8,No log,2.505365,0.266503,0.134211,0.224013,0.237883
9,No log,2.502565,0.266286,0.134048,0.224172,0.237551
10,No log,2.501655,0.266158,0.133912,0.223672,0.237705


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:04<00:00, 340.31 examples/s]
Map: 100%|██████████| 400/400 [00:01<00:00, 334.35 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.453408,0.261614,0.133144,0.219999,0.236413
2,No log,2.442244,0.262958,0.133671,0.221529,0.237981
3,No log,2.433028,0.265959,0.135093,0.223472,0.23993
4,No log,2.42443,0.267618,0.136415,0.22479,0.241049
5,No log,2.419046,0.267264,0.1369,0.225006,0.240635
6,No log,2.414053,0.267229,0.136766,0.225165,0.240774
7,No log,2.410739,0.267668,0.137388,0.224898,0.241273
8,No log,2.408202,0.267389,0.13756,0.224895,0.241315
9,No log,2.406666,0.267476,0.138023,0.225366,0.241381
10,No log,2.406191,0.267406,0.138102,0.225589,0.241379


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:04<00:00, 330.84 examples/s]
Map: 100%|██████████| 400/400 [00:01<00:00, 336.81 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.181035,0.286924,0.160675,0.246306,0.26334
2,No log,2.175056,0.28701,0.160776,0.245423,0.262726
3,No log,2.169339,0.287421,0.160695,0.244862,0.262331
4,No log,2.164569,0.286005,0.160258,0.243635,0.261413
5,No log,2.160555,0.286578,0.161387,0.244143,0.262084
6,No log,2.157568,0.286091,0.161049,0.243686,0.261725
7,No log,2.155209,0.286411,0.161016,0.244082,0.26223
8,No log,2.153767,0.286321,0.160996,0.24412,0.262184
9,No log,2.152323,0.286973,0.161355,0.245323,0.262809
10,No log,2.151888,0.286781,0.161265,0.245355,0.262874


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:04<00:00, 327.87 examples/s]
Map: 100%|██████████| 400/400 [00:01<00:00, 329.21 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.271812,0.274334,0.144917,0.230656,0.248517
2,No log,2.266146,0.274655,0.144873,0.231075,0.249649
3,No log,2.26194,0.274797,0.145698,0.230783,0.249498
4,No log,2.257811,0.273551,0.145466,0.230249,0.248293
5,No log,2.255082,0.273716,0.145741,0.229383,0.248282
6,No log,2.253033,0.27368,0.14581,0.229592,0.24821
7,No log,2.25146,0.274253,0.145898,0.230236,0.248794
8,No log,2.250434,0.273321,0.14561,0.22934,0.247821
9,No log,2.249604,0.273827,0.145723,0.22945,0.247954
10,No log,2.249302,0.274097,0.145723,0.22963,0.248066


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:06<00:00, 247.49 examples/s]
Map: 100%|██████████| 400/400 [00:02<00:00, 176.62 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.117867,0.278674,0.150166,0.23965,0.252323
2,No log,2.113866,0.277603,0.147211,0.237857,0.250928
3,No log,2.111988,0.278287,0.148594,0.238531,0.251415
4,No log,2.109615,0.278875,0.149805,0.239923,0.252185
5,No log,2.106859,0.27928,0.149646,0.239908,0.252453
6,No log,2.105979,0.279668,0.149935,0.23991,0.253043
7,No log,2.104753,0.280305,0.150484,0.240014,0.253541
8,No log,2.103471,0.281162,0.151202,0.240771,0.25424
9,No log,2.102698,0.28198,0.15187,0.240875,0.254359
10,No log,2.102501,0.28198,0.15187,0.240875,0.254359


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


In [None]:
model.save_pretrained("Flan-Answer")
tokenizer.save_pretrained("Flan-Answer")