In [1]:
import torch
from typing import List
from collections import defaultdict

In [2]:
with open("/export/home/Data/WikiText-2/wikitext-2-raw/wiki.train.raw") as f:
    full_data = [d.strip() for d in f.readlines()]

In [3]:
def partition_wikitext2_data(lines: List[str]) -> List[List[List[str]]]:
    """Partition the WikiText2 Data into the following format:
    
    [
        # Article 0
        [
            # Section 0
            [line-i, line-i+1, ..., ]
            # Section 1
            [line-i, line-i+1, ..., ]
            ...
        ]
        # Article 1
        [
            # Section 0
            [line-i, line-i+1, ..., ]
            # Section 1
            [line-i, line-i+1, ..., ]
            ...
        ]
    ]
    """
    if lines[0] != "":
        raise ValueError

    article_index = 0
    subsections = []
    sections = []
    articles = []
    last_line = lines[0]
    for line in lines[1:]:
        maybe_new_something = last_line == ""
        maybe_new_article = line.startswith("=")
        maybe_new_section = line.startswith("= =")
        maybe_new_subsection = line.startswith("= = =")
        corner_article_case = line.startswith("= De dezas a espwa")
        
        if maybe_new_something and maybe_new_subsection:
            subsections.append(line)
        elif maybe_new_something and maybe_new_section:
            # new section
            sections.append(subsections)
            subsections = [line]
        elif maybe_new_something and maybe_new_article and not corner_article_case:
            sections.append(subsections)
            subsections = [line]
            articles.append(sections)
            sections = []
        else:
            subsections.append(line)
        
        last_line = line

    sections.append(subsections)
    articles.append(sections)
    return articles

In [4]:
articles = partition_wikitext2_data(full_data)
# print(partitioned_data[1])
len(articles)

601

In [5]:
def join_sections(sections: List[List[str]]) -> str:
    # each section is a list of subsections
    joined_subsections = [
        "\n".join(subsections)
        for subsections in sections]
    return "\n".join(joined_subsections)


def join_articles(articles: List[List[List[str]]]) -> str:
    # Each article in a list of sections
    joined_sections = [
        join_sections(sections)
        for sections in articles]
    return "\n".join(joined_sections)

In [6]:
# Sanity check
join_articles(articles).splitlines() == full_data[:-1]

True

In [7]:
import os
from copy import deepcopy
from tqdm import trange
base_directory = "/export/home/Data/WikiText-2/articles/"

if articles[0] != [[]]:
    raise ValueError

cleaned_articles = articles[1:]
for i in trange(len(cleaned_articles)):
    cleaned_articles_copy = deepcopy(cleaned_articles)
    cleaned_article_i = cleaned_articles_copy.pop(i)
    if len(cleaned_articles_copy) != len(cleaned_articles) - 1:
        raise ValueError
    
    article_i_file_name = os.path.join(
        base_directory, f"article-{i}.txt")
    article_no_i_file_name = os.path.join(
        base_directory, f"article-no-{i}.txt")
    article_i_string = join_sections(cleaned_article_i)
    article_no_i_string = join_articles(cleaned_articles_copy)
    with open(article_i_file_name, "w") as f:
        f.write(article_i_string)
    with open(article_no_i_file_name, "w") as f:
        f.write(article_no_i_string)

100%|██████████| 600/600 [01:05<00:00,  9.19it/s]
