In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
from sklearn.model_selection import KFold
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import datasets
import pandas as pd
import os
import logging
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from random import sample


train_df = datasets.load_from_disk("./QTSumm/train_with_answer")
test_df = datasets.load_from_disk("./QTSumm/test_with_answer")

In [4]:
model_path = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [5]:
from typing import List, Dict

def tokenization_with_answer(examples):
    inputs = []
    targets = []

    for i, (query, table, answer, coordinates, summary) in enumerate(zip(examples['query'], examples['table'], examples['answers'], examples['coordinates'], examples['summary'])):
        flattened_table = flatten_table(table, i)
        input_text = f"{flattened_table} Query: {query} Potential answer: "

        # Get the header names
        header = table.get('header', [])

        # Append row and column names for each coordinate
        for coordinate in coordinates:
            row_idx, col_idx = coordinate
            row_name = f"Row {row_idx}"
            col_name = header[col_idx]
            input_text += f"{row_name}, {col_name} | "

        inputs.append(input_text[:-2])
        targets.append(summary)

    res = tokenizer(inputs, text_target=targets, truncation=True, padding=True)
    return res

def flatten_table(table: Dict, row_index: int) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    title = table.get('title', [])

    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        row_text += " ##"
        flattened_rows.append(row_text)

    flattened_table = f"Title: {' '.join(map(str, title))}" + " " + " ".join(flattened_rows)
    return flattened_table

tokenized_dataset_train = train_df.map(tokenization_with_answer, batched=True)
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)

processed_data_train = tokenized_dataset_train.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

In [6]:
def k_fold_split(dataset, num_folds=5):
    fold_size = len(dataset) // num_folds
    folds = []
    for i in range(num_folds):
        start = i * fold_size
        end = start + fold_size if i < num_folds - 1 else len(dataset)
        folds.append(dataset.select(range(start, end)))
    return folds

In [7]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate

def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

    rouge = evaluate.load('rouge')
    bleu = evaluate.load('bleu')

    # Compute ROUGE scores
    rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)

    # Compute BLEU scores
    bleu_results = bleu.compute(predictions=decoded_predictions, references=decoded_labels)

    # Combine ROUGE and BLEU results
    results = {
        **rouge_results,
        **bleu_results
    }

    return rouge_results

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model= model)

train_args = Seq2SeqTrainingArguments(
    output_dir="./train_weights",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=25,
    evaluation_strategy="epoch",
    predict_with_generate=True,
    overwrite_output_dir= True
)

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=processed_data_train,
    eval_dataset=processed_data_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_fn
)

2024-03-26 23:17:39.750872: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [8]:
folds = k_fold_split(train_df, num_folds=10)

for i in range(len(folds)):
    val_fold = folds[i]
    train_folds = [folds[j] for j in range(len(folds)) if j != i]
    train_dataset = concatenate_datasets(train_folds)

    tokenized_train = train_dataset.map(tokenization_with_answer, batched=True)
    tokenized_val = val_fold.map(tokenization_with_answer, batched=True)

    # Remove unnecessary columns
    processed_train = tokenized_train.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
    processed_val = tokenized_val.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

    # Update your trainer's train_dataset and eval_dataset
    trainer.train_dataset = processed_train
    trainer.eval_dataset = processed_val

    # Train your model
    trainer.train()
    trainer.evaluate()

Map: 100%|██████████| 1800/1800 [00:03<00:00, 454.54 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 575.02 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,5.523032,0.245657,0.11883,0.20502,0.22023
2,No log,3.769721,0.266796,0.134698,0.225318,0.239636
3,No log,3.045618,0.273328,0.142341,0.22835,0.242471
4,No log,2.521496,0.278067,0.146266,0.233152,0.245867
5,No log,2.080003,0.279583,0.148878,0.236702,0.248468
6,No log,1.708181,0.281678,0.150447,0.238018,0.250348
7,No log,1.416697,0.277603,0.149991,0.236155,0.249369
8,No log,1.205658,0.281329,0.152844,0.237627,0.250746
9,No log,1.065841,0.279518,0.154226,0.239724,0.250701
10,No log,0.975928,0.279838,0.155713,0.239641,0.252112


Map: 100%|██████████| 1800/1800 [00:02<00:00, 621.53 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 649.93 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.499285,0.305139,0.184477,0.258233,0.274205
2,No log,0.497518,0.303608,0.181376,0.25796,0.272053
3,No log,0.501955,0.308621,0.182207,0.257013,0.274418
4,No log,0.497678,0.308843,0.184705,0.258684,0.274736
5,No log,0.499811,0.308579,0.182741,0.258033,0.275035
6,No log,0.501081,0.305922,0.183183,0.260084,0.276693
7,No log,0.500469,0.30684,0.184137,0.261553,0.278353
8,No log,0.501709,0.306796,0.184964,0.262432,0.277633
9,No log,0.5008,0.307256,0.184211,0.260193,0.276653
10,No log,0.502766,0.309581,0.184108,0.259004,0.277099


Map: 100%|██████████| 1800/1800 [00:02<00:00, 626.03 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 591.36 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.365235,0.333087,0.22604,0.299206,0.311277
2,No log,0.369385,0.333786,0.226223,0.300086,0.311641
3,No log,0.372671,0.333825,0.226107,0.300911,0.311363
4,No log,0.377099,0.327899,0.219727,0.293088,0.305014
5,No log,0.379234,0.330927,0.223395,0.296766,0.306989
6,No log,0.382934,0.332245,0.221423,0.295722,0.308179
7,No log,0.385114,0.328527,0.219555,0.295507,0.306068
8,No log,0.386122,0.333367,0.223118,0.297459,0.309531
9,No log,0.385745,0.333673,0.223414,0.296932,0.310483
10,No log,0.390435,0.333864,0.224311,0.297866,0.310517


Map: 100%|██████████| 1800/1800 [00:02<00:00, 631.00 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 573.31 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.295565,0.329258,0.225794,0.302534,0.311386
2,No log,0.301748,0.323582,0.220067,0.296243,0.303977
3,No log,0.304845,0.32984,0.225535,0.301164,0.308786
4,No log,0.307887,0.326469,0.220407,0.300202,0.30716
5,No log,0.30792,0.327133,0.219732,0.29766,0.306928
6,No log,0.311669,0.322103,0.212949,0.293136,0.30229
7,No log,0.314587,0.320439,0.210104,0.290086,0.298994
8,No log,0.317055,0.319971,0.211362,0.290849,0.300306
9,No log,0.318407,0.318664,0.211431,0.289616,0.297336
10,No log,0.317977,0.319377,0.213535,0.29299,0.30032


Map: 100%|██████████| 1800/1800 [00:03<00:00, 586.67 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 562.32 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.226269,0.360045,0.27177,0.329989,0.343307
2,No log,0.231051,0.356837,0.266822,0.326163,0.338265
3,No log,0.235508,0.357971,0.267605,0.328552,0.340788
4,No log,0.236446,0.355588,0.265856,0.327361,0.338423
5,No log,0.238843,0.351843,0.260314,0.323988,0.331565
6,No log,0.242649,0.352501,0.260415,0.322744,0.333076
7,No log,0.242208,0.351683,0.260124,0.321308,0.331162
8,No log,0.245006,0.35292,0.260817,0.322531,0.332385
9,No log,0.246507,0.350135,0.258482,0.320805,0.33144
10,No log,0.247965,0.353337,0.260694,0.321569,0.332749


Map: 100%|██████████| 1800/1800 [00:04<00:00, 447.63 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 640.71 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.121259,0.390732,0.331553,0.379583,0.384421
2,No log,0.124199,0.386846,0.324352,0.373655,0.379354
3,No log,0.126356,0.384589,0.321135,0.370521,0.375952
4,No log,0.131273,0.383061,0.317014,0.368792,0.374489
5,No log,0.129551,0.385201,0.322282,0.370812,0.377126
6,No log,0.131567,0.382425,0.31677,0.368584,0.373685
7,No log,0.132272,0.381201,0.31513,0.365491,0.372041
8,No log,0.13315,0.379737,0.314326,0.365051,0.371545
9,No log,0.134136,0.381975,0.312449,0.366396,0.372526
10,No log,0.135094,0.378944,0.311054,0.362973,0.369862


Map: 100%|██████████| 1800/1800 [00:02<00:00, 631.58 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 594.08 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.082238,0.396999,0.346291,0.38566,0.389581
2,No log,0.086043,0.392335,0.337158,0.380099,0.384417
3,No log,0.088524,0.39145,0.334945,0.377994,0.383216
4,No log,0.089416,0.393393,0.338936,0.380413,0.385319
5,No log,0.091438,0.390093,0.331762,0.377418,0.382588
6,No log,0.09147,0.389318,0.330234,0.376131,0.380327
7,No log,0.093509,0.388295,0.331374,0.376203,0.380534
8,No log,0.094054,0.386257,0.325897,0.371962,0.37767
9,No log,0.094622,0.384158,0.322006,0.369334,0.375016
10,No log,0.094725,0.385254,0.324151,0.370418,0.375426


Map: 100%|██████████| 1800/1800 [00:02<00:00, 630.05 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 602.90 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.078364,0.400366,0.362374,0.394714,0.396955
2,No log,0.081462,0.399611,0.362325,0.394107,0.396465
3,No log,0.084288,0.39746,0.35631,0.391094,0.393695
4,No log,0.085749,0.393448,0.35006,0.386771,0.388785
5,No log,0.087041,0.395117,0.351776,0.38721,0.389981
6,No log,0.087511,0.394633,0.350348,0.387438,0.390243
7,No log,0.089973,0.393697,0.349438,0.386646,0.388721
8,No log,0.091835,0.39568,0.350407,0.38747,0.390221
9,No log,0.091871,0.394231,0.34713,0.386229,0.389263
10,No log,0.0918,0.394374,0.349141,0.386612,0.389457


Map: 100%|██████████| 1800/1800 [00:03<00:00, 561.01 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 614.60 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.046214,0.396743,0.36477,0.392677,0.39435
2,No log,0.048003,0.397171,0.359536,0.390799,0.392024
3,No log,0.048116,0.39304,0.359173,0.388034,0.389725
4,No log,0.049674,0.393469,0.358797,0.388279,0.389856
5,No log,0.049169,0.390985,0.355983,0.385025,0.386756
6,No log,0.051286,0.389861,0.355775,0.383906,0.385412
7,No log,0.051315,0.392304,0.355898,0.386316,0.388407
8,No log,0.05095,0.393006,0.357632,0.38576,0.387612
9,No log,0.052029,0.390934,0.354639,0.38352,0.385382
10,No log,0.052857,0.387509,0.348611,0.380842,0.382961


Map: 100%|██████████| 1800/1800 [00:03<00:00, 537.52 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 610.10 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.02295,0.413151,0.387308,0.408972,0.411198
2,No log,0.023687,0.413695,0.387276,0.408726,0.411185
3,No log,0.023961,0.413647,0.387642,0.408856,0.411061
4,No log,0.024701,0.412148,0.384569,0.407186,0.409797
5,No log,0.02488,0.411235,0.382287,0.405081,0.408226
6,No log,0.026247,0.409601,0.378658,0.40344,0.406585
7,No log,0.025777,0.412153,0.382865,0.406028,0.409256
8,No log,0.026196,0.411645,0.381726,0.40555,0.408502
9,No log,0.026755,0.411108,0.380143,0.405263,0.408071
10,No log,0.027235,0.409364,0.376744,0.402838,0.405801


In [8]:
model.save_pretrained("BART-with-answer")
tokenizer.save_pretrained("BART-with-answer")

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


('BART-with-answer/tokenizer_config.json',
 'BART-with-answer/special_tokens_map.json',
 'BART-with-answer/vocab.json',
 'BART-with-answer/merges.txt',
 'BART-with-answer/added_tokens.json',
 'BART-with-answer/tokenizer.json')

In [9]:
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

In [10]:
predictions, labels, metrics = trainer.predict(processed_data_test)
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

In [11]:
decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

In [13]:
import evaluate  

rouge = evaluate.load("rouge")
rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)
print(rouge_results)

{'rouge1': 0.2893337874714074, 'rouge2': 0.1616152299095607, 'rougeL': 0.24686856940027702, 'rougeLsum': 0.261382603691923}
