In [67]:
import os
import re
from datasets import DatasetDict, Dataset
from multiprocessing import cpu_count
from transformers import AutoTokenizer, BartForConditionalGeneration, pipeline

index_mapping = {
    "url": 0,
    "text": 1,
    "answer_masked": 2,
    "correct_entity": 3,
    "entity_mapping": 4
}

In [20]:
cnn_path = "../../lit-data/datasets/cnn/questions"
cnn_datasetdict = DatasetDict()
splits = ["training", "test", "validation"]

for split in splits:
    split_path = f"{cnn_path}/{split}"
    split_list = []
    for file_name in os.listdir(split_path):
        with open(f"{split_path}/{file_name}", "r") as f:
            parts = f.read().split("\n\n")
            split_list.append({
                "url": parts[0],
                "text": parts[1],
                "answer_masked": parts[2],
                "correct_entity": parts[3],
                "entity_mapping": parts[4]
            })
    cnn_datasetdict[split] = Dataset.from_list(split_list)

print(cnn_datasetdict)

DatasetDict({
    validation: Dataset({
        features: ['url', 'text', 'answer_masked', 'correct_entity', 'entity_mapping'],
        num_rows: 3924
    })
    training: Dataset({
        features: ['url', 'text', 'answer_masked', 'correct_entity', 'entity_mapping'],
        num_rows: 380298
    })
    test: Dataset({
        features: ['url', 'text', 'answer_masked', 'correct_entity', 'entity_mapping'],
        num_rows: 3198
    })
})


In [21]:
cnn_datasetdict.save_to_disk("../../lit-data/datasets/cnn_dataset")

Saving the dataset (0/1 shards):   0%|          | 0/3924 [00:00<?, ? examples/s]

Saving the dataset (0/4 shards):   0%|          | 0/380298 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3198 [00:00<?, ? examples/s]

In [None]:
dailymail_path = "../../lit-data/datasets/dailymail/questions"
dailymail_datasetdict = DatasetDict()
splits = ["training", "test", "validation"]

for split in splits:
    split_path = f"{dailymail_path}/{split}"
    split_list = []
    for file_name in os.listdir(split_path):
        with open(f"{split_path}/{file_name}", "r") as f:
            parts = f.read().split("\n\n")
            split_list.append({
                "url": parts[0],
                "text": parts[1],
                "answer_masked": parts[2],
                "correct_entity": parts[3],
                "entity_mapping": parts[4]
            })
    dailymail_datasetdict[split] = Dataset.from_list(split_list)

print(dailymail_datasetdict)

In [None]:
dailymail_datasetdict.save_to_disk("../../lit-data/datasets/dailymail_dataset")

In [61]:
def parse_entity_mapping(mapping_str):
    mappings_list = mapping_str.split("\n")
    mappings = dict()
    for mapping in mappings_list:
        mapping_splits = mapping.split(":")
        mappings[mapping_splits[0]] = mapping_splits[1]
    return mappings

def preprocess(record):
    entity_mapping = parse_entity_mapping(record["entity_mapping"])

    # unmask text
    text = record["text"]
    for placeholder, entity in entity_mapping.items():
        text = re.sub(placeholder, entity, text)
    split_text = text.split(" ")
    source = split_text[1]
    text = " ".join(split_text[3:])

    # unmask correct entity
    correct_entity = entity_mapping[record["correct_entity"]]

    # get entites in answer
    answer_entites_list = []
    answer = record["answer_masked"]
    split_answer = answer.split(" ")
    for token in split_answer:
        if token.startswith("@"):
            try:
                answer_entites_list.append(entity_mapping[token])
            except KeyError:
                pass
    answer_entities = " ".join(answer_entites_list)

    # unmask answer
    answer = record["answer_masked"]
    for placeholder, entity in entity_mapping.items():
        answer = re.sub(placeholder, entity, answer)

    return {"text": text, "correct_entity": correct_entity, "source": source, "answer_entities": answer_entities, "answer": answer }

sample = cnn_datasetdict["validation"][0]
preprocess(sample)

{'text': 'could it get any cuter than seal pup kisses ? the USGS and the U.S. Department of Interior this week shared a photo of a Weddell seal nuzzling up to what looked to be its mom in CNN2 , CNN3 . the expression of the mother is priceless . the photo was taken in october by USGS scientist Link . link , a statistician , was helping researchers tag newborn seal pups . he confirmed friday that the adult seal was the baby \'s mom . it \'s hard to know what she was thinking when her baby nuzzled up to her in this photo , but Link said the animals flare their noses when disturbed , " so this Mamma was pretty relaxed , " Link told CNN friday . " i have a great shot a few seconds later where Mamma yawned hugely . she looked utterly content , to me . " the agency \'s public affairs department had asked scientists for interesting images to post on social media . as the USGS1 caption notes , the Weddell seals of CNN2 have been studied extensively for over 40 years . " because of its isolatio

In [62]:
dailymail_datasetdict = DatasetDict.load_from_disk("../../lit-data/datasets/dailymail_dataset")
cnn_datasetdict = DatasetDict.load_from_disk("../../lit-data/datasets/cnn_dataset")

In [85]:
test_set = cnn_datasetdict["validation"].select(range(32)).map(preprocess, num_proc=cpu_count())
test_set.to_csv("./test.csv")

Map (num_proc=10):   0%|          | 0/32 [00:00<?, ? examples/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

170301

In [86]:
bart_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/distilbart-cnn-12-6")

In [89]:
def add_summary_column(records, model: BartForConditionalGeneration, tokenizer: AutoTokenizer):
    texts = records["text"]
    inputs = tokenizer(texts, max_length=1024, truncation=True, padding=True, return_tensors="pt")
    summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=100)
    summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return { "summary": summaries }

with_summaries = test_set.map(add_summary_column, fn_kwargs={"model": bart_model, "tokenizer": bart_tokenizer}, batched=True, batch_size=32)
with_summaries.to_csv("./test.csv")

Map:   0%|          | 0/32 [00:00<?, ? examples/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

181138