# Test the Datasets Class

In [2]:
from lexos.io.dataset import Dataset

## Configure Dataset Type

- headerlines lines
- lines with headers
- headerless csv
- csv
- json

In [62]:
dataset_types = {
    "raw_str_no_headers": "Test\nTest",
    "local_path": "../test_data/datasets/Austen.txt",
    "local_csv": "../test_data/datasets/Austen.csv",
    "local_tsv": "../test_data/datasets/Austen_valid_headers.tsv",
    "local_tsv_no_headers": "../test_data/datasets/Austen.tsv",
    "local_tsv_invalid": "../test_data/datasets/Austen_invalid_headers.tsv",
    "local_json": "../test_data/datasets/Austen.json",
    "local_jsonl": "../test_data/datasets/Austen_nl.jsonl",
    "local_json_invalid": "../test_data/datasets/Austen_invalid_fields.json",
    "zipped_lineated_text": "../test_data/datasets/lineated_text.zip"
}

source = dataset_types["local_path"]

In [13]:
dataset = Dataset.parse_lines(source)

In [None]:
dataset

In [66]:
"""dataset2.py.

This class currently supports single files of the following formats:

    - lineated text files, csv, tsv, json, and jsonl files
    - lineated text strings

To Do:

    - Test support for multiple files/strings (it should work as long as all files are of the same format)
    - Add support for csv, tsv, json, and jsonl strings
"""

from typing import Any, Dict, List, Optional, Union
from IPython.display import display # Remove for production
from pathlib import Path
import mimetypes
import tempfile
import zipfile
import pandas as pd
from pydantic import BaseModel
from smart_open import open


class Dataset(BaseModel):
    """Dataset class."""

    df: Optional[pd.DataFrame] = None
    data: Optional[List[Dict[str, str]]] = None
    
    class Config:
        arbitrary_types_allowed = True

#     @property
#     def df(self) -> pd.DataFrame:
#         """Return the dataframe of the object data.

#         Returns:
#             pd.DataFrame: The dataframe of the object data.
#         """
#         if self.data and not self.df:
#             return pd.DataFrame(self.data)

    @property
    def locations(self) -> List[str]:
        """Return the locations of the object data.

        Returns:
            List[str]: The locations of the object data.
        """
        if "locations" in self.df.columns:
            return self.df["locations"].values.tolist()
        else:
            return None

    @property
    def names(self) -> List[str]:
        """Return the names of the object data.

        Returns:
            List[str]: The names of the object data.
        """
        return self.df["title"].values.tolist()

    @property
    def texts(self) -> List[str]:
        """Return the texts of the object data.

        Returns:
            List[str]: The texts of the object data.
        """
        return self.df["text"].values.tolist()

    
class DatasetLoader:
    """Loads a dataset.
    
    Usage:
        loader = DatasetLoader(source)
        dataset = loader.dataset
        
    Notes:
        Different types of data may require different keyword parameters. Error messages
        provide some help in identifying what keywords are required.
    """
    
    def __init__(
        self,
        source: Any, 
        labels: List[str] = None,
        locations: Optional[List[str]] = None,
        title_col: Optional[str] = None,
        text_col: Optional[str] = None,
        location_col: Optional[str] = None,
        **kwargs: Dict[str, str],
    ) -> Union[Dataset, List[Dataset]]:
        """Initialise the loader.
        
        Args:
            source (Any): The source type to detect.
            labels (List[str]): The labels to use.
            locations (Optional[List[str]]): The locations of the texts.
            title_col (str): The name of the column containing the titles.
            text_col (str): The name of the column containing the texts.
            location_col (str): The name of the column containing the locations.

        Return:
            Union[Dataset, List[Dataset]]: A Dataset object.
        """
        if isinstance(source, list):
            self.data = [
                Dataset(df=self.load(item, labels, locations, title_col, text_col, location_col, **kwargs))
                for item in source
            ]
        else:
            df = self.load(source, labels, locations, title_col, text_col, location_col, **kwargs)
            self.data = Dataset(df=df)

    def _detect_source_type(
        self,
        source: Any,
        labels: List[str] = None,
        locations: Optional[List[str]] = None,
        title_col: Optional[str] = None,
        text_col: Optional[str] = None,
        location_col: Optional[str] = None,
        **kwargs: Dict[str, str],
    ):
        """
        Detect the source type of the given source type.

        Args:
            source (Any): The source type to detect.
            labels (List[str]): The labels to use.
            locations (Optional[List[str]]): The locations of the texts.
            title_col (str): The name of the column containing the titles.
            text_col (str): The name of the column containing the texts.
            location_col (str): The name of the column containing the locations.

        Return:
            str: The detected source type.
        """
        if isinstance(source, list):
            print("Lists are not supported yet.")
            return "list"
        elif isinstance(source, str):
            lines = source.split("\n")
            if len(lines) > 1:
                df = self.load_string(lines, labels)
            else:
                df = self.load_file(str(source), labels, locations, title_col, text_col, location_col, **kwargs)
            return df
        else:
            raise Exception(f"Unknown source type: {source}")

    def load(
        self,
        source: Any,
        labels: List[str] = None,
        locations: Optional[List[str]] = None,
        title_col: Optional[str] = None,
        text_col: Optional[str] = None,
        location_col: Optional[str] = None,
        **kwargs: Dict[str, str]
    ) -> pd.DataFrame:
        """Load the given file.

        Args:
            source (Any): The source the data to load.
            labels (List[str]): The labels to use.
            locations (Optional[List[str]]): The locations of the texts.
            title_col (str): The name of the column containing the titles.
            text_col (str): The name of the column containing the texts.
            location_col (str): The name of the column containing the locations.

        Returns:
            pd.DataFrame: The loaded data.
        """
        df = self._detect_source_type(source, labels, locations, title_col, text_col, location_col, **kwargs)
        return df

    def load_file(
        self,
        file_path: str,
        labels: List[str] = None,
        locations: Optional[List[str]] = None,
        title_col: Optional[str] = None,
        text_col: Optional[str] = None,
        location_col: Optional[str] = None,
        **kwargs: Dict[str, str]
    ) -> pd.DataFrame:
        """Load the given file.

        Args:
            file_path (str): The path to the file to load.
            labels (List[str]): The labels to use.
            locations (Optional[List[str]]): The locations of the texts.
            title_col (str): The name of the column containing the titles.
            text_col (str): The name of the column containing the texts.
            location_col (str): The name of the column containing the locations.

        Returns:
            pd.DataFrame: The loaded data.
        """
        if Path(file_path).exists():
            mime_type, _ = mimetypes.guess_type(file_path)
            # Lineated text
            if mime_type == "text/plain":
                with open(file_path) as f:
                    return self.load_string(list(f.readlines()), labels)
            # CSV/TSV
            elif mime_type in ["text/csv", "text/tsv", "application/vnd.ms-excel", "text/tab-separated-values"]:
                if labels:
                    df = pd.read_csv(file_path, header=None, **kwargs)
                    df.columns = labels
                else:
                    df = pd.read_csv(file_path, **kwargs)
                if title_col:
                    df = df.rename(columns={title_col: "title"})
                if text_col:
                    df = df.rename(columns={text_col: "text"})
                if location_col:
                    df["locations"] = [file_path] * df.shape[1]
                if "title" not in df.columns or "text" not in df.columns:
                    err = (
                        "CSV or TSV files must contain columns named `title` and `text`. ",
                        "You can convert the names of existing column to these with the ",
                        "`title_col` and `text_col` parameters. If your file has no column ",
                        "headers, you can supply a list with the `labels` parameter.",
                    )
                    raise Exception("".join(err))
                return df
            # JSON/JSONL
            elif mime_type == "application/json" or file_path.endswith(".jsonl"):
                df = pd.read_json(file_path, **kwargs)
                if title_col:
                    df = df.rename(columns={title_col: "title"})
                if text_col:
                    df = df.rename(columns={text_col: "text"})
                if location_col:
                    df["locations"] = [file_path] * df.shape[1]
                if "title" not in df.columns or "text" not in df.columns:
                    err = (
                        "JSON and JSONL files must contain fields named `title` and `text`. ",
                        "You can convert the names of existing fields to these with the ",
                        "`title_col` and `text_col` parameters.",
                    )
                    raise Exception("".join(err))
                return df
            # Zip
            elif mime_type in ["application/zip", "application/x-zip-compressed"]:
                return self.load_zip(file_path, labels, locations, title_col, text_col, location_col, **kwargs)
            else:
                raise Exception(f"Unknown file type: {mime_type}")

    def load_string(self, source: str, labels: List[str]) -> pd.DataFrame:
        """
        Load the given string.

        Args:
            source (str): The string to load.
            labels (List[str]): The labels to use.

        Returns:
            pd.DataFrame: The loaded data.
        """
        if not labels:
            raise Exception("Please use the `labels` argument to provide a list of labels for each row in your data.")
        elif len(labels) != len(source):
            raise Exception("The number of labels does not match the number of lines in your data.")
        else:
            data = [{"title": labels[i], "text": line} for i, line in enumerate(source)]
            df = pd.DataFrame(data, columns=["title", "text"])
            return df

    def load_zip(
        self,
        file_path: str,
        labels: List[str] = None,
        locations: Optional[List[str]] = None,
        title_col: Optional[str] = None,
        text_col: Optional[str] = None,
        location_col: Optional[str] = None,
        **kwargs: Dict[str, str],
    ) -> pd.DataFrame:
        """
        Load a zip file.

        Args:
            file_path (str): The path to the file to load.
            labels (List[str]): The labels to use.
            locations (Optional[List[str]]): The locations of the texts.
            title_col (str): The name of the column containing the titles.
            text_col (str): The name of the column containing the texts.
            location_col (str): The name of the column containing the locations.

        Returns:
            pd.DataFrame: The loaded data.
        """
        df = pd.DataFrame([], columns=["title", "text"])
        with open(file_path, "rb") as f:
            with zipfile.ZipFile(f) as zip:
                with tempfile.TemporaryDirectory() as tempdir:
                    zip.extractall(tempdir)
                    new_dfs = []
                    for tmp_path in Path(tempdir).glob("**/*"):
                        if (
                            tmp_path.is_file()
                            and not tmp_path.suffix == ""
                            and not str(tmp_path).startswith("__MACOSX")
                            and not str(tmp_path).startswith(".ds_store")
                        ):
                            new_dfs.append(load_file(tmp_path, labels, locations, title_col, text_col, location_col, **kwargs))
                    df = pd.concat(new_dfs, ignore_index=True, sort=False)
        return df

In [67]:
loader = DatasetLoader(source, labels=["text1", "text2"])
loader.data
# for dataset in loader.data:
#     display(dataset.df)
# dataset = Dataset(id=1, name="Joe")
# dataset

Unnamed: 0,title,text
0,text1,Pride and Prejudice by Jane Austen Chapter 1 I...
1,text2,SENSE AND SENSIBILITY by Jane Austen (1811) CH...


Dataset(data=None)

In [57]:
data = [{"title": "text1", "text": "Test"}, {"title": "text1", "text": "Test"}]
doink = Dataset(df=loader[0].data.df)
doink.df

TypeError: 'DatasetLoader' object is not subscriptable