In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
from sklearn.model_selection import KFold
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import datasets
import pandas as pd
import os
import logging
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from random import sample


train_df = datasets.load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/data/train")
test_df = datasets.load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/data/test")
validate_df = datasets.load_from_disk("/home/y.khan/cai6307-y.khan/Query-Focused-Tabular-Summarization/data/data/validate")

In [4]:
model_path = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [5]:
from typing import List, Dict

def tokenization_with_answer(examples):
    inputs = []
    targets = []
    
    task_prefix = "Given a query and a table, generate a summary that answers the query based on the information in the table: "

    for i, (query, table, answer, coordinates, summary) in enumerate(zip(examples['query'], examples['table'], examples['answers'], examples['coordinates'], examples['summary'])):
        flattened_table = flatten_table(table, i)
        input_text = f"{task_prefix} Table {flattened_table}. Query: {query}"

        inputs.append(input_text)

        targets.append(summary)
        
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True,padding='max_length')
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=512, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"] 

    res = tokenizer(inputs, text_target=targets, truncation=True, padding=True)
    return model_inputs

def flatten_table(table: Dict, row_index: int) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    title = table.get('title', [])

    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        flattened_rows.append("## "+row_text)

    flattened_table = f"Title: {' '.join(map(str, title))}" + " " + " ".join(flattened_rows)
    return flattened_table

tokenized_dataset_train = train_df.map(tokenization_with_answer, batched=True)
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)

processed_data_train = tokenized_dataset_train.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

In [6]:
def k_fold_split(dataset, num_folds=5):
    fold_size = len(dataset) // num_folds
    folds = []
    for i in range(num_folds):
        start = i * fold_size
        end = start + fold_size if i < num_folds - 1 else len(dataset)
        folds.append(dataset.select(range(start, end)))
    return folds

In [7]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate

def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

    rouge = evaluate.load('rouge')

    # Compute ROUGE scores
    rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)

    return rouge_results

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model= model)

train_args = Seq2SeqTrainingArguments(
    output_dir="./train_weights_flan",
    learning_rate=4e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    num_train_epochs=20,
    evaluation_strategy="epoch",
    save_strategy = "epoch",
    weight_decay=0.01,
    save_total_limit=5,
    warmup_ratio=0.05,
    load_best_model_at_end=True,
    predict_with_generate=True,
    overwrite_output_dir= True
)

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=processed_data_train,
    eval_dataset=processed_data_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_fn
)

2024-03-31 00:06:00.631488: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-31 00:06:35.660803: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/slurm/lib64:/opt/slurm/lib64:
2024-03-31 00:06:35.660851: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2024-03-31 00:06:39.342784: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-31 00:0

In [8]:
folds = k_fold_split(train_df, num_folds=10)

for i in range(len(folds)):
    val_fold = folds[i]
    train_folds = [folds[j] for j in range(len(folds)) if j != i]
    train_dataset = concatenate_datasets(train_folds)

    tokenized_train = train_dataset.map(tokenization_with_answer, batched=True)
    tokenized_val = val_fold.map(tokenization_with_answer, batched=True)

    # Remove unnecessary columns
    processed_train = tokenized_train.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
    processed_val = tokenized_val.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

    # Update your trainer's train_dataset and eval_dataset
    trainer.train_dataset = processed_train
    trainer.eval_dataset = processed_val

    # Train your model
    trainer.train()
    trainer.evaluate()

Map: 100%|██████████| 1600/1600 [00:12<00:00, 125.36 examples/s]
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.675925,0.231667,0.111879,0.19326,0.208033
2,No log,2.619363,0.255414,0.126686,0.21456,0.231322
3,No log,2.590219,0.260424,0.129248,0.219548,0.23599
4,No log,2.570375,0.261291,0.130606,0.219762,0.235988
5,No log,2.556111,0.26154,0.131276,0.220525,0.236998
6,No log,2.546391,0.262669,0.132502,0.221296,0.237736
7,No log,2.53927,0.263719,0.133289,0.222642,0.238432
8,No log,2.534162,0.263373,0.133416,0.222379,0.238276
9,No log,2.531149,0.263562,0.133325,0.22267,0.238652
10,No log,2.530097,0.263791,0.133288,0.222684,0.238904


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:05<00:00, 312.52 examples/s]
Map: 100%|██████████| 400/400 [00:01<00:00, 378.26 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.488495,0.258827,0.131527,0.218786,0.234492
2,No log,2.476234,0.257165,0.130927,0.21815,0.233543
3,No log,2.468179,0.259555,0.130894,0.21917,0.235392
4,No log,2.460413,0.260465,0.131071,0.219415,0.236087
5,No log,2.452877,0.260768,0.132062,0.219449,0.236474
6,No log,2.448592,0.262342,0.132677,0.220935,0.238068
7,No log,2.444681,0.262435,0.133284,0.221922,0.238003
8,No log,2.442795,0.262014,0.133213,0.22158,0.237402
9,No log,2.441258,0.262474,0.133182,0.22169,0.237469
10,No log,2.440725,0.263077,0.133504,0.222207,0.237904


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:07<00:00, 218.57 examples/s]
Map: 100%|██████████| 400/400 [00:02<00:00, 191.02 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.218807,0.280233,0.155163,0.241367,0.25701
2,No log,2.211464,0.279708,0.154958,0.24143,0.256669
3,No log,2.206455,0.278935,0.154512,0.240615,0.255691
4,No log,2.201962,0.279469,0.154444,0.240906,0.255908
5,No log,2.198446,0.279351,0.153709,0.240748,0.25664
6,No log,2.195461,0.279396,0.154236,0.24076,0.256333
7,No log,2.1931,0.279628,0.154582,0.240898,0.256634
8,No log,2.190982,0.280213,0.154741,0.24113,0.257238
9,No log,2.189837,0.280469,0.155081,0.241552,0.257257
10,No log,2.189398,0.280577,0.155591,0.241272,0.257446


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:08<00:00, 178.33 examples/s]
Map: 100%|██████████| 400/400 [00:05<00:00, 79.04 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.293929,0.270907,0.143296,0.228227,0.245784
2,No log,2.289042,0.269584,0.143143,0.228029,0.244953
3,No log,2.285436,0.269725,0.143729,0.2282,0.245374
4,No log,2.282755,0.269577,0.143637,0.22773,0.243866
5,No log,2.279566,0.26943,0.143908,0.227753,0.24413
6,No log,2.277495,0.268709,0.143877,0.227316,0.243231
7,No log,2.275257,0.268814,0.144116,0.227981,0.244203
8,No log,2.273634,0.268616,0.143625,0.228071,0.244099
9,No log,2.272855,0.268046,0.143527,0.227572,0.243505
10,No log,2.272622,0.268065,0.143303,0.227403,0.243538


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Map: 100%|██████████| 1600/1600 [00:06<00:00, 258.75 examples/s]
Map: 100%|██████████| 400/400 [00:01<00:00, 333.88 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.155701,0.272835,0.144475,0.235408,0.250005
2,No log,2.151587,0.274659,0.146182,0.237634,0.251736
3,No log,2.148744,0.274449,0.146614,0.237266,0.251504
4,No log,2.146787,0.273802,0.146053,0.236548,0.250673
5,No log,2.144011,0.272997,0.144809,0.235036,0.249484
6,No log,2.141752,0.273437,0.145588,0.235403,0.250035
7,No log,2.140942,0.273803,0.145631,0.235668,0.250017
8,No log,2.140033,0.273637,0.145725,0.235806,0.250092
9,No log,2.139505,0.273431,0.146033,0.235857,0.249894
10,No log,2.139302,0.27319,0.146043,0.235743,0.249652


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


In [None]:
model.save_pretrained("Flan")
tokenizer.save_pretrained("Flan")