In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
from sklearn.model_selection import KFold
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import datasets
import pandas as pd
import os
import logging
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from random import sample


train_df = datasets.load_from_disk("./QTSumm/decomposed_train")
test_df = datasets.load_from_disk("./QTSumm/decomposed_test")

In [5]:
model_path = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [6]:
from typing import List, Dict

def tokenization_with_answer(examples):
    inputs = []
    targets = []

    for i, (query, table, answer, coordinates, summary) in enumerate(zip(examples['query'], examples['table'], examples['answers'], examples['coordinates'], examples['summary'])):
        flattened_table = flatten_table(table, i)
        input_text = f"Table {flattened_table}. Query: {query}"

        inputs.append(input_text)
        targets.append(summary)

    res = tokenizer(inputs, text_target=targets, truncation=True, padding=True)
    return res

def flatten_table(table: Dict, row_index: int) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    title = table.get('title', [])

    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        flattened_rows.append("## "+row_text)

    flattened_table = f"Title: {' '.join(map(str, title))}" + " " + " ".join(flattened_rows)
    return flattened_table

tokenized_dataset_train = train_df.map(tokenization_with_answer, batched=True)
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)

processed_data_train = tokenized_dataset_train.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

In [7]:
def k_fold_split(dataset, num_folds=5):
    fold_size = len(dataset) // num_folds
    folds = []
    for i in range(num_folds):
        start = i * fold_size
        end = start + fold_size if i < num_folds - 1 else len(dataset)
        folds.append(dataset.select(range(start, end)))
    return folds

In [8]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate

def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

    rouge = evaluate.load('rouge')
    bleu = evaluate.load('bleu')

    # Compute ROUGE scores
    rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)

    # Compute BLEU scores
    bleu_results = bleu.compute(predictions=decoded_predictions, references=decoded_labels)

    # Combine ROUGE and BLEU results
    results = {
        **rouge_results,
        **bleu_results
    }

    return rouge_results

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model= model)

train_args = Seq2SeqTrainingArguments(
    output_dir="./train_weights",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=25,
    evaluation_strategy="epoch",
    predict_with_generate=True,
    overwrite_output_dir= True
)

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=processed_data_train,
    eval_dataset=processed_data_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_fn
)

2024-03-26 07:06:48.210450: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [8]:
folds = k_fold_split(train_df, num_folds=10)

for i in range(len(folds)):
    val_fold = folds[i]
    train_folds = [folds[j] for j in range(len(folds)) if j != i]
    train_dataset = concatenate_datasets(train_folds)

    tokenized_train = train_dataset.map(tokenization_with_answer, batched=True)
    tokenized_val = val_fold.map(tokenization_with_answer, batched=True)

    # Remove unnecessary columns
    processed_train = tokenized_train.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
    processed_val = tokenized_val.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

    # Update your trainer's train_dataset and eval_dataset
    trainer.train_dataset = processed_train
    trainer.eval_dataset = processed_val

    # Train your model
    trainer.train()
    trainer.evaluate()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,5.798208,0.278444,0.142267,0.233791,0.247268
2,No log,3.808136,0.293421,0.157435,0.244946,0.261295
3,No log,3.073267,0.291851,0.159595,0.244333,0.258026
4,No log,2.553326,0.29619,0.162036,0.24849,0.26219
5,No log,2.113953,0.296415,0.163582,0.249275,0.262931
6,No log,1.743766,0.299833,0.166211,0.253515,0.26763
7,No log,1.453032,0.302246,0.168323,0.257585,0.269918
8,No log,1.243703,0.303203,0.171639,0.25854,0.270227
9,No log,1.103662,0.301691,0.170529,0.25788,0.269885
10,No log,1.013255,0.301075,0.170257,0.256507,0.26966


Map: 100%|██████████| 1800/1800 [00:02<00:00, 618.34 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 781.81 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.527897,0.318053,0.192827,0.276788,0.290706
2,No log,0.528775,0.310035,0.187992,0.269854,0.283776
3,No log,0.532232,0.31254,0.188343,0.270181,0.283187
4,No log,0.528204,0.315362,0.188326,0.272667,0.286637
5,No log,0.530072,0.313953,0.186863,0.271766,0.285376
6,No log,0.529209,0.316174,0.188786,0.273299,0.287399
7,No log,0.532105,0.311289,0.189126,0.271585,0.285233
8,No log,0.530424,0.313228,0.187497,0.27259,0.286295
9,No log,0.531246,0.315294,0.188349,0.27137,0.285337
10,No log,0.530984,0.31615,0.189334,0.276301,0.289049


Map: 100%|██████████| 1800/1800 [00:02<00:00, 735.11 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 762.16 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.421653,0.34728,0.244317,0.315141,0.327118
2,No log,0.426799,0.345494,0.239327,0.312069,0.325656
3,No log,0.428242,0.34278,0.238965,0.310421,0.324226
4,No log,0.431966,0.346087,0.239486,0.31276,0.326051
5,No log,0.434894,0.34403,0.23849,0.311155,0.324036
6,No log,0.437852,0.346003,0.239202,0.312994,0.326467
7,No log,0.440021,0.344459,0.237867,0.312122,0.325196
8,No log,0.441112,0.342181,0.234523,0.308331,0.321016
9,No log,0.441621,0.342847,0.234536,0.309167,0.321091
10,No log,0.44451,0.344164,0.235762,0.30974,0.323345


Map: 100%|██████████| 1800/1800 [00:02<00:00, 730.90 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 757.30 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.334987,0.340703,0.236368,0.313675,0.324256
2,No log,0.339237,0.343233,0.23605,0.313463,0.324401
3,No log,0.34324,0.340228,0.237812,0.313982,0.322788
4,No log,0.345778,0.338953,0.229422,0.309161,0.31897
5,No log,0.347948,0.337538,0.231297,0.309873,0.319681
6,No log,0.350753,0.337091,0.226976,0.306462,0.317271
7,No log,0.351559,0.335982,0.230768,0.30945,0.318589
8,No log,0.352814,0.331385,0.225376,0.303807,0.31388
9,No log,0.3538,0.334664,0.227936,0.307018,0.315117
10,No log,0.355779,0.335852,0.228445,0.307405,0.318019


Map: 100%|██████████| 1800/1800 [00:02<00:00, 797.90 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 757.17 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.240109,0.371416,0.287736,0.347025,0.35569
2,No log,0.242494,0.369763,0.28572,0.345313,0.354555
3,No log,0.244598,0.370618,0.283973,0.345374,0.35419
4,No log,0.246455,0.370567,0.282194,0.344274,0.353985
5,No log,0.249446,0.370532,0.280098,0.344381,0.354104
6,No log,0.250859,0.36978,0.280805,0.342972,0.353711
7,No log,0.250964,0.367621,0.281058,0.341794,0.352301
8,No log,0.253085,0.36555,0.277172,0.341195,0.350101
9,No log,0.254736,0.367916,0.280204,0.343221,0.353329
10,No log,0.256546,0.36448,0.274275,0.339512,0.348152


Map: 100%|██████████| 1800/1800 [00:02<00:00, 624.23 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 799.99 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.155613,0.393132,0.335088,0.38565,0.387609
2,No log,0.158226,0.389363,0.329863,0.37936,0.382194
3,No log,0.163385,0.392442,0.332347,0.382529,0.385854
4,No log,0.16394,0.390537,0.326237,0.37942,0.383069
5,No log,0.164889,0.389794,0.328304,0.379312,0.383306
6,No log,0.166986,0.388495,0.326953,0.378362,0.381178
7,No log,0.168328,0.389426,0.325912,0.378579,0.383148
8,No log,0.167983,0.387718,0.325451,0.375648,0.380329
9,No log,0.170305,0.390006,0.324501,0.377429,0.38168
10,No log,0.170759,0.387773,0.323901,0.375794,0.379803


Map: 100%|██████████| 1800/1800 [00:02<00:00, 776.05 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 759.40 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.085647,0.412273,0.368304,0.403996,0.407594
2,No log,0.089153,0.408497,0.360569,0.398923,0.402709
3,No log,0.091541,0.4086,0.362078,0.399296,0.403282
4,No log,0.092134,0.407802,0.359944,0.397638,0.401723
5,No log,0.093365,0.405497,0.352778,0.393657,0.399078
6,No log,0.095969,0.408664,0.356287,0.395968,0.400982
7,No log,0.097662,0.404595,0.350629,0.392622,0.397972
8,No log,0.096601,0.400552,0.346008,0.386674,0.393172
9,No log,0.097243,0.401799,0.34805,0.389172,0.394838
10,No log,0.09692,0.404966,0.348613,0.391456,0.39701


Map: 100%|██████████| 1800/1800 [00:02<00:00, 782.76 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 779.98 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.092535,0.416275,0.383123,0.413274,0.415161
2,No log,0.095666,0.413186,0.376386,0.410918,0.412173
3,No log,0.098289,0.41223,0.374048,0.408966,0.410754
4,No log,0.099036,0.409278,0.370164,0.406038,0.407978
5,No log,0.100639,0.410546,0.369193,0.406026,0.408525
6,No log,0.099501,0.411239,0.373025,0.407496,0.409635
7,No log,0.101404,0.410248,0.371188,0.405908,0.408175
8,No log,0.103276,0.410134,0.370871,0.406073,0.408439
9,No log,0.104076,0.408403,0.366721,0.403556,0.406627
10,No log,0.103961,0.409354,0.369338,0.405164,0.40745


Map: 100%|██████████| 1800/1800 [00:02<00:00, 774.29 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 784.23 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.049461,0.416594,0.379146,0.413278,0.414448
2,No log,0.049843,0.416357,0.379704,0.413233,0.414899
3,No log,0.051408,0.414376,0.37607,0.411389,0.412811
4,No log,0.051869,0.415116,0.377396,0.412583,0.41373
5,No log,0.050885,0.411036,0.378313,0.4085,0.409354
6,No log,0.052846,0.41263,0.372467,0.408588,0.410112
7,No log,0.053377,0.407523,0.373352,0.404517,0.405248
8,No log,0.052947,0.407339,0.372147,0.403699,0.405039
9,No log,0.053737,0.411332,0.369611,0.405858,0.409295
10,No log,0.054018,0.407328,0.371911,0.40356,0.404945


Map: 100%|██████████| 1800/1800 [00:02<00:00, 731.08 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 749.12 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.025506,0.419538,0.399468,0.418104,0.418392
2,No log,0.027062,0.417883,0.395333,0.415758,0.416397
3,No log,0.026185,0.417928,0.396065,0.415874,0.416663
4,No log,0.027684,0.416739,0.393909,0.414355,0.41523
5,No log,0.027929,0.417296,0.393552,0.41464,0.416025
6,No log,0.029139,0.417447,0.39344,0.414162,0.415669
7,No log,0.028757,0.41624,0.392118,0.413773,0.414661
8,No log,0.029358,0.416023,0.39097,0.412749,0.413626
9,No log,0.029095,0.415732,0.390464,0.412774,0.413565
10,No log,0.029602,0.416654,0.392232,0.413663,0.414847


In [9]:
model.save_pretrained("BART-decomposed")
tokenizer.save_pretrained("BART-decomposed")

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


('BART-decomposed/tokenizer_config.json',
 'BART-decomposed/special_tokens_map.json',
 'BART-decomposed/vocab.json',
 'BART-decomposed/merges.txt',
 'BART-decomposed/added_tokens.json',
 'BART-decomposed/tokenizer.json')

In [54]:
# from transformers import BartForConditionalGeneration, BartTokenizer

# # Load the model and tokenizer
# model = BartForConditionalGeneration.from_pretrained("BART-decomposed")
# tokenizer = BartTokenizer.from_pretrained("BART-decomposed")

In [55]:
# train_args = Seq2SeqTrainingArguments(
#     output_dir="./train_weights",
#     learning_rate=2e-5,
#     per_device_train_batch_size=32,
#     per_device_eval_batch_size=64,
#     num_train_epochs=25,
#     evaluation_strategy="epoch",
#     predict_with_generate=True,
#     overwrite_output_dir= True
# )

# trainer = Seq2SeqTrainer(
#     model,
#     train_args,
#     train_dataset=processed_data_train,
#     eval_dataset=processed_data_test,
#     tokenizer=tokenizer,
#     data_collator=data_collator,
#     compute_metrics=metric_fn
# )



In [56]:
import evaluate  

validate = datasets.load_from_disk("./QTSumm/validate")

tokenized_dataset_test = validate.map(tokenization_with_answer, batched=True)
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

predictions, labels, metrics = trainer.predict(processed_data_test)
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)


rouge = evaluate.load("rouge")
rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)
print(rouge_results)

Map: 100%|██████████| 200/200 [00:00<00:00, 340.72 examples/s]


{'rouge1': 0.2972235879231314, 'rouge2': 0.1716006115513433, 'rougeL': 0.2539837486618498, 'rougeLsum': 0.2712466135462311}


In [58]:
print(decoded_predictions[5])
print("The correct summary is:\n")
print(decoded_labels[5])

The two players from Argentina that have appeared in Real Salt Lake are Javier Morales and Fab
The correct summary is:

The two players from Argentina that have appeared in Real Salt Lake are Javier Morales and Fabian Espíndola.
Javier Morales has made 155 appearances with 28 goals while Fabian Espíndola has made 125 appearances with 35 goals.
Both players played for Real Salt Lake between 2007-2012.


In [None]:
bertscore = evaluate.load("bertscore")
bert_score = bertscore.compute(predictions=decoded_predictions, references=decoded_labels, lang = "en")