In [None]:
!pip install datasets transformers rouge-score nltk py7zr

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
%cd /content/drive/MyDrive/NLP Project

/content/drive/MyDrive/NLP Project


In [5]:
# need to replace these code with modified version for this notebook to run properly
!cp "/content/drive/MyDrive/NLP Project/tokenization_utils_base.py" /usr/local/lib/python3.8/dist-packages/transformers/tokenization_utils_base.py
!cp "/content/drive/MyDrive/NLP Project/trainer_seq2seq.py" /usr/local/lib/python3.8/dist-packages/transformers/trainer_seq2seq.py

# Fine-tuning a model on a summarization task

In [6]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers import AutoTokenizer
from datasets import load_dataset, load_metric
import nltk
import numpy as np
import torch
from torch.utils.data import DataLoader
from modeling_bart import BartForConditionalGeneration  # Custom coref bart
from modeling_t5 import T5ForConditionalGeneration    # Custom coref t5


## Loading the dataset

In [7]:
raw_datasets = load_dataset("samsum")
raw_datasets = raw_datasets.filter(lambda data:data['dialogue'] != "")
metric = load_metric("rouge")



  0%|          | 0/3 [00:00<?, ?it/s]

  metric = load_metric("rouge")


In [8]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [9]:
import torch.nn as nn
class CustomTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # How the loss is computed by Trainer. By default, all models return the loss in the first element.
        # Subclass and override for custom behavior.
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)

        # Save past state if it exists
        # TODO: this needs to be fixed and mselfade cleaner later.

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        return (loss, outputs) if return_outputs else loss

# Coreference functions

## Preprocess functions

In [10]:
def save_processed_coref_dataset(prefix, datasets):
    """
    :param prefix: prefix arbitrary name
    :param datasets: the datasets need to save.
    :return: directly write a file.
    """
    task_names = ['train', 'test', 'validation']
    for name in task_names:
        output_fp = open("data/" + prefix + '-' + name + '.source', 'w', encoding='utf-8')
        current_dataset = datasets[name]
        for data in current_dataset:
            output_fp.write(data['id'] + '#' + str(data['input_ids']) + '#' 
                + str(data['coref_information'][0]) + '#' + str(data['coref_information'][1]) + '\n')
        output_fp.close()


def convert_str_list_to_list(str_list):
    tmp = str_list.strip().replace("[", "").replace("]", "").split(",")
    tmp = [int(i.strip()) for i in tmp if len(i.strip()) > 0]
    return tmp


def load_processed_coref_dataset(prefix):
    """
    :param prefix: prefix arbitrary name
    :return: input_ids and coref_information
    """
    task_names = ['train', 'test', 'validation']
    datasets = {}
    for name in task_names:
        input_ids_list = []
        coref_information_list = []
        result_dict = {}
        input_fp = open("data/" + prefix + '-' + name + '.source', 'r', encoding='utf-8')
        for line in input_fp:
            tmp = line.split('#')
            input_ids_list.append(convert_str_list_to_list(tmp[1]))
            coref_information_list.append((convert_str_list_to_list(tmp[2]), convert_str_list_to_list(tmp[3])))
        result_dict['input_ids'] = input_ids_list
        result_dict['coref_information'] = coref_information_list
        datasets[name] = result_dict
    
    return datasets


# preprocess fucntions
def add_preprocessed_data_train(x, indice):
    x['input_ids'] = coref_datasets['train']['input_ids'][indice]
    x['attention_mask'] = [1] * len(coref_datasets['train']['input_ids'][indice])
    x['coref_information'] = coref_datasets['train']['coref_information'][indice]
    return x


def add_preprocessed_data_test(x, indice):
    x['input_ids'] = coref_datasets['test']['input_ids'][indice]
    x['attention_mask'] = [1] * len(coref_datasets['test']['input_ids'][indice])
    x['coref_information'] = coref_datasets['test']['coref_information'][indice]
    return x


def add_preprocessed_data_validation(x, indice):
    x['input_ids'] = coref_datasets['validation']['input_ids'][indice]
    x['attention_mask'] = [1] * len(coref_datasets['validation']['input_ids'][indice])
    x['coref_information'] = coref_datasets['validation']['coref_information'][indice]
    return x


max_input_length = 1024
max_target_length = 128
def preprocess_function(examples):
    model_inputs = examples
    
    # Setup the tokenizer for targets
    labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def preprocess_function_t5(examples):
    task_prefix = "summarize: "
    first_word = " #"  # replace the first token of original input ids
    first_token = tokenizer(first_word)
    tokenized_prefix = tokenizer(task_prefix)

    model_inputs = examples

    # replace the original input_ids
    input_ids = [tokenized_prefix['input_ids'][:2] + first_token['input_ids'][:1] + i[1:] for i in model_inputs['input_ids']]
    model_inputs['input_ids'] = input_ids

    # update attention mask
    mask = [[1, 1] + i for i in model_inputs['attention_mask']]
    model_inputs['attention_mask'] = mask

    new_coref_information = []
    # update coreference information
    for source_list, target_list in model_inputs['coref_information']:
        temp_list = []
        new_source_list = [i+2 for i in source_list]
        new_target_list = [i+2 for i in target_list]
        temp_list.append(new_source_list)
        temp_list.append(new_target_list)
        new_coref_information.append(temp_list)
    
    model_inputs['coref_information'] = new_coref_information

    # setup the tokenizer for targets
    labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# save function
def save_result(result, path):
    with open(path, "w", encoding="utf-8") as fp:
        for i in result:
            fp.write(i.strip() + "\n")

## Coreference information extraction functions

In [11]:
# Coreference information extraction takes a lot of time. Thus, we only extracted once and saved the result for future usage.

In [12]:
# !pip instal allennlp allennlp_models

In [13]:
# # code from https://github.com/seq-to-mind/coref_dial_summ modified.
# import re
# import pickle
# from tqdm import tqdm
# from allennlp.predictors.predictor import Predictor
# import allennlp_models.coref
# from allennlp_models import pretrained


# class NeuralCoreferenceProcessing:
#     def __init__(self, gpu_id=-1):
#         """ download and indicate the path of pre-trained coref-spanbert model """
#         # self.predictor = Predictor.from_path("/content/drive/MyDrive/coref-spanbert-large-2021.03.10.tar.gz.tar", cuda_device=gpu_id)
#         self.predictor = pretrained.load_predictor("coref-spanbert", cuda_device=gpu_id)

#     def process(self, input_list, batch_size=4):
#         # output_list = {"dot": [], "sharp": [], "newline": [], "semicolon": []}
#         output_list = []
#         dataloader = DataLoader(input_list, batch_size=batch_size, shuffle=False)

#         # for tmp_content in tqdm(input_list):
#         # for tmp_content in input_list:
#         for tmp_content in dataloader:
#             tmp_content = [dialogue.replace("#", " ").replace("\r\n", " # ").replace("\n", " ").replace("🙂", " ") 
#                             for dialogue in tmp_content]
#             tmp_content = [re.sub("\s+", " ", dialogue).strip() for dialogue in tmp_content]

#             """ here we replace the sentence segmenter, to obtain multiple coreference resolution outputs """
#             tmp_res_with_dot_batch = self.predictor.predict_batch_json([{"document": dialogue.replace("#", ".")} for dialogue in tmp_content])
#             tmp_res_with_sharp_batch = self.predictor.predict_batch_json({"document": dialogue} for dialogue in tmp_content)
#             tmp_res_with_newline_batch = self.predictor.predict_batch_json([{"document": dialogue.replace("#", "\n")} for dialogue in tmp_content])
#             tmp_res_with_semicolon_batch = self.predictor.predict_batch_json([{"document": dialogue.replace("#", ";")} for dialogue in tmp_content])

#             # # """ check the length of multiple coreference resolution outputs are the same """
#             # assert len(tmp_res_with_dot['document']) == len(tmp_res_with_sharp['document'])
#             # assert len(tmp_res_with_newline['document']) == len(tmp_res_with_sharp['document'])
#             # assert len(tmp_res_with_semicolon['document']) == len(tmp_res_with_sharp['document'])

#             for (dialogue, tmp_res_with_dot, tmp_res_with_sharp, tmp_res_with_newline, tmp_res_with_semicolon) in \
#                 zip(tmp_content, tmp_res_with_dot_batch, tmp_res_with_sharp_batch, tmp_res_with_newline_batch, tmp_res_with_semicolon_batch):
#                 # """ check the length of multiple coreference resolution outputs are the same """
#                 assert len(tmp_res_with_dot['document']) == len(tmp_res_with_sharp['document'])
#                 assert len(tmp_res_with_newline['document']) == len(tmp_res_with_sharp['document'])
#                 assert len(tmp_res_with_semicolon['document']) == len(tmp_res_with_sharp['document'])
                
#                 tmp_res_with_dot['document'] = tmp_res_with_sharp['document']
#                 tmp_res_with_newline['document'] = tmp_res_with_sharp['document']
#                 tmp_res_with_semicolon['document'] = tmp_res_with_sharp['document']

#                 """ ensemble multiple coreference resolution outputs """
#                 output_list.append({"dot": (dialogue, tmp_res_with_dot),
#                                     "sharp": (dialogue, tmp_res_with_sharp),
#                                     "newline": (dialogue, tmp_res_with_newline),
#                                     "semicolon": (dialogue, tmp_res_with_semicolon)}, )
#             # output_list["dot"].extend(tmp_res_with_dot_batch)
#             # output_list["sharp"].extend(tmp_res_with_sharp_batch)
#             # output_list["newline"].extend(tmp_res_with_newline_batch)
#             # output_list["semicolon"].extend(tmp_res_with_semicolon_batch)

#         return output_list

In [14]:
# import copy
# from tqdm import tqdm
# import re
# import numpy as np
# import pickle
# from transformers import AutoTokenizer


# def Prev_Coreference_Matrix(token_length, src_list, tgt_list):
#     """ build the prev-linked coreference matrix """
#     coref_matrix = np.zeros([token_length, token_length], dtype=float)
#     assert len(src_list) == len(tgt_list)
#     for i in range(len(src_list)):
#         coref_matrix[src_list[i]][tgt_list[i]] = 1
#     for i in range(token_length):
#         if sum(coref_matrix[i]) == 0:
#             coref_matrix[i][i] = 1
#     return coref_matrix


# class BuildSampleWithCoreferenceInfo:
#     def __init__(self, tokenizer):
#         """ Here we use the tokenizer from BART """
#         self.global_tokenizer = tokenizer

#     def build_sample_with_coref_to_file(self, input_list, aux_condition_name_file=None, conditional_file_path=None, debug=False):
#         """
#         :param aux_condition_name_file: each row will contain the speaker roles / personal named entities
#         :param input_list: the list of conversations
#         :param conditional_file_path: the list of conversations with conditional planning
#         :param debug: for debug print
#         :return: directly write a file.
#         """

#         if conditional_file_path is not None:
#             conditional_line_list = open(conditional_file_path, encoding="utf-8").readlines()
#             assert len(conditional_line_list) == len(input_list)

#         output_list = {'input_ids': [], 'coref_information': []}

#         if aux_condition_name_file is not None:
#             aux_name_list = open(aux_condition_name_file, encoding="utf-8").readlines()
#             assert len(aux_name_list) == len(input_list)
#         else:
#             aux_name_list = None

#         tmp_line_idx = 0
#         # for tmp_k, tmp_dict_node in tqdm(enumerate(input_list)):
#         for tmp_k, tmp_dict_node in enumerate(input_list):
#             """ we use the multiple coreference resolution outputs """
#             for coref_idx, coref_type in enumerate(['newline', 'dot', 'sharp', 'semicolon']):
#                 tmp_i = tmp_dict_node[coref_type]
#                 tmp_token_list = tmp_i[1]["document"]
#                 tmp_clusters = tmp_i[1]["clusters"]

#                 # print(tmp_i[0])
#                 raw_coref_cluster_info = []
#                 for k, i in enumerate(tmp_clusters):
#                     one_list = [" ".join(tmp_token_list[j[0]:j[1] + 1]) for j in i]
#                     raw_coref_cluster_info.append((k, one_list))

#                 """ Tackle the issue that some speaker names are not included in coreference chains """
#                 tmp_new_coref_cluster_info = copy.deepcopy(raw_coref_cluster_info)

#                 tmp_titled_speakers = set([i[:-1] for i in tmp_i[0].split() if i[-1] == ":" and i.istitle()])
#                 tmp_speaker_label_dict = {}
#                 for k, v in enumerate(tmp_titled_speakers):
#                     tmp_cluster_res = [j[0] for j in tmp_new_coref_cluster_info if v in j[1]]
#                     if len(tmp_cluster_res) < 1 and v not in tmp_speaker_label_dict.keys():
#                         tmp_speaker_label_dict[v] = len(tmp_new_coref_cluster_info) + 30
#                         tmp_new_coref_cluster_info.append((len(tmp_new_coref_cluster_info) + 30, [v]))
#                     else:
#                         if len(tmp_cluster_res) == 1:
#                             tmp_speaker_label_dict[v] = tmp_cluster_res[0]
#                         if len(tmp_cluster_res) > 1:
#                             """ Here we select the first found token as the cluster label """
#                             q_list = [(q, tmp_clusters[q][tmp_new_coref_cluster_info[q][1].index(v)][0]) for q in tmp_cluster_res]
#                             q_list = sorted(q_list, key=lambda x: x[1])
#                             tmp_speaker_label_dict[v] = q_list[0][0]

#                 if aux_name_list is not None:
#                     assert len(re.findall("\}\s+\#", aux_name_list[tmp_k])) == 1
#                     aux_one_cond_name_set = set(re.sub("[\#\.\|\{\}]]", " ", aux_name_list[tmp_k].split("}")[0]).split())
#                 else:
#                     aux_one_cond_name_set = set()

#                 continue_flag = False
#                 for tmp_item in tmp_new_coref_cluster_info:
#                     tmp_small_set = set([i.strip().split()[0] for i in tmp_item[1]])
#                     intersection = tmp_small_set & (set(tmp_titled_speakers) | set(aux_one_cond_name_set))
#                     if len(intersection) > 1:
#                         if len(set([i[:2] for i in intersection])) > 1:
#                             if coref_idx == 3:
#                                 print("one plausible coreference chain.")
#                             continue_flag = True

#                 if continue_flag is False:
#                     break

#             """ Further add titled words to increase coverage """
#             tmp_titled_other_tokens = set([i for i in tmp_i[1]["document"] if len(i) > 2 and i.istitle()])
#             for k, v in enumerate(tmp_titled_other_tokens):
#                 tmp_cluster_res = [j[0] for j in tmp_new_coref_cluster_info if v in j[1]]
#                 if len(tmp_cluster_res) < 1 and v not in tmp_speaker_label_dict.keys():
#                     tmp_cluster_res = [(j[0], j[1].count(v)) for j in tmp_new_coref_cluster_info if v in " ".join(j[1]).split()]
#                     tmp_cluster_res = sorted(tmp_cluster_res, key=lambda x: x[1], reverse=True)
#                     if len(tmp_cluster_res) > 0:
#                         tmp_speaker_label_dict[v] = tmp_cluster_res[0][0]
#                     else:
#                         tmp_speaker_label_dict[v] = len(tmp_new_coref_cluster_info) + 100
#                         tmp_new_coref_cluster_info.append((len(tmp_new_coref_cluster_info) + 100, [v]))

#             """ Adding spaces in tokenized list, to recover the same tokenization via BART """
#             tmp_doc = copy.deepcopy(" " + tmp_i[0])

#             tmp_token_list_with_space = []
#             for k, v in enumerate(tmp_token_list):
#                 find_idx = str(tmp_doc).index(v)
#                 if find_idx > 0 and tmp_doc[find_idx - 1] == " ":
#                     tmp_token_list_with_space.append([" " + v, -1])
#                 else:
#                     tmp_token_list_with_space.append([v, -1])
#                 tmp_doc = tmp_doc[find_idx + len(v):]

#             """ Labeling the token list with the coreference cluster labels """
#             """ From the longer spans to shorter spans, to avoiding labels to be re-changed """
#             tmp_span_len_list = []
#             for i in tmp_clusters:
#                 tmp_span_len_list.extend([j[-1] + 1 - j[0] for j in i])

#             tmp_span_len_list = list(set(tmp_span_len_list))
#             tmp_span_len_list = sorted(tmp_span_len_list, reverse=True)

#             for one_len in tmp_span_len_list:
#                 for i in range(len(tmp_clusters)):
#                     for j in tmp_clusters[i]:
#                         if (j[1] + 1 - j[0]) == one_len:
#                             for e in j:
#                                 tmp_token_list_with_space[e][1] = i

#             """ Tackle the issue that speaker names do not have coreference """
#             assert len(tmp_token_list_with_space) == len(tmp_token_list)
#             for k in range(len(tmp_token_list_with_space)):
#                 if (k == len(tmp_token_list_with_space) - 1 or tmp_token_list_with_space[k + 1][0].strip() == ":") \
#                         and tmp_token_list_with_space[k][0].strip() in tmp_speaker_label_dict.keys() \
#                         and tmp_token_list_with_space[k][1] == -1:
#                     tmp_token_list_with_space[k][1] = tmp_speaker_label_dict[tmp_token_list_with_space[k][0].strip()]
#                     # print(tmp_token_list_with_space)

#             """ Merge the token list with the same coreference cluster """
#             merged_tmp_token_list_with_space = []
#             current_merge_set = []
#             current_cluster_id_to_merge = -999
#             for i in range(len(tmp_token_list_with_space)):
#                 if tmp_token_list_with_space[i][1] == current_cluster_id_to_merge:
#                     current_merge_set.append(tmp_token_list_with_space[i])
#                     current_cluster_id_to_merge = tmp_token_list_with_space[i][1]
#                 else:
#                     if len(current_merge_set) > 0:
#                         merged_tmp_token_list_with_space.append([[j[0] for j in current_merge_set], current_merge_set[0][1]])
#                         current_merge_set = []
#                     current_merge_set.append(tmp_token_list_with_space[i])
#                     current_cluster_id_to_merge = tmp_token_list_with_space[i][1]
#                 if i == len(tmp_token_list_with_space) - 1 and len(current_merge_set) > 0:
#                     merged_tmp_token_list_with_space.append([[j[0] for j in current_merge_set], current_merge_set[0][1]])
#                     current_merge_set = []

#             # print(merged_tmp_token_list_with_space)

#             """ Using the BART tokenizer to process the new token list """
#             for i in range(len(merged_tmp_token_list_with_space)):
#                 merged_tmp_token_list_with_space[i][0] = self.global_tokenizer.tokenize("".join(merged_tmp_token_list_with_space[i][0]))

#             """ V2 only point to the first token of spans """
#             tmp_token_list_with_cluster_ids = []
#             for i in merged_tmp_token_list_with_space:
#                 for j in range(len(i[0])):
#                     if j == 0:
#                         tmp_token_list_with_cluster_ids.append([i[0][j], i[1]])
#                     else:
#                         tmp_token_list_with_cluster_ids.append([i[0][j], -1])

#             if debug:
#                 tmp_t = " ".join(self.global_tokenizer.tokenize(tmp_i[0])).strip()
#                 tmp_c = " ".join([j[0] for j in tmp_token_list_with_cluster_ids]).strip()
#                 print("\n", tmp_t, "\n", tmp_c)

#             """ Adding coreference of the conditional personal names """
#             if conditional_file_path is not None:
#                 assert len(re.findall("\}\s+\#", conditional_line_list[tmp_k])) == 1
#                 conditional_names = (conditional_line_list[tmp_k].split("}")[0] + "} #").split()
#                 conditional_names = [[i.strip(), -1] for i in conditional_names]

#                 for k, v in enumerate(conditional_names):
#                     if v[0] not in ["{", "}", "#", "|"]:
#                         tmp_cluster_res = [(j[0], j[1].count(v[0])) for j in tmp_new_coref_cluster_info if v[0] in j[1]]
#                         if len(tmp_cluster_res) > 0:
#                             conditional_names[k][1] = tmp_cluster_res[0][0]
#                             # if len(tmp_cluster_res) > 1:
#                             #     print(tmp_cluster_res)
#                         else:
#                             """ splitting every name in the cluster keys, then find more names """
#                             tmp_cluster_res = [(j[0], j[1].count(v[0])) for j in tmp_new_coref_cluster_info if v[0] in " ".join(j[1]).split()]

#                             if len(tmp_cluster_res) > 0:
#                                 conditional_names[k][1] = tmp_cluster_res[0][0]
#                             else:
#                                 """ To tackle the exception of names are not included """
#                                 tmp_new_coref_cluster_info.append((len(tmp_new_coref_cluster_info) + 50, [v[0]]))
#                                 for n, i in enumerate(tmp_token_list_with_cluster_ids):
#                                     if i[0][1:] == v[0] and i[1] == -1:
#                                         tmp_token_list_with_cluster_ids[n][1] = tmp_new_coref_cluster_info[-1][0]
#                                         conditional_names[k][1] = tmp_new_coref_cluster_info[-1][0]
#                                 print("\n\n\n")
#                                 print(tmp_i[0])
#                                 print(v[0])
#                                 print(tmp_token_list_with_cluster_ids)
#                                 pass

#                 print(conditional_names)
#                 condition_prefix = conditional_names
#                 tmp_prefix = []
#                 for i in condition_prefix:
#                     tmp_t = self.global_tokenizer.tokenize(" " + i[0])
#                     for j in range(len(tmp_t)):
#                         if j == 0:
#                             tmp_prefix.append((tmp_t[j], i[1]))
#                         else:
#                             tmp_prefix.append((tmp_t[j], -1))

#                 tmp_token_list_with_cluster_ids = tmp_prefix + tmp_token_list_with_cluster_ids

#             else:
#                 tmp_token_list_with_cluster_ids = [['#', -1]] + tmp_token_list_with_cluster_ids

#             """ Build the src list ang tgt list for DGL GNN implementation """
#             src_list = []
#             tgt_list = []
#             text_input_list = []
#             for k, v in enumerate(tmp_token_list_with_cluster_ids):
#                 text_input_list.append(v[0])
#                 if v[1] != -1:
#                     find_precedent = [j for j in range(k) if tmp_token_list_with_cluster_ids[j][1] == v[1]]
#                     if len(find_precedent) > 0:
#                         src_list.append(k)
#                         tgt_list.append(max(find_precedent))

#             assert len(src_list) == len(tgt_list)

#             tmp_line_idx += 1

#             """ Truncating the lengthy samples """
#             if len(text_input_list) > 1023:
#                 print("Truncate the lengthy sample:", len(text_input_list))
#                 cut_num = len([i for i in src_list if i > 1022])
#                 print(cut_num)
#                 src_list = src_list[:-cut_num]
#                 tgt_list = tgt_list[:-cut_num]
#                 print(src_list)
#                 print(tgt_list)
#                 assert len(src_list) == len(tgt_list)

#             """ We write all information as a text file """
#             text_input_list = text_input_list[:1023]
#             # output_fp.write(" ".join(text_input_list) + " ##### " + str(self.global_tokenizer.convert_tokens_to_ids(text_input_list)) + \
#             #                 " ##### " + str(src_list) + " ##### " + str(tgt_list) + " ##### " + str(len(text_input_list)) + "\n")
#             output_list["input_ids"].append(self.global_tokenizer.convert_tokens_to_ids(text_input_list))
#             output_list["coref_information"].append((src_list, tgt_list))

#             assert len(text_input_list) == len(self.global_tokenizer.convert_tokens_to_ids(text_input_list))
#             assert len(self.global_tokenizer.convert_tokens_to_ids(text_input_list)) < 1024

#             if debug:
#                 for k, v in enumerate(text_input_list):
#                     if k in src_list:
#                         print(">>>>>>>> ", v.replace("Ġ", ""), k, tgt_list[src_list.index(k)])
#                     else:
#                         print(v.replace("Ġ", ""), k, "X")
#                 tmp_matrix = Prev_Coreference_Matrix(len(text_input_list), src_list, tgt_list)
#                 print(tmp_matrix)

#         return output_list

In [15]:
# max_target_length = 128
# def preprocess_function(examples):
#     # model_inputs = tokenizer(examples["dialogue"], max_length=max_input_length, truncation=True)
#     result = coref_model.process(examples["dialogue"], batch_size=4)
#     model_inputs = coref_build.build_sample_with_coref_to_file(result)

#     # Setup the tokenizer for targets
#     labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

#     model_inputs["labels"] = labels["input_ids"]
#     return model_inputs

In [16]:
# processed_dataset = raw_datasets.map(preprocess_function, batched=True)

In [17]:
# # save coreference dataset, prefix: ['bart', 't5-base']
# prefix = 'bart'
# coref_datasets = save_processed_coref_dataset(prefix, processed_datasets)

# Bart

## Load saved coreference dataset and preprocess

In [18]:
# load coreference dataset, prefix: ['bart', 't5-base']
prefix = 'bart'
coref_datasets = load_processed_coref_dataset(prefix)

model_checkpoint = "facebook/bart-base"

In [19]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [20]:
processed_train_dataset = raw_datasets['train'].map(add_preprocessed_data_train, with_indices=True)
processed_test_dataset = raw_datasets['test'].map(add_preprocessed_data_test, with_indices=True)
processed_validation_dataset = raw_datasets['validation'].map(add_preprocessed_data_validation, with_indices=True)



  0%|          | 0/14731 [00:00<?, ?ex/s]

Exception ignored in: <function tqdm.__del__ at 0x7f8465529ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/tqdm/std.py", line 1161, in __del__
    def __del__(self):
KeyboardInterrupt: 


  0%|          | 0/819 [00:00<?, ?ex/s]

  0%|          | 0/818 [00:00<?, ?ex/s]

In [21]:
import copy
processed_datasets = copy.deepcopy(raw_datasets)

In [22]:
processed_datasets['train'] = processed_train_dataset
processed_datasets['test'] = processed_test_dataset
processed_datasets['validation'] = processed_validation_dataset

In [None]:
tokenized_coref_datasets = processed_datasets.map(preprocess_function, batched=True)

## Fine-tuning

In [None]:
model = BartForConditionalGeneration.from_pretrained(model_checkpoint, output_hidden_states=False)

In [None]:
batch_size = 16

args = Seq2SeqTrainingArguments(
    "16-coref-bart-dialogue-summarization",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # gradient_accumulation_steps=2,
    weight_decay=0.01,
    # save_total_limit=2,
    num_train_epochs=5,
    logging_steps = 10, ## added
    predict_with_generate=True,
    fp16=True,
    report_to="none",
    generation_max_length=max_target_length,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)

In [None]:
tokenized_coref_datasets_train = tokenized_coref_datasets['train']
tokenized_coref_dataset_val = tokenized_coref_datasets['validation']

In [None]:
trainer = CustomTrainer(
    model,
    args,
    train_dataset=tokenized_coref_datasets_train,
    eval_dataset=tokenized_coref_dataset_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
import nltk
nltk.download('punkt')

In [None]:
trainer.evaluate() #before training

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

## Save result

In [None]:
prediction = trainer.predict(tokenized_coref_dataset_val, max_length=128)

In [None]:
path = "result/bart-coref-16-generation.txt"
batch_result = tokenizer.batch_decode(prediction.predictions, skip_special_tokens=True)
save_result(batch_result, path)

# T5

## Load saved coreference dataset and preprocess

In [18]:
# load coreference dataset, prefix: ['bart', 't5-base']
prefix = 't5-base'
coref_datasets = load_processed_coref_dataset(prefix)

model_checkpoint = "t5-base"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_lenght=max_input_length)

In [None]:
processed_train_dataset = raw_datasets['train'].map(add_preprocessed_data_train, with_indices=True)
processed_test_dataset = raw_datasets['test'].map(add_preprocessed_data_test, with_indices=True)
processed_validation_dataset = raw_datasets['validation'].map(add_preprocessed_data_validation, with_indices=True)

In [21]:
import copy
processed_datasets = copy.deepcopy(raw_datasets)

In [22]:
processed_datasets['train'] = processed_train_dataset
processed_datasets['test'] = processed_test_dataset
processed_datasets['validation'] = processed_validation_dataset

In [None]:
tokenized_coref_datasets = processed_datasets.map(preprocess_function_t5, batched=True)

## Fine-tuning

In [24]:
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint, output_hidden_states=False)

In [25]:
batch_size = 8

args = Seq2SeqTrainingArguments(
    "16-coref-t5-dialogue-summarization",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # gradient_accumulation_steps=2,
    weight_decay=0.01,
    # save_total_limit=2,
    num_train_epochs=5,
    logging_steps = 10, ## added
    predict_with_generate=True,
    # fp16=True,
    report_to="none",
    generation_max_length=max_target_length,
)

In [26]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)

In [27]:
tokenized_coref_datasets_train = tokenized_coref_datasets['train']
tokenized_coref_dataset_val = tokenized_coref_datasets['validation']

In [28]:
trainer = CustomTrainer(
    model,
    args,
    train_dataset=tokenized_coref_datasets_train,
    eval_dataset=tokenized_coref_dataset_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [29]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
trainer.evaluate() #before training

In [None]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: dialogue, summary, id. If dialogue, summary, id are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 14731
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 9210
  Number of trainable parameters = 222903552
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,1.6268,1.470832,47.7309,23.0264,39.5677,43.8753,23.1993
2,1.5699,1.432313,49.0595,24.0731,40.4141,44.9689,23.7494
3,1.593,1.416316,49.7809,24.694,40.9114,45.7035,24.2311
4,1.4503,1.410689,49.947,24.986,41.2414,46.0481,24.6773


Saving model checkpoint to 16-coref-t5-dialogue-summarization/checkpoint-500
Configuration saved in 16-coref-t5-dialogue-summarization/checkpoint-500/config.json
Model weights saved in 16-coref-t5-dialogue-summarization/checkpoint-500/pytorch_model.bin
tokenizer config file saved in 16-coref-t5-dialogue-summarization/checkpoint-500/tokenizer_config.json
Special tokens file saved in 16-coref-t5-dialogue-summarization/checkpoint-500/special_tokens_map.json
Saving model checkpoint to 16-coref-t5-dialogue-summarization/checkpoint-1000
Configuration saved in 16-coref-t5-dialogue-summarization/checkpoint-1000/config.json
Model weights saved in 16-coref-t5-dialogue-summarization/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in 16-coref-t5-dialogue-summarization/checkpoint-1000/tokenizer_config.json
Special tokens file saved in 16-coref-t5-dialogue-summarization/checkpoint-1000/special_tokens_map.json
Saving model checkpoint to 16-coref-t5-dialogue-summarization/checkpoint-1500

In [None]:
trainer.evaluate()

## Save result

In [None]:
prediction = trainer.predict(tokenized_coref_dataset_val, max_length=128)

The following columns in the test set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: dialogue, summary, id. If dialogue, summary, id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 819
  Batch size = 8


In [None]:
path = "result/t5-generation.txt"
batch_result = tokenizer.batch_decode(prediction.predictions, skip_special_tokens=True)
save_result(batch_result, path)