In [None]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
torch.cuda.empty_cache()

In [20]:
from sklearn.model_selection import KFold
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import datasets
import pandas as pd
import os
import logging
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from random import sample


train_df = datasets.load_from_disk("/home/y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_train")
test_df = datasets.load_from_disk("/home/y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_test")
validate_df = datasets.load_from_disk("/home/y.khan/Query-Focused-Tabular-Summarization/data/decomposed/decomposed_validate")

In [4]:
model_path = "facebook/bart-large"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [4]:
model_path = "BART-decomposed-large"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [5]:
from typing import List, Dict

def tokenization_with_answer(examples):
    inputs = []
    targets = []
    
    task_prefix = "Given a query and a table, generate a summary that answers the query based on the information in the table: "

    for i, (query, table, answer, coordinates, summary) in enumerate(zip(examples['query'], examples['table'], examples['answers'], examples['coordinates'], examples['summary'])):
        flattened_table = flatten_table(table, i)
        input_text = f"{task_prefix} Table {flattened_table}. Query: {query}"

        inputs.append(input_text)
        targets.append(summary)
        
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True,padding='max_length')
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=512, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"] 

    res = tokenizer(inputs, text_target=targets, truncation=True, padding=True)
    return model_inputs

def flatten_table(table: Dict, row_index: int) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    title = table.get('title', [])

    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        flattened_rows.append("## "+row_text)

    flattened_table = f"Title: {' '.join(map(str, title))}" + " " + " ".join(flattened_rows)
    return flattened_table

tokenized_dataset_train = train_df.map(tokenization_with_answer, batched=True)
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)

processed_data_train = tokenized_dataset_train.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

In [6]:
def k_fold_split(dataset, num_folds=5):
    fold_size = len(dataset) // num_folds
    folds = []
    for i in range(num_folds):
        start = i * fold_size
        end = start + fold_size if i < num_folds - 1 else len(dataset)
        folds.append(dataset.select(range(start, end)))
    return folds

In [7]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate

def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

    rouge = evaluate.load('rouge')

    # Compute ROUGE scores
    rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)

    return rouge_results

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model= model)

train_args = Seq2SeqTrainingArguments(
    output_dir="./train_weights_bart_decomposed",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    num_train_epochs=20,
    evaluation_strategy="epoch",
    save_strategy = "epoch",
    weight_decay=0.01,
    save_total_limit=5,
    warmup_ratio=0.03,
    load_best_model_at_end=True,
    predict_with_generate=True,
    overwrite_output_dir= True
)


trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=processed_data_train,
    eval_dataset=processed_data_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_fn
)

2024-03-29 04:19:28.769139: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [10]:
folds = k_fold_split(train_df, num_folds=10)

for i in range(len(folds)):
    val_fold = folds[i]
    train_folds = [folds[j] for j in range(len(folds)) if j != i]
    train_dataset = concatenate_datasets(train_folds)

    tokenized_train = train_dataset.map(tokenization_with_answer, batched=True)
    tokenized_val = val_fold.map(tokenization_with_answer, batched=True)

    # Remove unnecessary columns
    processed_train = tokenized_train.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
    processed_val = tokenized_val.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

    # Update your trainer's train_dataset and eval_dataset
    trainer.train_dataset = processed_train
    trainer.eval_dataset = processed_val

    # Train your model
    trainer.train()
    trainer.evaluate()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,1.74164,0.310922,0.190126,0.274251,0.284224
2,No log,1.746532,0.3117,0.186292,0.2722,0.282696
3,No log,1.781855,0.308137,0.18452,0.27046,0.282074
4,No log,1.792227,0.310651,0.180755,0.268887,0.28113
5,No log,1.859723,0.304483,0.180463,0.263685,0.27692
6,No log,1.847143,0.306652,0.182138,0.268237,0.279202
7,0.953000,1.914978,0.309419,0.179245,0.267391,0.280251
8,0.953000,1.905485,0.309323,0.185684,0.272819,0.283899
9,0.953000,1.961449,0.309507,0.182913,0.271276,0.283882
10,0.953000,2.014781,0.312802,0.183745,0.271571,0.283286


Downloading builder script: 100%|██████████| 6.27k/6.27k [00:00<00:00, 21.8MB/s]


Map: 100%|██████████| 1800/1800 [00:13<00:00, 134.93 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 411.28 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.364045,0.378481,0.317903,0.364101,0.371057
2,No log,0.362913,0.377242,0.305969,0.360126,0.367271
3,No log,0.388991,0.386919,0.31768,0.369138,0.377016
4,No log,0.406257,0.381557,0.308955,0.364593,0.371286
5,No log,0.404682,0.377502,0.302452,0.358389,0.366657
6,No log,0.412435,0.364652,0.289373,0.346564,0.355266
7,0.581300,0.405055,0.371139,0.291156,0.348776,0.357577
8,0.581300,0.413587,0.369951,0.292602,0.35138,0.35885
9,0.581300,0.425972,0.364491,0.288418,0.347132,0.353575
10,0.581300,0.436486,0.361531,0.281145,0.340903,0.348076


Map: 100%|██████████| 1800/1800 [00:04<00:00, 427.80 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 417.10 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.1565,0.419272,0.38606,0.416496,0.418385
2,No log,0.136608,0.413219,0.374962,0.407938,0.410821
3,No log,0.129492,0.402444,0.365389,0.394082,0.39672
4,No log,0.141372,0.407959,0.366141,0.40056,0.404135
5,No log,0.16156,0.412432,0.371556,0.406299,0.409584
6,No log,0.135184,0.404734,0.368046,0.398335,0.400729
7,0.299500,0.152729,0.39722,0.357946,0.390428,0.392657
8,0.299500,0.143979,0.400395,0.360895,0.392811,0.395079
9,0.299500,0.145063,0.40478,0.365985,0.397151,0.399483
10,0.299500,0.159074,0.409068,0.364837,0.401305,0.404562


Map: 100%|██████████| 1800/1800 [00:04<00:00, 422.56 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 412.09 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.06197,0.412412,0.392762,0.411311,0.412276
2,No log,0.076129,0.403392,0.379299,0.400439,0.401373
3,No log,0.064934,0.408537,0.38547,0.406839,0.407348
4,No log,0.066857,0.404594,0.382167,0.403075,0.403389
5,No log,0.06565,0.404224,0.379157,0.401967,0.40276
6,No log,0.066976,0.408126,0.385139,0.406402,0.407079
7,0.164700,0.06545,0.404206,0.380197,0.402028,0.402371
8,0.164700,0.063854,0.405026,0.382238,0.403485,0.404045
9,0.164700,0.081352,0.398497,0.371701,0.395192,0.396353
10,0.164700,0.06617,0.405291,0.379943,0.403946,0.404182


Map: 100%|██████████| 1800/1800 [00:08<00:00, 212.22 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 376.07 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.142327,0.416846,0.394999,0.41309,0.414817
2,No log,0.057901,0.41956,0.398587,0.41592,0.41707
3,No log,0.073529,0.414845,0.393639,0.411853,0.413005
4,No log,0.056515,0.417103,0.395399,0.413033,0.414709
5,No log,0.046258,0.418374,0.395948,0.41382,0.415693
6,No log,0.045891,0.417661,0.393274,0.412191,0.41444
7,0.107100,0.055041,0.416188,0.392838,0.411618,0.413287
8,0.107100,0.045833,0.418086,0.394226,0.412785,0.414664
9,0.107100,0.045942,0.418353,0.394515,0.413284,0.415322
10,0.107100,0.072462,0.417684,0.394105,0.412779,0.414353


Map: 100%|██████████| 1800/1800 [00:05<00:00, 324.95 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 431.59 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.038347,0.421725,0.403774,0.421493,0.421319
2,No log,0.043123,0.422697,0.404328,0.422105,0.42206
3,No log,0.042001,0.422435,0.403945,0.421828,0.421818
4,No log,0.039421,0.420669,0.400036,0.41859,0.41858
5,No log,0.059225,0.422748,0.404081,0.421586,0.421749
6,No log,0.035079,0.422347,0.404089,0.4217,0.421723
7,0.070500,0.048578,0.420912,0.402408,0.420228,0.420133
8,0.070500,0.053893,0.419583,0.401293,0.41922,0.418818
9,0.070500,0.036044,0.421895,0.403194,0.421129,0.421212
10,0.070500,0.037174,0.421917,0.403557,0.421226,0.421442


Map: 100%|██████████| 1800/1800 [00:10<00:00, 170.84 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 408.49 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.02959,0.433875,0.414325,0.432634,0.433279
2,No log,0.086528,0.433805,0.413301,0.432305,0.433125
3,No log,0.032434,0.434199,0.415285,0.433461,0.43385
4,No log,0.071422,0.433651,0.413588,0.432202,0.43306
5,No log,0.036016,0.434232,0.414824,0.433066,0.433587
6,No log,0.03371,0.433443,0.412903,0.431857,0.432724
7,0.050900,0.058399,0.434282,0.414148,0.432986,0.433688
8,0.050900,0.033302,0.433633,0.413721,0.432533,0.433193
9,0.050900,0.056583,0.431724,0.412366,0.43041,0.431015
10,0.050900,0.041597,0.43255,0.413259,0.431068,0.431909


Map: 100%|██████████| 1800/1800 [00:05<00:00, 328.24 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 416.67 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.046131,0.42979,0.411103,0.428886,0.429829
2,No log,0.038581,0.431093,0.411854,0.4304,0.431146
3,No log,0.03697,0.430814,0.41157,0.430334,0.431029
4,No log,0.037363,0.430978,0.412036,0.43034,0.43105
5,No log,0.034091,0.430564,0.411423,0.429776,0.430746
6,No log,0.04005,0.430499,0.411234,0.429707,0.430536
7,0.040300,0.033764,0.430148,0.410287,0.429183,0.430078
8,0.040300,0.033171,0.4302,0.410615,0.429509,0.430359
9,0.040300,0.037319,0.430509,0.411107,0.429838,0.43064
10,0.040300,0.04015,0.430615,0.411002,0.429721,0.430649


Map: 100%|██████████| 1800/1800 [00:04<00:00, 409.70 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 421.83 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.028518,0.422896,0.396384,0.422651,0.422708
2,No log,0.059398,0.422887,0.395844,0.42261,0.422595
3,No log,0.02757,0.423011,0.396418,0.42269,0.422768
4,No log,0.03619,0.422772,0.396218,0.422488,0.422492
5,No log,0.033121,0.423097,0.396397,0.422695,0.42274
6,No log,0.029683,0.423715,0.396686,0.423392,0.423469
7,0.032800,0.030137,0.422892,0.39634,0.422562,0.422652
8,0.032800,0.045365,0.422892,0.39634,0.422562,0.422652
9,0.032800,0.030146,0.422772,0.396047,0.422488,0.422492
10,0.032800,0.030626,0.422892,0.39634,0.422562,0.422652


Map: 100%|██████████| 1800/1800 [00:05<00:00, 305.91 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 413.77 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.061752,0.418781,0.400395,0.418817,0.418648
2,No log,0.031002,0.419056,0.400699,0.41894,0.418918
3,No log,0.054678,0.419509,0.400324,0.419117,0.418978
4,No log,0.030299,0.419509,0.400324,0.419117,0.418978
5,No log,0.045297,0.418246,0.398419,0.416998,0.417552
6,No log,0.055381,0.418933,0.40027,0.418866,0.418798
7,0.028000,0.031221,0.418005,0.398201,0.416582,0.417426
8,0.028000,0.066742,0.418498,0.400009,0.417808,0.418028
9,0.028000,0.030686,0.418132,0.399209,0.417553,0.417892
10,0.028000,0.033476,0.417809,0.398796,0.417323,0.417575


In [11]:
model.save_pretrained("BART-decomposed")
tokenizer.save_pretrained("BART-decomposed")

('BART-decomposed-large/tokenizer_config.json',
 'BART-decomposed-large/special_tokens_map.json',
 'BART-decomposed-large/vocab.json',
 'BART-decomposed-large/merges.txt',
 'BART-decomposed-large/added_tokens.json',
 'BART-decomposed-large/tokenizer.json')

In [33]:
valid1 = validate_df.select(range(0, 10))
valid2 = validate_df.select(range(10, 20))
valid3 = validate_df.select(range(20, 30))
valid4 = validate_df.select(range(30, 40))
valid5 = validate_df.select(range(40, 50))
valid6 = validate_df.select(range(50, 60))
valid7 = validate_df.select(range(60, 70))
valid8 = validate_df.select(range(70, 80))
valid9 = validate_df.select(range(80, 90))
valid10 = validate_df.select(range(90, 100))
valid11 = validate_df.select(range(100, 110))
valid12 = validate_df.select(range(110, 120))
valid13 = validate_df.select(range(120, 130))
valid14 = validate_df.select(range(130, 140))
valid15= validate_df.select(range(140, 150))
valid16 = validate_df.select(range(150, 160))
valid17 = validate_df.select(range(160, 170))
valid18 = validate_df.select(range(170, 180))
valid19 = validate_df.select(range(180, 190))
valid20 = validate_df.select(range(190, 200))

In [43]:
valid = [valid1, valid2, valid3, valid4, valid5, valid6, valid7, valid8, valid9, valid10, valid11, valid12, valid13, valid14, valid15, valid16, valid17, valid18, valid19,valid20]

In [52]:
rougeL = []
for i in range(20):
    validate_df = valid[i].map(tokenization_with_answer, batched=True)
    predict_results = trainer.predict(validate_df, max_length = 1024)
    metrics = predict_results.metrics

    rougeL.append(metrics['test_rougeLsum'])

Map: 100%|██████████| 10/10 [00:00<00:00, 248.65 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 262.08 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 294.57 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 286.65 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 254.29 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 256.79 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 210.35 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 238.61 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 250.34 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 258.32 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 274.77 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 275.89 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 97.27 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 230.71 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 290.25 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 273.12 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 265.13 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 281.26 examples/s]


Map: 100%|██████████| 10/10 [00:00<00:00, 278.00 examples/s]


In [53]:
sum(rougeL)/20

0.3994643853204672

In [54]:
rougeL

[0.4249032347992855,
 0.3977325553762604,
 0.3988015079677397,
 0.3970669486004853,
 0.47750503288214874,
 0.38922545450437884,
 0.38459774420980286,
 0.40872599518530695,
 0.3946639845161596,
 0.3973327160291645,
 0.37783619933499957,
 0.3758676980043921,
 0.47964852191307583,
 0.3084137150736289,
 0.3908993112468482,
 0.41420683147736537,
 0.40518408812365425,
 0.3532071575434926,
 0.361786290776044,
 0.45168271884511124]

In [55]:
validate_df = valid13.map(tokenization_with_answer, batched=True)
predict_results = trainer.predict(validate_df, max_length = 1024)

In [56]:
metrics = predict_results.metrics
metrics['test_rougeLsum']

0.47964852191307583

In [57]:
predictions = tokenizer.batch_decode(predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)
predictions = [pred.strip() for pred in predictions]

In [58]:
predictions

['The artist who presented "Ara" at the Lisboa International Art Festival 2012 was Claudia Crabuzza & Claudio Gabriel Sanna. They won the jury award and public award SUNS 2012, which helped their placing and point total in the competition.',
 'In the 2004 United States presidential election in Vermont, John Kerry, a Democratic candidate, received 58.94% of the popular vote, while George W. Bush, a Republican candidate, secured 121,180 votes, or 38.80%. Kerry was awarded 3 electoral votes, while Bush was awarded 0 electoral votes.',
 'In April, the Philadelphia Flyers accumulate a total of 83 points. This is the most points accumulated in April by any month.',
 "In table given, we can see the relation between number of floors and building height in Fresno's tallest buildings not same clearly. Number of floors go up when building height also go up, like in the case of Robert E. Coyle United States Courthouse, which have 9 floors and is highest building in list, it also tallest building. 

In [59]:
validate_df['summary']

['The artist who show "Ara," Claudia Crabuzza & Claudio Gabriel Sanna, have the qualification for win both the Jury award and public award at SUNS 2012. The artist get 5th place in the Liet International 2012, with total of 64 point.',
 'In the 2004 United States president election in Vermont, John Kerry get 58.94% of the popular votes and George W. Bush get 38.80% of the popular votes. John Kerry was give 3 electoral votes, while George W. Bush not get any electoral votes in Vermont.',
 'The Philadelphia Flyers accumulated 16 points in April. They won eight games, and tied one in overtime. This gave them a record of 36 - 37 - 11 for the month and a total of 83 points.',
 "The relation between number of floors and building height in tallest buildings in Fresno mostly positive. It means when floors more, building height also more. But this relation not always straight line because other things can affect building height, like each floor's height. For example, Golden State County Plaza h

In [60]:
bertscore = evaluate.load("bertscore")
bert_score = bertscore.compute(predictions=predictions, references=validate_df['summary'], lang = "en")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [61]:
bert_score

{'precision': [0.9332228899002075,
  0.9267280697822571,
  0.9194029569625854,
  0.8933998942375183,
  0.8934261202812195,
  0.8771347999572754,
  0.9359485507011414,
  0.8356510400772095,
  0.9432274103164673,
  0.8096977472305298],
 'recall': [0.9222641587257385,
  0.947601318359375,
  0.8829821348190308,
  0.9095994234085083,
  0.9000557661056519,
  0.9136006832122803,
  0.9224000573158264,
  0.8718510866165161,
  0.8724163770675659,
  0.8919053673744202],
 'f1': [0.9277111291885376,
  0.9370484948158264,
  0.9008246064186096,
  0.9014269113540649,
  0.8967287540435791,
  0.8949964642524719,
  0.9291248917579651,
  0.8533673286437988,
  0.9064410924911499,
  0.8488156795501709],
 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.32.1)'}