In [1]:
!pip install --quiet transformers
!pip install --quiet sentencepiece
!pip install --quiet datasets
!pip install --quiet rouge_score
! pip install --quiet  evaluate


[notice] A new release of pip is available: 23.3.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 23.3.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 23.3.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 23.3.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 23.3.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
from datasets import load_dataset
import pandas as pd

languages = ["telugu", "urdu", "marathi", "hindi", "tamil", "bengali", "english"]

dfs = []
for lang in languages:
    dataset = load_dataset("csebuetnlp/xlsum", lang, split="train[:2000]")
    df = dataset.to_pandas()
    dfs.append(df)

df = pd.concat(dfs, ignore_index=True)

print(df.shape)
print(df.head())

(14000, 5)
                       id                                                url  \
0  international-53649907  https://www.bbc.com/telugu/international-53649907   
1          india-46550604          https://www.bbc.com/telugu/india-46550604   
2          india-43404438          https://www.bbc.com/telugu/india-43404438   
3  international-54671956  https://www.bbc.com/telugu/international-54671956   
4                53723894                https://www.bbc.com/telugu/53723894   

                                               title  \
0  పాకిస్తాన్ ఎయిర్‌లైన్స్‌లో నకిలీ లైసెన్సుల పైల...   
1  తెలంగాణ ముఖ్యమంత్రిగా కేసీఆర్ రెండోసారి ప్రమాణ...   
2  ‘అధికారం కొన్ని కులాల గుప్పిట్లోనే ఉండాలా? కుద...   
3  పోలండ్‌లో కొత్త అబార్షన్ చట్టాలను వ్యతిరేకిస్త...   
4  దిల్లీ అల్లర్లపై పరస్పర విరుద్ధ నివేదికలు... ఏ...   

                                             summary  \
0  పాకిస్తాన్ విమానయాన రంగంలో కొత్త సంక్షోభం మొదల...   
1  తెలంగాణ ముఖ్యమంత్రిగా కల్వకుంట్ల చంద్రశేఖర్ రా...   
2  

In [3]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AdamW, get_scheduler

import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

from tqdm.auto import tqdm

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import evaluate



In [4]:
tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")

model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [5]:
class SummaryDataset(Dataset):
    def __init__(
        self,
        data=df,
        tokenizer=tokenizer,
        text_max_token_len = 200,
        summary_max_token_len = 12
    ):
        self.tokenizer = tokenizer
        self.data = data
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]

        text = data_row['text']

        text_encoding = tokenizer(
            text,
            max_length=self.text_max_token_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors='pt'
        )

        summary_encoding = tokenizer(
            data_row['summary'],
            max_length=self.summary_max_token_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors='pt'
        )

        labels = summary_encoding['input_ids']
        labels[labels == tokenizer.pad_token_id] = -100

        return dict(
            input_ids=text_encoding['input_ids'].flatten(),
            attention_mask=text_encoding['attention_mask'].flatten(),
            labels=labels.flatten(),
            decoder_attention_mask=summary_encoding['attention_mask'].flatten()
        )

In [6]:
df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)

train_dataset = SummaryDataset(data=df_train)
test_dataset = SummaryDataset(data=df_test)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=10)
eval_dataloader = DataLoader(test_dataset, batch_size=10)

In [7]:
num_epochs = 10

num_training_steps = num_epochs * len(train_dataloader)

optimizer = AdamW(model.parameters())
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
#         logits = outputs.logits
#         predictions = torch.argmax(logits, dim=-1)
#         print(predictions)
#         print(batch["labels"])

        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()

        optimizer.zero_grad()
        progress_bar.update()

    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, f'./t5-multi.pth')

    print(f'epoch: {epoch + 1} -- loss: {loss}')

  0%|          | 0/11200 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


epoch: 1 -- loss: 2.863598585128784
epoch: 2 -- loss: 2.653280735015869
epoch: 3 -- loss: 1.7017065286636353
epoch: 4 -- loss: 1.02516508102417
epoch: 5 -- loss: 0.6968614459037781
epoch: 6 -- loss: 0.38551315665245056
epoch: 7 -- loss: 0.33853280544281006
epoch: 8 -- loss: 0.20748813450336456
epoch: 9 -- loss: 0.093267060816288
epoch: 10 -- loss: 0.15542356669902802


In [11]:
import torch
import evaluate

rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

model.eval()

all_predictions = []
all_references = []

for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}

    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)

    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 with pad_token_id before decoding
    labels = batch["labels"]
    labels[labels == -100] = tokenizer.pad_token_id

    decoded_references = tokenizer.batch_decode(labels, skip_special_tokens=True)

    all_predictions.extend(decoded_predictions)
    all_references.extend(decoded_references)

rouge_score = rouge.compute(predictions=all_predictions, references=all_references)
bleu_score = bleu.compute(predictions=all_predictions, references=all_references)

print("ROUGE Score:", rouge_score)
print("BLEU Score:", bleu_score)


ROUGE Score: {'rouge1': 0.04901744473553138, 'rouge2': 0.012986807776723739, 'rougeL': 0.0481451259684343, 'rougeLsum': 0.048057580824318505}
BLEU Score: {'bleu': 0.07339079467665911, 'precisions': [0.24043138084161556, 0.09229823482919118, 0.04994760740482012, 0.034193879295606085], 'brevity_penalty': 0.9353619257662512, 'length_ratio': 0.9373637264618434, 'translation_length': 14187, 'reference_length': 15135}
