Skip to content
Permalink
8047bfa1d9
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
452 lines (384 sloc) 14.2 KB
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating the C4 dataset."""
import functools
import gzip
import hashlib
import io
import re
import threading
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow_datasets.public_api as tfds
# WET file constants
_PAGE_DELIMITER = "WARC/1.0"
_URL_KEY = "WARC-Target-URI:"
_URL_DATE = "WARC-Date:"
_CONTENT_TYPE = "Content-Type:"
_CONTENT_LEN = "Content-Length:"
_METADATA_PREFIXES = ("WARC", "CONTENT-", "Content-")
# Filters
_MIN_WORDS_PER_LINE = 5
_MIN_NUM_SENTENCES = 3
_MAX_WORD_LENGTH = 1000
_END_MARKS = (".", "?", "!", "\"")
_ELLIPSIS = "..."
_POLICY_SUBSTRINGS = [
"terms of use", "privacy policy", "cookie policy", "uses cookies",
"use of cookies", "use cookies"]
# Memoized sentence tokenizer.
_SENTENCE_TOKENIZER = None
def get_counter_inc_fn(namespace):
def counter_inc_fn(counter, amt=1):
tfds.core.lazy_imports.apache_beam.metrics.Metrics.counter(
namespace, counter).inc(amt)
return counter_inc_fn
def get_hashed_url_filter_fn(predicate_fn):
def filter_fn(el):
url, _ = el
val = int(
hashlib.md5(tf.compat.as_text(url).encode("utf-8")).hexdigest(), 16)
return predicate_fn(val)
return filter_fn
def _load_sentence_tokenizer():
"""Returns a sentence tokenization function."""
nltk = tfds.core.lazy_imports.nltk
# Lock to avoid a race-condition in the creation of the download directory.
with threading.Lock():
nltk.download("punkt")
return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
def _get_sentences(text):
global _SENTENCE_TOKENIZER
if not _SENTENCE_TOKENIZER:
_SENTENCE_TOKENIZER = _load_sentence_tokenizer()
return list(_SENTENCE_TOKENIZER.tokenize(tf.compat.as_text(text)))
def is_language(page, language, min_probability=0.99):
"""Returns True iff text is in `language` with at least `min_probability`."""
unused_url, features = page
text = features["text"]
counter_inc_fn = get_counter_inc_fn("detected-lang")
langdetect = tfds.core.lazy_imports.langdetect
# Make langdetect predictions deterministic.
langdetect.DetectorFactory.seed = 0
try:
predictions = langdetect.detect_langs(text)
except langdetect.lang_detect_exception.LangDetectException:
counter_inc_fn("langdetect-exception")
return False
if not predictions:
counter_inc_fn("page-filtered-nolangpredictions")
return False
best_prediction = predictions[0]
if best_prediction.prob < min_probability:
counter_inc_fn("page-filtered-lowlangdetectconf")
return False
if best_prediction.lang != language:
counter_inc_fn("page-filtered-ignoredlang")
counter_inc_fn("page-filtered-ignoredlang-%s" % (best_prediction.lang))
return False
counter_inc_fn("page-emited-%s" % best_prediction.lang)
return True
def get_clean_page_fn(badwords=None):
"""Returns `clean_page` with pre-compiled badword and citation regexes."""
# Used to filter citation from Wikipedia pages (among others).
citation_regex = re.compile(r"\[\d*\]|\[edit\]|\[citation needed\]")
if badwords:
badwords_regex = re.compile(
"[^a-z]({})[^a-z]".format("|".join(badwords or [])))
else:
badwords_regex = None
return functools.partial(
clean_page, citation_regex=citation_regex, badwords_regex=badwords_regex)
def clean_page(url_and_features,
citation_regex,
badwords_regex=None,
counter_inc_fn=None,
min_words_per_line=_MIN_WORDS_PER_LINE,
min_num_sentences=_MIN_NUM_SENTENCES,
max_word_length=_MAX_WORD_LENGTH):
"""Cleans a CommonCrawl page, yielding nothing if it should be skipped.
Cleaning removes lines with no end marks or with too few words. After line
filtering, pages are filtered out if they have too few sentences based on a
simple count of end marks.
Args:
url_and_features: tuple(string, dict), the url and features of the page.
citation_regex: Regex to use for finding Wikipedia-like citations to filter.
badwords_regex: Regex to use for finding badwords. Default None, which means
don't apply badwords filtering.
counter_inc_fn: function, a function taking the name of a counter to be
incremented and the (optional) amount. Defaults to a beam Metric counter.
min_words_per_line: int, the minimum number of words a line needs to not be
removed.
min_num_sentences: int, the minimum number of sentences a page needs to not
be skipped.
max_word_length: int, the maximum number of characters allowed in a word.
Lines containing a word with too many characters are removed.
Yields:
The url and cleaned text for the page.
"""
url, features = url_and_features
text = features["text"]
if not counter_inc_fn:
counter_inc_fn = get_counter_inc_fn("clean-page")
lines = text.splitlines()
valid_lines = []
num_sentences = 0
def line_has_too_long_word(line):
for word in line.split():
if len(word) > max_word_length:
return True
return False
for line in lines:
line = line.strip()
if line_has_too_long_word(line):
counter_inc_fn("lines-with-too-long-word")
continue
line = citation_regex.sub("", line)
if not line.endswith(_END_MARKS) or line.endswith(_ELLIPSIS):
counter_inc_fn("lines-no-endmark")
continue
if len(line.split()) < min_words_per_line:
counter_inc_fn("lines-too-short")
continue
line_lower = line.lower()
# Remove documents which contain lorem ipsum
if "lorem ipsum" in line_lower:
counter_inc_fn("filtered-page-loremipsum")
return
# Remove "javascript must be enabled" notices
if "javascript" in line_lower:
counter_inc_fn("lines-javascript")
continue
# Remove docs which probably contain javascript code
if "{" in line:
counter_inc_fn("filtered-page-squigglybracket")
return
# Remove policy lines
if any(p in line_lower for p in _POLICY_SUBSTRINGS):
counter_inc_fn("lines-policy")
continue
# If any badword appears on its own in the line, skip this doc
if badwords_regex:
badwords_found = badwords_regex.search(line_lower)
if badwords_found is not None:
counter_inc_fn("filtered-page-badword")
return
num_sentences += len(_get_sentences(line))
valid_lines.append(line)
counter_inc_fn("lines-valid")
if num_sentences < min_num_sentences:
counter_inc_fn("filtered-page-toofewsentences")
return
counter_inc_fn("emitted-clean-pages")
features["text"] = "\n".join(valid_lines).strip()
yield url, features
def _hash_text(text):
return hashlib.md5(tf.compat.as_text(text).encode("utf-8")).hexdigest()
def _emit_url_to_lines(page):
"""Emits url to all (lower-cased, hashed) lines."""
url, features = page
text = features["text"]
for line in text.split("\n"):
yield _hash_text(line.strip().lower()), url
def _remove_lines_from_text(el, counter_inc_fn, min_num_sentences):
"""Removes all lines from the page that do not match the given set of hashes.
Process the result of a join containing a single value for 'features' and zero
or more values for 'lines'. Each value in 'lines' is a lower-cased, hashed
line that has been selected to keep.
Args:
el: `(string, {'features': features_dict, 'lines': [string]})`,
element containing the result of a join on key with both the page text
and lower-cased, hashed lines to remove.
counter_inc_fn: function, a function taking the name of a counter to be
incremented and the (optional) amount.
min_num_sentences: int, the minimum number of sentences a page needs to not
be skipped.
Yields:
url: The URL of the page.
features: The page features with lines removed from text.
"""
url, join_values = el
features = join_values["features"]
assert len(features) == 1, "Invalid page count (%d) for %s" % (
len(features), url)
features = features[0]
text = features["text"]
lines_to_keep = set(join_values["lines"])
new_lines = []
hashed_lines = set()
for line in text.split("\n"):
hashed_line = _hash_text(line.strip().lower())
if hashed_line not in lines_to_keep:
counter_inc_fn("filtered-lines-global_duplicate")
elif hashed_line in hashed_lines:
counter_inc_fn("filtered-lines-local_duplicate")
else:
new_lines.append(line)
hashed_lines.add(hashed_line)
new_text = "\n".join(new_lines)
if not new_text:
counter_inc_fn("filtered-deduped_page-empty")
return
if min_num_sentences and len(_get_sentences(new_text)) < min_num_sentences:
counter_inc_fn("filtered-deduped_page-toofewsentences")
return
new_features = features.copy()
new_features["text"] = new_text
yield (url, new_features)
def remove_duplicate_text(pages, min_num_sentences=_MIN_NUM_SENTENCES):
"""Utility to remove duplicate lines across text documents."""
# Output: url, lines
beam = tfds.core.lazy_imports.apache_beam
# Select a single URL for each line in the input pages.
# Hash before comparison to avoid biasing by domain.
# line, [url]
line_to_selected_url = (
pages
| beam.FlatMap(_emit_url_to_lines)
| beam.combiners.Top.PerKey(1, key=_hash_text, reverse=True))
# url, line
lines_to_keep = line_to_selected_url | beam.Map(lambda x: (x[1][0], x[0]))
# Output: url, text
final_docs = (
{
"features": pages,
"lines": lines_to_keep
}
| "group_features_and_lines_by_url" >> beam.CoGroupByKey()
| beam.FlatMap(
_remove_lines_from_text,
counter_inc_fn=get_counter_inc_fn("dedupe-lines"),
min_num_sentences=min_num_sentences))
return final_docs
def split_wet_file(wet_file_path, counter_inc_fn=None):
"""Split a WET file into separate pages."""
logging.info("Splitting file: %s", wet_file_path)
if not counter_inc_fn:
counter_inc_fn = get_counter_inc_fn("split-wet-file")
counter_inc_fn("wet-file")
with tf.io.gfile.GFile(wet_file_path, "rb") as f, gzip.GzipFile(
fileobj=f) as g:
url = None
content = None
content_len = None
content_type = None
timestamp = None
def _maybe_get_page():
"""Generate a (url, {features}) page."""
if not url and url is not None:
counter_inc_fn("page-filtered-nourl")
if not content and content is not None:
counter_inc_fn("page-filtered-nocontent")
if not content_type and content_type is not None:
counter_inc_fn("page-nocontenttype")
if not content_len and content_len is not None:
counter_inc_fn("page-nocontentlen")
if not timestamp and timestamp is not None:
counter_inc_fn("page-notimestamp")
if content and url:
counter_inc_fn("page-emitted")
return (url, {
"text": "\n".join(content),
"content-type": content_type,
"content-length": content_len,
"timestamp": timestamp,
"url": url
})
return None
for line in io.TextIOWrapper(g, encoding="utf-8"): # pytype: disable=wrong-arg-types
line = line.strip()
if not line:
continue
if line == _PAGE_DELIMITER:
page = _maybe_get_page()
if page:
yield page
url = ""
content = []
content_len = ""
content_type = ""
timestamp = ""
if line.startswith(_URL_KEY):
url = line[len(_URL_KEY):].strip()
if line.startswith(_URL_DATE):
timestamp = line[len(_URL_DATE):].strip()
if line.startswith(_CONTENT_TYPE):
content_type = line[len(_CONTENT_TYPE):].strip()
if line.startswith(_CONTENT_LEN):
content_len = line[len(_CONTENT_LEN):].strip()
if line.startswith(_METADATA_PREFIXES):
continue
content.append(line) # pytype: disable=attribute-error
page = _maybe_get_page()
if page:
yield page
def dedupe_urls(el):
"""Returns the first value for a given URL."""
counter_inc_fn = get_counter_inc_fn("dedupe-urls")
url, vals = el
cnt = 0
v = None
for v in vals:
cnt += 1
counter_inc_fn("filtered-url-duplicate", cnt - 1)
counter_inc_fn("unique-url")
return url, v
def is_valid_length(el, max_length=1.9e5):
"""Returns False iff page's text is too long."""
counter_inc_fn = get_counter_inc_fn("is-valid-length")
_, page = el
if len(page["text"]) > max_length:
counter_inc_fn("filtered-page-contenttoolong")
return False
counter_inc_fn("valid-length")
return True
def is_realnews_domain(el, realnews_domains):
"""Returns False iff page's (sub)domain is not allowed."""
counter_inc_fn = get_counter_inc_fn("is-realnews-domain")
url, _ = el
ext = tfds.core.lazy_imports.tldextract.extract(url)
main_domain = ext.domain + "." + ext.suffix
if main_domain not in realnews_domains:
counter_inc_fn("filtered-url-invaliddomain")
return False
allowed_subdomains = realnews_domains[main_domain]
if (isinstance(allowed_subdomains, list) and
ext.subdomain not in allowed_subdomains):
counter_inc_fn("filtered-url-invalidsubdomain")
return False
counter_inc_fn("realnews-domain")
return True
def filter_by_webtextlike(el):
"""Yields only pages with a matching WebText-like URL."""
counter_inc_fn = get_counter_inc_fn("filter-by-webtextlike")
url, join_values = el
text = join_values["text"]
webtextlike = join_values["webtextlike_urls"]
if not webtextlike:
counter_inc_fn("filtered-url-notwebtextlike")
return
if not text:
counter_inc_fn("missing-webtextlike")
return
assert len(text) == 1
counter_inc_fn("found-webtextlike")
yield url, text[0]
def normalize_url(el):
url, val = el
url = tf.compat.as_text(url)
url = re.sub(r"https?:\/\/(www\.)?", "", url)
url = re.sub(r"\?(utm_|ref|feed).*", "", url)
url = url.rstrip("/")
return url, val