In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/y.khan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [5]:
import torch
torch.cuda.empty_cache()

In [6]:
from sklearn.model_selection import KFold
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import datasets
import pandas as pd
import os
import logging
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from random import sample


train_df = datasets.load_from_disk("/home/y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_train")
test_df = datasets.load_from_disk("/home/y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_test")
validate_df = datasets.load_from_disk("/home/y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_validate")

In [7]:
model_path = "google-t5/t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [24]:
from typing import List, Dict

def tokenization_with_answer(examples):
    inputs = []
    targets = []
    
    task_prefix = "Given a query and a table, generate a summary that answers the query based on the information in the table: "

    for i, (query, table, answer, coordinates, summary) in enumerate(zip(examples['query'], examples['table'], examples['answers'], examples['coordinates'], examples['summary'])):
        flattened_table = flatten_table(table, i)
        input_text = f"{task_prefix} Table {flattened_table}. Query: {query}"

        inputs.append(input_text)

        targets.append(summary)
        
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True,padding='max_length')
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=512, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"] 

    res = tokenizer(inputs, text_target=targets, truncation=True, padding=True)
    return model_inputs

def flatten_table(table: Dict, row_index: int) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    title = table.get('title', [])

    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        flattened_rows.append("## "+row_text)

    flattened_table = f"Title: {' '.join(map(str, title))}" + " " + " ".join(flattened_rows)
    return flattened_table

tokenized_dataset_train = train_df.map(tokenization_with_answer, batched=True)
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)

processed_data_train = tokenized_dataset_train.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

Map: 100%|██████████| 2000/2000 [00:04<00:00, 451.45 examples/s]
Map: 100%|██████████| 500/500 [00:01<00:00, 438.47 examples/s]


In [11]:
def k_fold_split(dataset, num_folds=5):
    fold_size = len(dataset) // num_folds
    folds = []
    for i in range(num_folds):
        start = i * fold_size
        end = start + fold_size if i < num_folds - 1 else len(dataset)
        folds.append(dataset.select(range(start, end)))
    return folds

In [12]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate

def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

    rouge = evaluate.load('rouge')

    # Compute ROUGE scores
    rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)

    return rouge_results

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model= model)

train_args = Seq2SeqTrainingArguments(
    output_dir="./train_weights_flan",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    evaluation_strategy="epoch",
    save_strategy = "epoch",
    weight_decay=0.01,
    save_total_limit=5,
    warmup_ratio=0.03,
    load_best_model_at_end=True,
    predict_with_generate=True,
    overwrite_output_dir= True
)

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=processed_data_train,
    eval_dataset=processed_data_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_fn
)

2024-03-27 17:03:33.483390: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [13]:
folds = k_fold_split(train_df, num_folds=10)

for i in range(len(folds)):
    val_fold = folds[i]
    train_folds = [folds[j] for j in range(len(folds)) if j != i]
    train_dataset = concatenate_datasets(train_folds)

    tokenized_train = train_dataset.map(tokenization_with_answer, batched=True)
    tokenized_val = val_fold.map(tokenization_with_answer, batched=True)

    # Remove unnecessary columns
    processed_train = tokenized_train.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
    processed_val = tokenized_val.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

    # Update your trainer's train_dataset and eval_dataset
    trainer.train_dataset = processed_train
    trainer.eval_dataset = processed_val

    # Train your model
    trainer.train()
    trainer.evaluate()

Map: 100%|██████████| 1800/1800 [00:04<00:00, 366.96 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 521.78 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,3.470807,0.090623,0.037867,0.08219,0.083153
2,No log,2.80721,0.167339,0.069255,0.14193,0.149478
3,No log,2.672187,0.211538,0.093243,0.176523,0.186869
4,No log,2.594445,0.23508,0.104441,0.195455,0.208912
5,No log,2.543628,0.236238,0.106482,0.195903,0.210327
6,No log,2.515201,0.232623,0.104587,0.194746,0.20705
7,No log,2.486205,0.237209,0.104608,0.198788,0.212192
8,No log,2.466686,0.239346,0.104905,0.199139,0.212536
9,No log,2.448865,0.243551,0.109226,0.20306,0.217794
10,No log,2.438308,0.242191,0.108038,0.201109,0.21505


Map: 100%|██████████| 1800/1800 [00:04<00:00, 438.06 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 445.40 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.405347,0.251307,0.124543,0.211957,0.228256
2,No log,2.395933,0.255187,0.128773,0.214458,0.230545
3,No log,2.387444,0.252399,0.128286,0.213388,0.226792
4,No log,2.38167,0.25354,0.126783,0.215289,0.229166
5,No log,2.3723,0.253322,0.126485,0.213914,0.227866
6,No log,2.367026,0.251857,0.126535,0.214462,0.228042
7,No log,2.361145,0.252129,0.126185,0.214135,0.228301
8,No log,2.354269,0.249929,0.122615,0.209343,0.223205
9,No log,2.351137,0.249147,0.1221,0.209319,0.222345
10,No log,2.347899,0.253831,0.125357,0.214635,0.227809


Map: 100%|██████████| 1800/1800 [00:03<00:00, 463.36 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 425.65 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.959844,0.267355,0.136048,0.222453,0.23812
2,No log,1.955663,0.267856,0.13822,0.223738,0.239382
3,No log,1.956679,0.266495,0.140458,0.223551,0.23876
4,No log,1.953912,0.267554,0.14257,0.224849,0.239994
5,No log,1.950219,0.263838,0.140175,0.22185,0.238519
6,No log,1.948362,0.265663,0.142524,0.225552,0.240618
7,No log,1.947157,0.265139,0.142889,0.225225,0.24022
8,No log,1.946655,0.26751,0.142655,0.22697,0.24189
9,No log,1.94595,0.265872,0.142678,0.223703,0.239678
10,No log,1.94534,0.265658,0.142811,0.224583,0.239958


Map: 100%|██████████| 1800/1800 [00:03<00:00, 459.29 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 421.41 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.943124,0.253177,0.121965,0.214618,0.229132
2,No log,1.944735,0.252925,0.121342,0.21255,0.227818
3,No log,1.947076,0.254003,0.121611,0.213982,0.230865
4,No log,1.945403,0.252894,0.122831,0.213549,0.229507
5,No log,1.944802,0.2526,0.121606,0.212881,0.22901
6,No log,1.945775,0.253312,0.120664,0.212262,0.228755
7,No log,1.946347,0.251283,0.119793,0.211561,0.227506
8,No log,1.945463,0.251179,0.119521,0.210084,0.227349
9,No log,1.946585,0.253075,0.121203,0.211286,0.228884
10,No log,1.9461,0.253296,0.121357,0.211756,0.228707


Map: 100%|██████████| 1800/1800 [00:03<00:00, 462.18 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 415.36 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.647086,0.271039,0.146483,0.23003,0.244653
2,No log,1.648747,0.271487,0.148121,0.229411,0.243833
3,No log,1.650491,0.271958,0.14575,0.228894,0.243586
4,No log,1.653091,0.274247,0.14696,0.229984,0.245771
5,No log,1.654174,0.273455,0.148877,0.228867,0.244321
6,No log,1.655367,0.272576,0.148388,0.228921,0.243606
7,No log,1.656462,0.270722,0.146409,0.227266,0.24381
8,No log,1.657126,0.272678,0.148206,0.229639,0.245507
9,No log,1.656978,0.275497,0.148659,0.229578,0.246317
10,No log,1.6589,0.273356,0.147156,0.229263,0.245594


Map: 100%|██████████| 1800/1800 [00:04<00:00, 417.77 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 443.98 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.422752,0.29,0.171893,0.247394,0.261353
2,No log,1.425373,0.291494,0.175401,0.249482,0.262123
3,No log,1.428516,0.290624,0.17374,0.247352,0.262417
4,No log,1.429762,0.294789,0.177129,0.25038,0.265314
5,No log,1.432192,0.2901,0.172834,0.247585,0.262648
6,No log,1.432874,0.289817,0.173599,0.246784,0.261607
7,No log,1.433792,0.290709,0.173619,0.248285,0.264332
8,No log,1.43715,0.287818,0.17048,0.244988,0.26007
9,No log,1.438319,0.290459,0.173029,0.246041,0.26139
10,No log,1.437448,0.291247,0.173936,0.247869,0.263209


Map: 100%|██████████| 1800/1800 [00:03<00:00, 462.51 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 410.89 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.307511,0.281093,0.161959,0.246432,0.260129
2,No log,1.310905,0.280732,0.159724,0.243753,0.258511
3,No log,1.313836,0.274489,0.151946,0.237085,0.251352
4,No log,1.312257,0.276519,0.154367,0.23816,0.253776
5,No log,1.316366,0.283862,0.162138,0.245769,0.260826
6,No log,1.316945,0.279466,0.159443,0.242646,0.255739
7,No log,1.317602,0.280622,0.15985,0.24296,0.25698
8,No log,1.318575,0.27913,0.15604,0.241194,0.253525
9,No log,1.31977,0.281967,0.157721,0.242905,0.256746
10,No log,1.319354,0.280427,0.157167,0.241238,0.255214


Map: 100%|██████████| 1800/1800 [00:03<00:00, 456.16 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 426.57 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.495537,0.281447,0.159418,0.240431,0.254719
2,No log,1.49579,0.283999,0.163565,0.245079,0.259839
3,No log,1.498654,0.283613,0.163492,0.244689,0.259624
4,No log,1.500804,0.279684,0.160634,0.242136,0.25575
5,No log,1.502459,0.283032,0.163399,0.244426,0.259542
6,No log,1.503678,0.282062,0.160458,0.241179,0.257626
7,No log,1.504952,0.281475,0.161173,0.240895,0.258226
8,No log,1.506422,0.282464,0.159867,0.242423,0.258054
9,No log,1.507198,0.284886,0.162511,0.245143,0.261598
10,No log,1.509098,0.284997,0.163406,0.244678,0.261999


Map: 100%|██████████| 1800/1800 [00:03<00:00, 459.86 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 435.58 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.252972,0.29843,0.17424,0.264581,0.27513
2,No log,1.255789,0.298815,0.175729,0.265033,0.276126
3,No log,1.258242,0.300055,0.175819,0.264761,0.276904
4,No log,1.259341,0.298919,0.174322,0.264444,0.276417
5,No log,1.261856,0.297783,0.175181,0.264345,0.274368
6,No log,1.262981,0.297313,0.174759,0.262516,0.273892
7,No log,1.264676,0.299973,0.174606,0.264751,0.276241
8,No log,1.265384,0.29943,0.173449,0.265401,0.275952
9,No log,1.266247,0.299068,0.175179,0.265193,0.276132
10,No log,1.266555,0.299509,0.175891,0.266101,0.277574


Map: 100%|██████████| 1800/1800 [00:03<00:00, 456.48 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 419.41 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.146784,0.298083,0.175249,0.260035,0.277532
2,No log,1.151021,0.299143,0.175414,0.261687,0.27809
3,No log,1.151621,0.292927,0.169776,0.257409,0.272368
4,No log,1.155335,0.293415,0.169802,0.2561,0.272288
5,No log,1.156792,0.292454,0.169878,0.255876,0.271962
6,No log,1.158569,0.293507,0.170677,0.256134,0.272809
7,No log,1.159877,0.295186,0.171293,0.256645,0.273546
8,No log,1.162236,0.295218,0.170506,0.25717,0.272651
9,No log,1.162804,0.2974,0.170121,0.257086,0.273932
10,No log,1.163963,0.29407,0.168546,0.25576,0.271556


In [14]:
model.save_pretrained("T5-decomposed")
tokenizer.save_pretrained("T5-decomposed")

('T5-decomposed/tokenizer_config.json',
 'T5-decomposed/special_tokens_map.json',
 'T5-decomposed/spiece.model',
 'T5-decomposed/added_tokens.json',
 'T5-decomposed/tokenizer.json')

In [23]:
validate_df = validate_df.map(tokenization_with_answer, batched=True)
predict_results = trainer.predict(validate_df, max_length=256)

Map: 100%|██████████| 200/200 [00:00<00:00, 412.32 examples/s]


OverflowError: out of range integral type conversion attempted

In [18]:
metrics = predict_results.metrics
metrics

{'test_loss': 2.3074774742126465,
 'test_rouge1': 0.2610107969166739,
 'test_rouge2': 0.12761580719455787,
 'test_rougeL': 0.22002125622999238,
 'test_rougeLsum': 0.2349292080542475,
 'test_runtime': 4.8351,
 'test_samples_per_second': 41.364,
 'test_steps_per_second': 0.62}

In [19]:
predictions = tokenizer.batch_decode(predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)
predictions = [pred.strip() for pred in predictions]

In [20]:
predictions[1]

'Snowdon Mountain Railway has two locomotives with 0-4-2 twitch arrangement.'

In [22]:
validate_df[1]['summary']

'Swiss Locomotive and Machine Works built a Mountain Railway Rack Steam Locomotive with an slm number of 988 in 1896. It has a wheel arrangement of 0 - 4 - 2 T, and is located on the Snowdon Mountain Railway. Its name is Snowdon.'