### Paths to data files

In [1]:
# Paths to data files and output file on drive
train_data_file = 'train_data.json'
val_data_file = 'valid_data.json' #needs to be modified for test file
pred_out_file = 'prediction_out.json'

### Importing all dependecies

In [5]:
# Import all dependencies
import numpy as np
import pandas as pd
import spacy
import json
import re
# from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration

nlp = spacy.load("en_core_web_sm")

from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer

model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base")
tokenizer = AutoTokenizer.from_pretrained("t5-base")

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


### Function with model to generate the conclusion

In [6]:
def conclusion_generate(argument):
    
    inputs = tokenizer.encode("summarize: " + argument, return_tensors = "tf", max_length = 512, truncation = True) #tokeninzing the argument
    outputs = model.generate(inputs, max_length=150, min_length=10, length_penalty=2.0, num_beams=4, early_stopping=True) #generating the summary of the argument
    gen_conclusion = str(tokenizer.decode(outputs[0]))
    gen_conclusion = gen_conclusion.replace("<pad>","").lstrip().rstrip() #removes <pad> string generated when using t5 model
    
    return gen_conclusion

### Reading and preprocessing the data file

In [7]:
df_train_init = pd.read_json(train_data_file)
df_val_init = pd.read_json(val_data_file)
df_val_init['argument'] = df_val_init['argument'].apply(lambda x: re.sub('This is a footnote.*$', '',x,flags = re.DOTALL).strip()) #removing any footnote occuring in the data
df_val_init['argument'] = df_val_init['argument'].apply(lambda x: re.sub('gt','',x,flags = re.DOTALL).strip()) #removing gt tags observed in data
val_args_list = df_val_init['argument'].to_list()

In [None]:
# Call to conclusion generation function

conclusion_list = [conclusion_generate(item) for item in tqdm(val_args_list)]

### Generating the predictions file

In [None]:
val_data_id_list = df_val_init['id'].to_list()
pred_val = dict(zip(val_data_id_list, conclusion_list))

with open(pred_out_file, 'w') as fp:
    json.dump(pred_val,fp)