<a href="https://colab.research.google.com/github/zetavg/LLM-Research/blob/main/Wikipedia_Random_Page_Summaries_Dataset_Generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wikipedia Random Page Summaries Dataset Generator

Collect random Wikipedia page summaries and save them into a dataset for further training. Optionally uploads the dataset to Hugging Face Hub.

In [None]:
# @title Settings { display-mode: "form", run: "auto" }

wikipedia_lang = "zh-tw"  # @param {type:"string"}
number_of_pages = 5000  # @param {type:"integer"}

hf_dataset_set_as_private = False  # @param {type:"boolean"}
# @markdown If set to blank, the dataset will not be uploaded to Hugging Face. **Note that this dataset may be overwritten if exists**:
hf_dataset_name = "zetavg/wikipedia_random_page_summaries_zh_tw_5k"  # @param {type:"string"}

In [None]:
!pip install datasets huggingface_hub gradio

In [None]:
# @title Hugging Face Login { display-mode: "form" }
# @markdown Make sure to have this done before running the "Upload the dataset to Hugging Face Hub" section.

from huggingface_hub import HfFolder, whoami, notebook_login

upload_to_hf = False
if hf_dataset_name:
    upload_to_hf = True


def check_login_status():
    folder = HfFolder()
    token = folder.get_token()

    if token is not None:
        try:
            user_info = whoami(token)
            username = user_info["name"]
            return username
        except Exception:
            return None
    else:
        return None


if upload_to_hf:
    username = None
    username = check_login_status()
    if not username:
        notebook_login()
    else:
        print(f"Already login as {username}")

In [None]:
# @markdown Install necessary packages.

# @markdown A patched version of the wikipedia package to support zh-tw.
!pip install git+https://github.com/zetavg/python_wikipedia

# @markdown 「盤古之白」(https://github.com/vinta/pangu.js).
!pip install pangu

## Prepare the dataset

In [None]:
# @markdown Test if we can get pages from wikipedia

import wikipedia
import pangu

for _ in range(10):
    try:
        wikipedia.set_lang(wikipedia_lang)
        page_title = wikipedia.random()
        print("Page Title: ", page_title)
        # @markdown Note: 中英文中間的空格（盤古之白 XD）在這邊扮演了重要角色，因為根據大多數 tokenizer 實作的原理，英文字前面沒空格的話會被轉成另一個不一樣的 token。
        print(pangu.spacing_text(
            wikipedia.summary(page_title)
        ))
        # @markdown Note: "Page Title" 有些可能是簡體中文，這是正常的。因為它們是維基百科未經轉換的原始頁面名稱。
    except Exception as e:
        # @markdown Note: 有些因為消歧義產生的 error 可以忽略。
        print("Error: ", e)
    print()

In [None]:
# @markdown Define how the data should be loaded. See: https://shareg.pt/kx4UbKd.

import queue
from wikipedia import DisambiguationError
from concurrent.futures import ThreadPoolExecutor
from threading import Semaphore
from tqdm.auto import tqdm


def get_random_page_summary(sentences=0):
    try:
        page_title = wikipedia.random()
        page_summary = wikipedia.summary(page_title)
        page_summary = pangu.spacing_text(page_summary)
        return {
            'page_title': page_title,
            'page_summary': page_summary,
        }
    except DisambiguationError:  # "..." may refer to "..." or "..."
        # Just retry
        return get_random_page_summary(sentences)
    except Exception as e:  # Ignore other errors
        print(e)
        return None


def fetch_page_and_update_queue(
        sentences, fetched_pages_queue, fetch_semaphore):
    fetch_semaphore.acquire()
    result = get_random_page_summary(sentences)
    fetched_pages_queue.put(result)


def data_generator(
        count=1000,
        sentences=0,  # 0: Unlimited
        max_workers=50,
        max_buffer_size=500):
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        fetched_pages_queue = queue.Queue()
        fetch_semaphore = Semaphore(max_buffer_size)

        for _ in range(count):
            executor.submit(fetch_page_and_update_queue,
                            sentences,
                            fetched_pages_queue, fetch_semaphore)

        progress_bar = tqdm(total=count)
        yielded_count = 0
        errored_count = 0

        for _ in range(count):
            result = fetched_pages_queue.get()
            fetch_semaphore.release()

            if result:
                yielded_count += 1
                yield result
            else:
                errored_count += 1

            fetched_count = yielded_count + errored_count + \
                len(fetched_pages_queue.queue)
            description = f"Pages processed/fetched: {yielded_count}/{fetched_count}"
            if errored_count > 0:
                description = f"Pages processed/errored/fetched: {yielded_count}/{errored_count}/{fetched_count}"
            progress_bar.set_description(description)
            progress_bar.update(1)

        progress_bar.close()

In [None]:
# @title Create the dataset

from datasets import Dataset

ds = Dataset.from_generator(
    data_generator,
    gen_kwargs={
        'count': number_of_pages
    })

In [None]:
# @title Preview the created dataset
import json

print("features: ", ds.features)
print("num_rows: ", ds.num_rows)

print("preview: ")
for i in range(10):
    item = ds[i]
    print(json.dumps(item, indent=2, ensure_ascii=False))

## Upload the dataset to Hugging Face Hub

In [None]:
if upload_to_hf:
    print(f"Uploading {'private' if hf_dataset_set_as_private else 'public'} dataset '{hf_dataset_name}' to Hugging Face...")
    ds.push_to_hub(
        hf_dataset_name,
        private=hf_dataset_set_as_private)
    print(f"Dataset uploaded: https://huggingface.co/datasets/{hf_dataset_name}.")
else:
    print("Upload skipped.")

In [None]:
# @title Preview the uploaded dataset
import json
from datasets import load_dataset

if upload_to_hf:
    ds_from_hf = load_dataset(hf_dataset_name)['train']

    print("features: ", ds_from_hf.features)
    print("num_rows: ", ds_from_hf.num_rows)

    print("preview: ")
    for i in range(10):
        item = ds_from_hf[i]
        print(json.dumps(item, indent=2, ensure_ascii=False))