In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
from sklearn.model_selection import KFold
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import datasets
import pandas as pd
import os
import logging
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from random import sample


train_df = datasets.load_from_disk("./QTSumm/train_with_answer")
test_df = datasets.load_from_disk("./QTSumm/test_with_answer")

In [4]:
model_path = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

In [5]:
from typing import List, Dict

def tokenization_with_answer(examples):
    inputs = []
    targets = []

    for i, (query, table, answer, coordinates, summary) in enumerate(zip(examples['query'], examples['table'], examples['answers'], examples['coordinates'], examples['summary'])):
        flattened_table = flatten_table(table, i)
        input_text = f"{flattened_table} Query: {query} Potential answer: "

        # Get the header names
        header = table.get('header', [])

        # Append row and column names for each coordinate
        for coordinate in coordinates:
            row_idx, col_idx = coordinate
            row_name = f"Row {row_idx}"
            col_name = header[col_idx]
            input_text += f"{row_name}, {col_name} | "

        inputs.append(input_text[:-2])
        targets.append(summary)

    res = tokenizer(inputs, text_target=targets, truncation=True, padding=True)
    return res

def flatten_table(table: Dict, row_index: int) -> str:
    header = table.get('header', [])
    rows = table.get('rows', [])
    title = table.get('title', [])

    flattened_rows = []
    for i, row in enumerate(rows):
        row_text = f"Row {i}, " + ",".join([f"{col}:{val}" for col, val in zip(header, row)])
        row_text += " ##"
        flattened_rows.append(row_text)

    flattened_table = f"Title: {' '.join(map(str, title))}" + " " + " ".join(flattened_rows)
    return flattened_table

tokenized_dataset_train = train_df.map(tokenization_with_answer, batched=True)
tokenized_dataset_test = test_df.map(tokenization_with_answer, batched=True)

processed_data_train = tokenized_dataset_train.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
processed_data_test = tokenized_dataset_test.remove_columns(['table','summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

In [6]:
def k_fold_split(dataset, num_folds=5):
    fold_size = len(dataset) // num_folds
    folds = []
    for i in range(num_folds):
        start = i * fold_size
        end = start + fold_size if i < num_folds - 1 else len(dataset)
        folds.append(dataset.select(range(start, end)))
    return folds

In [7]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate

def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_predictions, decoded_labels = postprocess_text(decoded_predictions, decoded_labels)

    rouge = evaluate.load('rouge')
    bleu = evaluate.load('bleu')

    # Compute ROUGE scores
    rouge_results = rouge.compute(predictions=decoded_predictions, references=decoded_labels)

    # Compute BLEU scores
    bleu_results = bleu.compute(predictions=decoded_predictions, references=decoded_labels)

    # Combine ROUGE and BLEU results
    results = {
        **rouge_results,
        **bleu_results
    }

    return rouge_results

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model= model)

train_args = Seq2SeqTrainingArguments(
    output_dir="./train_weights",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=30,
    evaluation_strategy="epoch",
    predict_with_generate=True,
    overwrite_output_dir= True
)

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=processed_data_train,
    eval_dataset=processed_data_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_fn
)

2024-03-25 02:28:35.421860: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [8]:
folds = k_fold_split(train_df, num_folds=10)

for i in range(len(folds)):
    val_fold = folds[i]
    train_folds = [folds[j] for j in range(len(folds)) if j != i]
    train_dataset = concatenate_datasets(train_folds)

    tokenized_train = train_dataset.map(tokenization_with_answer, batched=True)
    tokenized_val = val_fold.map(tokenization_with_answer, batched=True)

    # Remove unnecessary columns
    processed_train = tokenized_train.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])
    processed_val = tokenized_val.remove_columns(['table', 'summary', 'row_ids', 'example_id', 'query', 'answers', 'coordinates'])

    # Update your trainer's train_dataset and eval_dataset
    trainer.train_dataset = processed_train
    trainer.eval_dataset = processed_val

    # Train your model
    trainer.train()
    trainer.evaluate()

Map: 100%|██████████| 200/200 [00:00<00:00, 312.98 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,17.754456,0.007624,0.000625,0.007038,0.007576
2,No log,7.351467,0.037667,0.005832,0.032283,0.034273
3,No log,3.925401,0.044816,0.007878,0.040288,0.041608
4,No log,3.527853,0.067336,0.014341,0.058198,0.062549
5,No log,2.698954,0.011416,0.00241,0.009195,0.010256
6,No log,1.683148,0.005598,0.001403,0.004321,0.005045
7,No log,1.318545,0.01681,0.002911,0.013015,0.014287
8,No log,1.168857,0.079134,0.022862,0.066077,0.070745
9,No log,1.100531,0.131548,0.041424,0.106006,0.114972
10,No log,1.048469,0.159777,0.051519,0.128959,0.139215


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:03<00:00, 566.18 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 636.68 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.697579,0.225396,0.092726,0.1803,0.198289
2,No log,0.693701,0.226204,0.091991,0.182359,0.197475
3,No log,0.689994,0.226827,0.094734,0.183583,0.200033
4,No log,0.687188,0.224602,0.093618,0.182707,0.198456
5,No log,0.684823,0.227173,0.098521,0.186744,0.200977
6,No log,0.682577,0.226627,0.094567,0.185503,0.200227
7,No log,0.680417,0.223363,0.094012,0.182413,0.197723
8,No log,0.678354,0.222986,0.093237,0.182142,0.197723
9,No log,0.677413,0.221925,0.092061,0.181276,0.196326
10,No log,0.675666,0.223531,0.090738,0.18127,0.198381


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 677.07 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 624.81 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.737694,0.242611,0.104137,0.195824,0.216876
2,No log,0.736907,0.241059,0.102784,0.194941,0.215387
3,No log,0.736028,0.238169,0.102166,0.191756,0.213666
4,No log,0.735818,0.239234,0.102554,0.192959,0.213676
5,No log,0.734183,0.239935,0.103009,0.192746,0.21507
6,No log,0.73373,0.237241,0.100122,0.192351,0.212154
7,No log,0.732643,0.236011,0.101046,0.19129,0.211573
8,No log,0.73214,0.239076,0.101581,0.194092,0.21447
9,No log,0.731597,0.235402,0.099372,0.190284,0.210553
10,No log,0.730586,0.239611,0.101249,0.193609,0.214942


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 679.41 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 628.77 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.767965,0.218978,0.088499,0.176685,0.193804
2,No log,0.768185,0.219001,0.087934,0.176596,0.193769
3,No log,0.768819,0.22037,0.090793,0.176579,0.195252
4,No log,0.768892,0.218662,0.090308,0.175082,0.193791
5,No log,0.769327,0.22198,0.091427,0.177298,0.195989
6,No log,0.768847,0.223366,0.093412,0.17977,0.19879
7,No log,0.768689,0.219768,0.091216,0.178274,0.195421
8,No log,0.768715,0.222593,0.093845,0.18063,0.198618
9,No log,0.769394,0.223531,0.092297,0.180447,0.198349
10,No log,0.769976,0.222485,0.093061,0.180635,0.198305


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 605.55 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 624.66 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.722099,0.234914,0.10756,0.196056,0.21128
2,No log,0.722912,0.238874,0.111593,0.200581,0.215168
3,No log,0.723439,0.237313,0.108931,0.19779,0.212176
4,No log,0.724109,0.23846,0.110341,0.199409,0.214449
5,No log,0.724739,0.236001,0.108861,0.197969,0.212814
6,No log,0.725051,0.237963,0.10852,0.197008,0.213106
7,No log,0.725122,0.237577,0.10957,0.196802,0.213872
8,No log,0.725062,0.23893,0.111229,0.198903,0.213898
9,No log,0.725577,0.237805,0.110191,0.198205,0.214113
10,No log,0.726027,0.235774,0.109269,0.19496,0.211047


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 669.96 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 672.70 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.495378,0.254784,0.130609,0.214958,0.231544
2,No log,0.496932,0.253847,0.132266,0.216499,0.231264
3,No log,0.497601,0.251235,0.130482,0.212507,0.229256
4,No log,0.498672,0.252288,0.132791,0.214886,0.231033
5,No log,0.499129,0.251666,0.130786,0.214689,0.230418
6,No log,0.499007,0.252083,0.131368,0.214583,0.230609
7,No log,0.49953,0.254151,0.132302,0.2152,0.232416
8,No log,0.500271,0.251724,0.130246,0.213829,0.230057
9,No log,0.500667,0.252592,0.132505,0.214181,0.231254
10,No log,0.501006,0.253635,0.130412,0.214852,0.231833


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 634.48 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 603.40 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.488045,0.23854,0.111773,0.201269,0.21653
2,No log,0.490444,0.235527,0.107028,0.198576,0.214143
3,No log,0.490275,0.240716,0.114224,0.202575,0.218145
4,No log,0.491298,0.241376,0.114238,0.203149,0.218821
5,No log,0.491222,0.2356,0.111157,0.199014,0.21431
6,No log,0.491301,0.235395,0.109229,0.199556,0.213899
7,No log,0.492192,0.238257,0.113315,0.203314,0.218301
8,No log,0.492335,0.23765,0.113373,0.202781,0.217167
9,No log,0.492319,0.235673,0.112747,0.200218,0.214711
10,No log,0.492887,0.233825,0.111297,0.198423,0.212367


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 666.29 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 632.68 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.572282,0.261779,0.126976,0.219503,0.233668
2,No log,0.573431,0.261823,0.12606,0.216832,0.23212
3,No log,0.573668,0.261917,0.126851,0.218967,0.23348
4,No log,0.574676,0.260966,0.125752,0.218144,0.232731
5,No log,0.576039,0.261309,0.125393,0.216256,0.233696
6,No log,0.576701,0.261395,0.124433,0.216745,0.233571
7,No log,0.575854,0.258151,0.125344,0.216727,0.231939
8,No log,0.576906,0.257187,0.123479,0.215563,0.232524
9,No log,0.577012,0.257405,0.123115,0.215513,0.230647
10,No log,0.576986,0.256184,0.123549,0.21385,0.22916


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 629.42 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 644.08 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.388776,0.258095,0.131744,0.220923,0.2373
2,No log,0.389561,0.257637,0.129372,0.219064,0.235772
3,No log,0.38991,0.257286,0.129918,0.21906,0.235478
4,No log,0.391133,0.260275,0.130371,0.220848,0.237908
5,No log,0.392618,0.255791,0.12811,0.216663,0.232956
6,No log,0.391443,0.256762,0.132023,0.220366,0.23593
7,No log,0.392807,0.256037,0.127212,0.215683,0.232574
8,No log,0.392822,0.255633,0.126018,0.217081,0.233623
9,No log,0.39323,0.258358,0.131585,0.220339,0.235883
10,No log,0.393746,0.260736,0.130423,0.22111,0.237401


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Map: 100%|██████████| 1800/1800 [00:02<00:00, 671.77 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 644.19 examples/s]


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,0.411646,0.270603,0.137313,0.226498,0.244938
2,No log,0.413617,0.266802,0.133927,0.22263,0.240472
3,No log,0.414207,0.265138,0.131917,0.220495,0.238542
4,No log,0.414998,0.266342,0.134652,0.223864,0.240301
5,No log,0.415911,0.266962,0.133729,0.22446,0.242237
6,No log,0.415977,0.274576,0.136801,0.229399,0.247508
7,No log,0.41682,0.273152,0.136907,0.228066,0.24561
8,No log,0.417825,0.27201,0.134884,0.227251,0.245619
9,No log,0.4179,0.266417,0.128708,0.22126,0.240497
10,No log,0.418224,0.265506,0.13076,0.22332,0.24054


Checkpoint destination directory ./train_weights/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


In [9]:
model.save_pretrained("Flan-with-answer-10fold")
tokenizer.save_pretrained("Flan-with-answer-10fold")

('Flan-with-answer-10fold/tokenizer_config.json',
 'Flan-with-answer-10fold/special_tokens_map.json',
 'Flan-with-answer-10fold/spiece.model',
 'Flan-with-answer-10fold/added_tokens.json',
 'Flan-with-answer-10fold/tokenizer.json')