In [1]:
%pwd

'/home/tiva/PycharmProjects/HeadlineGenerator/notebook'

In [2]:
import os

In [3]:
os.chdir("../")

In [4]:
%pwd

'/home/tiva/PycharmProjects/HeadlineGenerator'

In [5]:
from dataclasses import dataclass
from pathlib import Path

In [6]:
@dataclass(frozen=True)
class DataPreparationConfig:
    root_dir: Path
    data_dir: Path
    save_dir: Path

In [7]:
from headlineGenerator.constants import *
from headlineGenerator.utils.common import read_yaml, create_directories

from datasets import concatenate_datasets, DatasetDict

[2024-03-17 06:19:12,430 || PyTorch version 2.2.0+cpu available.]


In [8]:
class ConfigurationManager:
    def __init__(self, config_filepath=CONFIG_FILEPATH, params_filepath=PARAMS_FILEPATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_data_preparation_config(self) -> DataPreparationConfig:
        config = self.config.data_preparation

        create_directories([config.root_dir])

        data_preparation_config = DataPreparationConfig(
            root_dir=config.root_dir,
            data_dir=config.data_dir,
            save_dir=config.save_dir,
        )

        return data_preparation_config

In [9]:
import os
import re
from datasets import load_dataset, concatenate_datasets, DatasetDict
from headlineGenerator.logging import logger

In [10]:
class DataPreparation:
    def __init__(self, config: DataPreparationConfig):
        self.config = config
        self.news_dataset = DatasetDict()
        self.business_file = os.path.join(self.config.data_dir, "business_contents.csv")
        self.entertainment_file = os.path.join(self.config.data_dir, "entertainment_contents.csv")
        self.sports_file = os.path.join(self.config.data_dir, "sports_contents.csv")

    def add_category(self, example, category):
        return {"category": category}

    def build_full_story(self, example):
        if example["summary"] is None:
            return {"full_story": example["main_story"]}
    
        return {"full_story": example["summary"] + " " + example["main_story"]}

    def edit_writer(self, example):
        if example["writer"] is None:
            return {"writer": "UNK"}
        else:
            return {"writer": re.sub("by.", "", example["writer"].lower()).strip()}

    def lowercase_content(self, example):
        return {"headline": example["headline"].lower(), "full_story": example["full_story"].lower()}

    def compute_story_length(self, example):
        return {"story_length": len(example["full_story"].split())}

    def load_data(self):
        business_dataset = load_dataset("csv", data_files=self.business_file, split="train")
        entertainment_dataset = load_dataset("csv", data_files=self.entertainment_file, split="train")
        sports_dataset = load_dataset("csv", data_files=self.sports_file, split="train")

        # add category to datasets
        business_dataset = business_dataset.map(self.add_category, fn_kwargs={"category": "business"})
        entertainment_dataset = entertainment_dataset.map(self.add_category, fn_kwargs={"category": "entertainment"})
        sports_dataset = sports_dataset.map(self.add_category, fn_kwargs={"category": "sports"})

        # concatenate data
        news_dataset = concatenate_datasets([business_dataset, entertainment_dataset, sports_dataset])
        self.news_dataset = news_dataset.shuffle(seed=84)

    def clean_data(self):
        # remove any content with no story
        self.news_dataset = self.news_dataset.filter(lambda x: x["main_story"] is not None)

        # encode the `category` column
        self.news_dataset = self.news_dataset.class_encode_column("category")

        # create a full story by combining the summary and main story
        self.news_dataset = self.news_dataset.map(self.build_full_story)

        # edit content writers, removing `by.` from their names
        self.news_dataset = self.news_dataset.map(self.edit_writer)

        # convert all headlines and stories to lowercase
        self.news_dataset = self.news_dataset.map(self.lowercase_content)

        # compute length of each story
        self.news_dataset = self.news_dataset.map(self.compute_story_length)

        # remove contents > 1000 words in `story_length`
        self.news_dataset = self.news_dataset.filter(lambda x: x["story_length"] < 1000)

    def split_and_save(self):
        # remove unwanted columns
        news_dataset_clean = self.news_dataset.remove_columns(["last_update", "summary", "editor", "writer", "main_story",])
        
        split_data = news_dataset_clean.train_test_split(train_size=0.7, seed=84)
        val_n_test = split_data["test"].train_test_split(train_size=0.4, seed=84)
        split_data["validation"] = val_n_test["train"]
        split_data["test"] = val_n_test["test"]

        split_data.save_to_disk(self.config.save_dir)

In [11]:
try:
    config = ConfigurationManager()
    data_preparation_config = config.get_data_preparation_config()
    data_preparation = DataPreparation(config=data_preparation_config)
    data_preparation.load_data()
    data_preparation.clean_data()
    data_preparation.split_and_save()
except Exception as e:
    raise e

[2024-03-17 06:19:54,284 || yaml file : config.yaml loaded successfully]
[2024-03-17 06:19:54,354 || yaml file : params.yaml loaded successfully]
[2024-03-17 06:19:54,368 || created directory at artifacts]
[2024-03-17 06:19:54,370 || created directory at artifacts/data_preparation]


Map:   0%|          | 0/193 [00:00<?, ? examples/s]

Map:   0%|          | 0/230 [00:00<?, ? examples/s]

Map:   0%|          | 0/312 [00:00<?, ? examples/s]

Filter:   0%|          | 0/735 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/724 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/724 [00:00<?, ? examples/s]

Map:   0%|          | 0/724 [00:00<?, ? examples/s]

Map:   0%|          | 0/724 [00:00<?, ? examples/s]

Map:   0%|          | 0/724 [00:00<?, ? examples/s]

Map:   0%|          | 0/724 [00:00<?, ? examples/s]

Filter:   0%|          | 0/724 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/501 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/129 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/86 [00:00<?, ? examples/s]