In [1]:
# requirements: python >= 3.7
from pathlib import Path
import re
import json
from concurrent.futures import ThreadPoolExecutor, as_completed

# these too
import sqlite3
from tqdm import tqdm

In [2]:
metadata = [
    "metadata112.sqlite",
    "metadata113.sqlite",
    "metadata114.sqlite",
    "metadata115.sqlite",
    "metadata116.sqlite",
    "metadata117.sqlite",
    "metadata118.sqlite",
    "metadata119.sqlite",
    "metadata120.sqlite",
    "metadata121.sqlite",
]

data_path = Path('/mnt/e/datasets/fanfic/updateablefanfic')
categories_path = Path.cwd().joinpath('updateablefanfic_categories.json')

records_output_path = Path.cwd().joinpath('../data/records.json')
dataset_output_path = Path.cwd().joinpath('../data/dataset.txt')

In [3]:
i = 0
records = []

def query(md, conn, cat):
    result = None
    try:
        curs = conn.cursor()
        cat = cat.replace("'", "''")

        # Selects stories within specified categories, max 8 stories per category,
        # with between 100 to 15000 words. No poetry.
        sqlite_select_query = """
            SELECT path 
            FROM metadata 
            WHERE language LIKE 'english' 
                AND CAST(REPLACE(words, ',', '') AS int) > 100 
                AND CAST(REPLACE(words, ',', '') AS int) < 15000 
                AND genre NOT LIKE '%Poetry%'
                AND category LIKE '%""" + cat + """%' 
            ORDER BY CAST(REPLACE(words, ',', '') AS int) ASC
            LIMIT 8;"""
        # and cast(chapters as integer) < 2

        curs.execute(sqlite_select_query)
        
        return curs.fetchall()

    except Exception as e:
        print(e)


def get_records():
    records = []
    with open(categories_path, encoding="utf8") as fcats:
        cats = json.load(fcats)
        max_i = len(cats) * len(metadata) + len(metadata * 500)

        with ThreadPoolExecutor(max_workers = 64) as executor:
            futures = []

            # Progress bar
            with tqdm(total=max_i) as pbar:
                for md in metadata:
                    # Load the DBs into memory. 
                    source = sqlite3.connect(data_path.joinpath(md))
                    conn = sqlite3.connect(':memory:', check_same_thread=False)
                    source.backup(conn)
                    pbar.update(500)

                    for cat in cats:
                        cat = cat.replace("'", "''")
                        future = executor.submit(query, md, conn, cat)
                        futures.append(future)

                output = []

                
                for f in as_completed(futures):
                    records += f.result()
                    pbar.update(1)

    return records

In [4]:
def write_records(records):
    with open(records_output_path, "w+", encoding="utf8") as rfile:
        json.dump(records, rfile)

def load_records():
    with open(records_output_path, "r", encoding="utf8") as rfile:
        return json.load(rfile)

In [5]:
records = get_records()

write_records(records)

100%|██████████| 10400/10400 [04:59<00:00, 34.68it/s]


In [6]:
def parse_file(file_path):
    try:
        parsed = ""
        quote_re = r'^".*"'

        with open(file_path, 'r', encoding="utf8") as file:
            data=file.read()
            metadata_skipped = False

            for l in iter(data.splitlines()):
                # Skip first few rows of genre information etc.
                if not metadata_skipped:
                    if l.startswith("Summary:"):
                        metadata_skipped = True

                else:
                    if not l.startswith("\t"):
                        l = l.strip()
                        if (not l == 'End file.') and len(l):
                            if "." in l or "?" in l or "!" in l:
                                # Weeding out lines with no written words, e.g. "---"
                                if len(l) > 10 or re.search(quote_re, l):
                                    l = re.sub(
                                        r"[^A-Za-z0-9,.?!\-—():; '\"\*]+", '', l)

                                    # Remove some characters appearing multiple times in a row
                                    # Maybe leave them in for extra spice?
                                    l = re.sub(r"[—]{2,}", '—', l)
                                    l = re.sub(r"[,]{2,}", ',', l)
                                    l = re.sub(r"[/]{2,}", '/', l)
                                    l = re.sub(r"[-]{2,}", '-', l)
                                    l = re.sub(r"[']{2,}", "'", l)
                                    l = re.sub(r'["]{2,}', '"', l)
                                    l = re.sub(r'[(]{2,}', '(', l)
                                    l = re.sub(r'[)]{2,}', ')', l)

                                    parsed += l + "\n"

        return parsed

    except Exception as e:
        print(e)
        pass

In [7]:
def create_dataset():
    records=load_records()

    with open(dataset_output_path, "a+", newline="", encoding="utf8") as output_file:
        # Try lowering workers if it gets stuck
        with ThreadPoolExecutor(max_workers = 1024) as executor:
            # Progress bar
            with tqdm(total=len(records)) as pbar:
                futures = []

                for r in records:
                    future = executor.submit(parse_file, data_path.joinpath(r[0]))
                    futures.append(future)

                for f in as_completed(futures):
                    # Append start and end tokens for GPT-2
                    text = '<|startoftext|>\n' + f.result() + '<|endoftext|>\n'
                    output_file.write(text)

                    pbar.update(1)
                

In [11]:
create_dataset()

100%|██████████| 27069/27069 [00:51<00:00, 529.61it/s]
