## CNN DAILY/MAIL DATASET

In [1]:
import config
from datasets import load_dataset
import os

os.makedirs("data")

In [9]:
def preprocess_function_cnn_dailymail(
        examples,
        tokenizer,
        max_input_length: int = config.MAX_INPUT_LENGTH,
        max_target_length: int = config.MAX_TARGET_LENGTH,
    ):
        prefix = "summarize: "
        inputs = [prefix + doc for doc in examples["article"]]
        model_inputs = tokenizer(
            inputs, max_length=max_input_length, truncation=True, padding=True
        )

        # Setup the tokenizer for targets
        labels = tokenizer(
            text_target=examples["highlights"],
            max_length=max_target_length,
            truncation=True,
        )

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

# data_dir = "data"

cnn_data_train = load_dataset(
    "cnn_dailymail",'3.0.0', split=f"train[:{config.PERCENT_DATA}%]"
)
cnn_data_test = load_dataset(
    "cnn_dailymail",'3.0.0' , split=f"test[:{config.PERCENT_DATA}%]"
)
cnn_data_val = load_dataset(
    "cnn_dailymail",'3.0.0',split=f"validation[:{config.PERCENT_DATA}%]"
)

Downloading data: 100%|██████████| 257M/257M [00:07<00:00, 32.7MB/s] 
Downloading data: 100%|██████████| 257M/257M [00:06<00:00, 40.6MB/s] 
Downloading data: 100%|██████████| 259M/259M [00:07<00:00, 35.2MB/s] 
Downloading data: 100%|██████████| 34.7M/34.7M [00:01<00:00, 22.7MB/s]
Downloading data: 100%|██████████| 30.0M/30.0M [00:00<00:00, 37.9MB/s]
Generating train split: 100%|██████████| 287113/287113 [00:03<00:00, 81410.03 examples/s]
Generating validation split: 100%|██████████| 13368/13368 [00:00<00:00, 91425.24 examples/s]
Generating test split: 100%|██████████| 11490/11490 [00:00<00:00, 81273.17 examples/s]


In [11]:
import torch
from transformers import AutoTokenizer
from transformers import T5ForConditionalGeneration

model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name, legacy=False)
def preprocess_function_summary(examples,max_input_length:int=config.MAX_INPUT_LENGTH,max_target_length:int=config.MAX_TARGET_LENGTH):
        prefix = "summarize: "
        inputs = [prefix + doc for doc in examples["article"]]
        model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True,padding=True)

        # Setup the tokenizer for targets
        labels = tokenizer(text_target=examples["highlights"], max_length=max_target_length, truncation=True)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

tokenized_datasets_train = cnn_data_train.map(preprocess_function_summary, batched=True,remove_columns=['article','highlights','id'],batch_size=config.TOKENIZE_BATCH_SIZE)
tokenized_datasets_val = cnn_data_val.map(preprocess_function_summary, batched=True,remove_columns=['article','highlights','id'],batch_size=config.TOKENIZE_BATCH_SIZE)
tokenized_datasets_test = cnn_data_test.map(preprocess_function_summary, batched=True,remove_columns=['article','highlights','id'],batch_size=config.TOKENIZE_BATCH_SIZE)

Map: 100%|██████████| 287113/287113 [02:20<00:00, 2039.67 examples/s]
Map: 100%|██████████| 13368/13368 [00:06<00:00, 2051.93 examples/s]
Map: 100%|██████████| 11490/11490 [00:05<00:00, 2004.15 examples/s]


## ARXIV DATASET (TOO BIG PLANNING TO DISCARD)

In [20]:
# from datasets import load_dataset
arxiv_link = "https://github.com/armancohan/long-summarization/tree/master?tab=readme-ov-file"
# dataset = load_dataset("arxiv_dataset",trust_remote_code=True)

## PubMed DATASET

In [23]:
# import gdown
# pubmed_url = "https://archive.org/download/armancohan-long-summarization-paper-code/pubmed-dataset.zip"
# https://huggingface.co/datasets/scientific_papers?row=0
# output = 'pubmed.zip'
# gdown.download(pubmed_url, output, quiet=False)

In [3]:
import os
pubmed_dir = "pubmed-dataset"

for file in os.listdir(pubmed_dir):
    # print(file)
    if file.endswith(".txt"):
        txt_file = os.path.join(pubmed_dir, file)

test.txt
vocab
val.txt
train.txt


In [4]:
from datasets import load_dataset

pubmed_dataset = load_dataset(
    "scientific_papers","pubmed"
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Downloading builder script: 100%|██████████| 5.35k/5.35k [00:00<00:00, 21.5MB/s]
Downloading readme: 100%|██████████| 8.27k/8.27k [00:00<00:00, 16.4MB/s]
Downloading data: 100%|██████████| 3.62G/3.62G [01:48<00:00, 33.4MB/s] 
Downloading data: 100%|██████████| 880M/880M [00:22<00:00, 39.6MB/s] 
Generating train split: 100%|██████████| 119924/119924 [00:32<00:00, 3697.11 examples/s]
Generating validation split: 100%|██████████| 6633/6633 [00:02<00:00, 2674.18 examples/s]
Generating test split: 100%|██████████| 6658/6658 [00:01<00:00, 5007.96 examples/s]


In [5]:
pubmed_dataset

DatasetDict({
    train: Dataset({
        features: ['article', 'abstract', 'section_names'],
        num_rows: 119924
    })
    validation: Dataset({
        features: ['article', 'abstract', 'section_names'],
        num_rows: 6633
    })
    test: Dataset({
        features: ['article', 'abstract', 'section_names'],
        num_rows: 6658
    })
})

## MEDIA SUM DATASET (TOO LARGE PLANNING TO DISCARD)

Link to [data](https://aclanthology.org/2021.naacl-main.474/)

## MULTI-NEWS DATASET

In [1]:
from datasets import load_dataset

multi_news_dataset = load_dataset("multi_news")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
multi_news_dataset

DatasetDict({
    train: Dataset({
        features: ['document', 'summary'],
        num_rows: 44972
    })
    validation: Dataset({
        features: ['document', 'summary'],
        num_rows: 5622
    })
    test: Dataset({
        features: ['document', 'summary'],
        num_rows: 5622
    })
})

## WMT DATASET

In [4]:
# train_en_link = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en"
# train_de_link = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de"

# test_en_link = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en"
# test_de_link = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de"

wmt_dataset = load_dataset("stas/wmt14-en-de-pre-processed")

In [7]:
wmt_dataset['test']['translation'][0]['en']

'Obama receives Netanyahu'

## TRIVIA-QA DATASET

In [2]:
from transformers import T5Tokenizer
from datasets import load_dataset


dataset = load_dataset("trivia_qa", "unfiltered.nocontext")


Downloading readme:   0%|          | 0.00/26.7k [00:00<?, ?B/s]

Downloading and preparing dataset None/rc to C:/Users/saise/.cache/huggingface/datasets/parquet/rc-cfb291df38bbf033/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/308M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/290M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/444M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/461M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/474M [00:00<?, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/324M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/329M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/336M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/400M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/370M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/341M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/327M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/310M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/136M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/200M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/180M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/150M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/153M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/147M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/154M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/158M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/55.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/224M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/229M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/210M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/311M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/392M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/344M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/339M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/315M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/240M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/250M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/336M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/271M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/264M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/192M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/387M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/30.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/240M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/261M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/319M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/266M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/240M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/259M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/253M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/25.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/215M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/279M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/250M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/224M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/231M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/247M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/245M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/410M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/386M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/367M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/350M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/310M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/336M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/397M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/372M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/320M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/351M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/314M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/270M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/201M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/271M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/252M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/278M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/258M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/252M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/261M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/273M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/264M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/269M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/338M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/350M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/284M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/271M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/284M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/251M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/196M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/198M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/198M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/349M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/361M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/327M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/296M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/184M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/129M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.34M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/225M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/237M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/235M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.04M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/235M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.31M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/212M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/308M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/226M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/259M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/229M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.39M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/307M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/288M [00:00<?, ?B/s]

OSError: [Errno 28] No space left on device

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")

def preprocess_data(examples):
    input_texts = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['entity_pages']['wiki_context'])]
    target_texts = [answer['value'] for answer in examples['answer']]
    model_inputs = tokenizer(input_texts, max_length=512, truncation=True, padding="max_length", return_tensors="pt")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(target_texts, max_length=128, truncation=True, padding="max_length", return_tensors="pt").input_ids
    model_inputs["labels"] = labels
    return model_inputs

# Map the preprocessing function over the dataset
train_dataset = dataset["train"].map(preprocess_data, batched=True)


## FINAL FUNCTIONS

In [None]:

dataset_type = ["cnn_dailymail","pubmed","multi_news","wmt14","triviaqa"]

def preprocess_function(examples,dataset_name:str,max_input_length:int=config.MAX_INPUT_LENGTH,max_target_length:int=config.MAX_TARGET_LENGTH):
        if dataset_name == "cnn_dailymail":
            prefix = "summarize: "
            inputs = [prefix + doc for doc in examples["article"]]
            model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True,padding=True)

            # Setup the tokenizer for targets
            labels = tokenizer(text_target=examples["highlights"], max_length=max_target_length, truncation=True)

            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        elif dataset_name == "pubmed":
            prefix = "summarize: "
            NotImplementedError()
        elif dataset_name == "multi_news":
            prefix = "summarize: "
            inputs = [prefix + doc for doc in examples["document"]]
            model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True,padding=True)

            # Setup the tokenizer for targets
            labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        elif dataset_name == "wmt14":
            prefix = "translate german to english: "
            inputs = [prefix + doc for doc in examples["translation"]['en']]
            model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True,padding=True)

            # Setup the tokenizer for targets
            text_targets = [ex['de'] for ex in examples["translation"]]
            labels = tokenizer(text_target=text_targets, max_length=max_target_length, truncation=True)

            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
tokenized_datasets_train = cnn_data_train.map(preprocess_function, batched=True,remove_columns=['article','highlights','id'],batch_size=config.TOKENIZE_BATCH_SIZE)
tokenized_datasets_val = cnn_data_val.map(preprocess_function, batched=True,remove_columns=['article','highlights','id'],batch_size=config.TOKENIZE_BATCH_SIZE)
tokenized_datasets_test = cnn_data_test.map(preprocess_function, batched=True,remove_columns=['article','highlights','id'],batch_size=config.TOKENIZE_BATCH_SIZE)