In [None]:
# To install all required packages, run this cell (can be left out otherwise)
!pip install pandas zstandard

In [None]:
import pandas as pd

from os import path
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, lit, sentences

In [None]:
# Creating a spark session and getting the context
spark = SparkSession.builder.master("yarn").appName("the-pile-embeddings").getOrCreate()
sc = spark.sparkContext

In [None]:
# Define data set paths
THE_PILE_BASE_PATH = path.join("file:///", "mnt", "ceph", "storage", "corpora", "corpora-thirdparty", "the-pile")
val_data_path = path.join(THE_PILE_BASE_PATH, "val.jsonl.zst")
test_data_path = path.join(THE_PILE_BASE_PATH, "test.jsonl.zst")
train_data_paths = [path.join(THE_PILE_BASE_PATH, "train", f"{str(n).zfill(2)}.jsonl.zst") for n in range(0, 30)]

## Data loading and preparation

In [None]:
# Data set selection
data_selection = [
    'OpenWebText2',
    'PubMed Abstracts',
    'StackExchange',
    # 'Github', # Currently ignoring because we don't want the code
    'Enron Emails',
    'FreeLaw',
    'USPTO Backgrounds',
    'Pile-CC',
    'Wikipedia (en)',
    'Books3',
    'PubMed Central',
    'HackerNews',
    'Gutenberg (PG-19)',
    # 'DM Mathematics', # Currently ignoring because we don't want math formulas
    'NIH ExPorter',
    'ArXiv',
    'BookCorpus2',
    'OpenSubtitles',
    'YoutubeSubtitles',
    'Ubuntu IRC',
    # 'EuroParl', # Currently ignoring because we'll focus on English text for now
    'PhilPapers'
]

_For now, we use pandas to read the data, as there seems to be some issues with spark reading zstandard compressed files (which "The Pile" uses)._

_That means, we can also just load parts of the data for now, until we get this issue fixed (probably a server-side issue)._

In [None]:
# Read data
val_data = pd.read_json(val_data_path, lines=True, compression="zstd")
# Transform set name column to make it easier to work with
val_data["meta_str"] = val_data["meta"].apply(lambda x: x["pile_set_name"])

In [None]:
# Only select the data we are interested in for now
filtered_data = val_data[val_data["meta_str"].isin(data_selection)][["text", "meta_str"]]

# Create a spark dataframe from pandas DataFrame
val_data_spark = spark.createDataFrame(filtered_data)

In [None]:
# If necessary, we can also transform our data into a dataframe of sentences
# First, we split each document into a list of sentences, which are lists of tokens.
nested_sentences = val_data_spark.select(
    sentences(string=val_data_spark.text, language=lit("en")))

# Afterwards, we can flatten the list of lists by "exploding" each outer list
flattened_sentences = nested_sentences.select(explode(col("sentences(text, en, )")).alias("sents"))