## T5 Model

### Importing librariers

In [1]:
import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from evaluate import load

  from .autonotebook import tqdm as notebook_tqdm


### importing the data

In [2]:
df_train = pd.read_csv('../../Datasets/with_ctext/df_train.csv')
df_train.drop('Unnamed: 0', axis = 1, inplace = True)
df_train.head()

Unnamed: 0,headlines,ctext
0,Chhattisgarh to start ambulance service for cows,The Chhattisgarh government will begin an amb...
1,Trucks dumping debris on wetlands seized in Mu...,Ten trucks and an excavator machine that were ...
2,Modi pays homage to Indian World War I heroes ...,On the last day of his three-day trip to Israe...
3,Delhi's domesticated elephants may be shifted ...,Delhi could soon lose all its seven elephants ...
4,Ranchi civic body uses 'Sholay' climax to prom...,The Ranchi Nagar Nigam has upped the ante with...


In [3]:
df_eval = pd.read_csv('../../Datasets/with_ctext/df_eval.csv')
df_eval.drop('Unnamed: 0', axis = 1, inplace = True)
df_eval.head()

Unnamed: 0,headlines,ctext
0,Delhi taxi driver returns lost bag with valuab...,A 24-year-old kaali-peeli taxi driver Debendra...
1,Recall what happened in 1971: Venkaiah Naidu t...,"Hitting out at Pakistan-sponsored terrorism, N..."
2,"Bihar minister abuses PM Modi, calls him a dacoit",Bihar's minister for excise and prohibition Ab...
3,6 arrested for blackmailing makers over Baahub...,The cyber crime police of Hyderabad have arres...
4,Indrani forged Peter?s signature on bank docum...,A special CBI court on Wednesday asked banks t...


In [6]:
df_test = pd.read_csv('../../Datasets/with_ctext/df_test.csv')
df_test.drop('Unnamed: 0', axis = 1, inplace = True)
df_test.head()

Unnamed: 0,headlines,ctext
0,Ex-Australian PM sends signed bat to Modi thro...,These days if you just happen to wait outside ...
1,Nearly 400 judicial officers transferred in Ut...,"Allahabad, Apr 29 (PTI) The Allahabad High Cou..."
2,"Big B complains about Vodafone on Twitter, RJi...",Bollywood actor Amitabh Bachchan has at least ...
3,No interference in Jayalalithaa's treatment: A...,Apollo Hospitals said on Tuesday there was ?no...
4,Varun's pants tear while dancing with contesta...,Varun Dhawan and Alia Bhatt have been frequent...


In [7]:
print(f"{len(df_test['headlines'][0])} :: {len(df_test['ctext'][0])}")

60 :: 1704


In [8]:
ref_summary = df_test['headlines']

In [10]:
df_train = Dataset.from_pandas(df_train)
df_eval = Dataset.from_pandas(df_eval)
df_test = Dataset.from_pandas(df_test)

### model and maping

In [11]:
model_name = 't5-small'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [15]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model = model)

In [16]:
def preprocess(data) : 
    model_inputs = tokenizer(
        data['ctext'], max_length = 1800, truncation = True
    )
    labels = tokenizer(
        data['headlines'], max_length = 30, truncation = True
    )

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [17]:
train_data = df_train.map(preprocess, batched = True)

Map: 100%|██████████| 3000/3000 [00:01<00:00, 2116.00 examples/s]


In [18]:
eval_data = df_eval.map(preprocess, batched = True)
test_data = df_test.map(preprocess, batched = True)

Map: 100%|██████████| 801/801 [00:00<00:00, 1785.03 examples/s]
Map: 100%|██████████| 595/595 [00:00<00:00, 2055.67 examples/s]


### trainer and traning arguments

In [19]:
training_args = TrainingArguments(
    output_dir = './results_with_ctext',
    eval_strategy = 'epoch',
    learning_rate = 2e-5,
    per_device_eval_batch_size = 128,
    per_device_train_batch_size = 128,
    num_train_epochs = 1,
    do_predict = True,
    save_total_limit = 2,
    logging_dir = './logs_with_ctext'
)

In [20]:
trainer = Trainer(
    args = training_args,
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_data,
    eval_dataset = eval_data,
    data_collator = data_collator
)

  trainer = Trainer(


In [None]:
trainer.train()

