Skip to content

Commit

Permalink
A complete example (contact extraction). See GH-24.
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Aug 4, 2017
1 parent 5a3f39e commit 8fb60d3
Show file tree
Hide file tree
Showing 9 changed files with 424 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,7 @@ notebooks/*.marisa
notebooks/*.wapiti
notebooks/*.crfsuite
webstruct_data/corpus/random_pages/wa/*.html
webstruct_data/corpus/us_contact_pages/cleaned
webstruct_data/corpus/us_contact_pages/cleaned
example/_data/*
example/*.joblib
example/*.html
44 changes: 44 additions & 0 deletions example/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
Contact extraction using Webstruct
==================================

This repository contains code to train a model for contact and address
extraction. The result is a .joblib file with pickled webstruct.NER object.

Currently the example requires Python 3.5+

Training
--------

To train a model, first build gazetteers using built_gazetteers script::

python3 -m ner.build_gazetteers

It will create "_data" folder with city/state geonames data. The script uses
several GBs or RAM.

To train a CRF model run::

python3 -m ner.train

The model uses training data from opensource webstruct package
(mostly contact pages of US, CA and GB small business websites)
and provides 'ORG', 'TEL', 'FAX', 'HOURS', 'STREET', 'CITY', 'STATE',
'ZIPCODE', 'COUNTRY', and 'SUBJ' entities.

Script should produce "contact-extractor.joblib" file with a saved
webstruct.NER object and "crf-features.html" file with debugging
information about the model.

Usage
-----

To use the saved model code in this repository is not needed.
Make sure joblib, sklearn-crfsuite and webstruct are installed,
then load the model::

import joblib
ner = joblib.load('contact-extractor.joblib')
print(ner.extract_groups_from_url('<some URL>'))

See https://webstruct.readthedocs.io/en/latest/ref/model.html#webstruct.model.NER
for the API.
1 change: 1 addition & 0 deletions example/ner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
88 changes: 88 additions & 0 deletions example/ner/build_gazetteers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python3
import argparse
from pathlib import Path

import requests

from webstruct.gazetteers.geonames import read_geonames_zipped, to_dawg


FILES = [
('http://download.geonames.org/export/dump/allCountries.zip', 'allCountries.zip'),
('http://download.geonames.org/export/dump/cities1000.zip', 'cities1000.zip'),
('http://download.geonames.org/export/dump/cities5000.zip', 'cities5000.zip'),
('http://download.geonames.org/export/dump/cities15000.zip', 'cities15000.zip'),
]

DATA_ROOT = Path('_data')


def download_geonames():
""" Download geonames files if they don't exist in ./_data folder. """
DATA_ROOT.mkdir(exist_ok=True)
for url, name in FILES:
path = (DATA_ROOT / name)
if path.exists():
continue
print("downloading {}".format(url))
path.write_bytes(requests.get(url).content)


def _compile_cities(path: Path, lowercase: bool=False):
out_path = path.with_suffix('.dafsa')
# if out_path.exists():
# return
print("reading {}".format(path))
df = read_geonames_zipped(str(path))
if lowercase:
df = _to_lower(df)
print("compiling {}".format(out_path))
dawg = to_dawg(df)
dawg.save(str(out_path))


def _to_lower(df):
return df.assign(
main_name=df.main_name.str.lower(),
asciiname=df.asciiname.str.lower(),
alternatenames=df.alternatenames.str.lower(),
)


def _read_full():
path = DATA_ROOT / 'allCountries.zip'
print("reading {}".format(path))
return read_geonames_zipped(str(path))


def _compile_adm(df):
codes = ['ADM1', 'ADM2', 'ADM3', 'ADM4']
out_paths = [DATA_ROOT / "{}.dafsa".format(code.lower()) for code in codes]
# if all(p.exists() for p in out_paths):
# return
for code, out_path in zip(codes, out_paths):
# if out_path.exists():
# continue
print("compiling {}".format(out_path))
df_adm = df[df.feature_code == code]
dawg = to_dawg(df_adm)
dawg.save(str(out_path))


def compile_gazetteers_contacts(lowercase=False):
""" Compile geonames data downloaded by ``download_geonames``. """
for name in ['cities1000.zip', 'cities5000.zip', 'cities15000.zip']:
_compile_cities(DATA_ROOT / name, lowercase=lowercase)
df = _read_full()
if lowercase:
df = _to_lower(df)
_compile_adm(df)


if __name__ == '__main__':
p = argparse.ArgumentParser()
p.add_argument('--lower', action="store_true")
args = p.parse_args()

download_geonames()
compile_gazetteers_contacts(args.lower)
26 changes: 26 additions & 0 deletions example/ner/cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
import numpy as np


def crf_cross_val_predict(pipe, X, y, cv, groups=None, n_folds=None):
"""
Split data into folds according to cv iterator, do train/test prediction
on first n_folds (or on all folds if n_folds is None).
"""
X, y = np.array(X), np.array(y)
y_pred = []
y_true = []

for idx, (train_idx, test_idx) in enumerate(cv.split(X, y, groups)):
if n_folds and idx >= n_folds:
break

X_train, X_dev = X[train_idx], X[test_idx]
y_train, y_dev = y[train_idx], y[test_idx]
pipe.fit(X_train, y_train, X_dev=X_dev, y_dev=y_dev)
y_true.append(y_dev)
y_pred.append(pipe.predict(X_dev))

y_pred = np.hstack(y_pred)
y_true = np.hstack(y_true)
return y_pred, y_true
56 changes: 56 additions & 0 deletions example/ner/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from itertools import chain
from pathlib import Path
from typing import List, Tuple, Any, Set

import webstruct

from .utils import pages_progress


WEBSTRUCT_DATA = Path(__name__).parent / ".." / "webstruct_data"
GAZETTEER_DATA = Path(__name__).parent / "_data"


KNOWN_ENTITIES = [
'ORG', 'TEL', 'FAX', 'HOURS',
'STREET', 'CITY', 'STATE', 'ZIPCODE', 'COUNTRY',
'EMAIL', 'PER', 'FUNC', 'SUBJ'
]
CONTACT_ENTITIES = [
'ORG', 'TEL', 'FAX', 'HOURS',
'STREET', 'CITY', 'STATE', 'ZIPCODE', 'COUNTRY',
'SUBJ',
]
ADDRESS_ENTITIES = [
'STREET', 'CITY', 'STATE', 'ZIPCODE', 'COUNTRY',
]


def load_webstruct_data() -> List:
"""
Load training data from webstruct repository.
It is a mess: there are two folders which have OK data, one
is stored in WebAnnotator format, another is stored in GATE format.
"""
wa_loader = webstruct.WebAnnotatorLoader(known_entities=KNOWN_ENTITIES)
gate_loader = webstruct.GateLoader(known_entities=KNOWN_ENTITIES)

trees1 = webstruct.load_trees(
str(WEBSTRUCT_DATA / "corpus/business_pages/wa/*.html"),
loader=wa_loader,
)

trees2 = webstruct.load_trees(
str(WEBSTRUCT_DATA / "corpus/us_contact_pages/annotated/*.xml"),
loader=gate_loader
)
trees = chain(trees1, trees2)
return list(pages_progress(trees, desc="Loading webstruct default annotated data"))


def load_countries() -> Set[str]:
countries_path = WEBSTRUCT_DATA / 'gazetteers/countries/countries.txt'
return set(countries_path.read_text().splitlines())

0 comments on commit 8fb60d3

Please sign in to comment.