diff --git a/deep_reference_parser/__init__.py b/deep_reference_parser/__init__.py index c18d70b..18a418c 100644 --- a/deep_reference_parser/__init__.py +++ b/deep_reference_parser/__init__.py @@ -2,9 +2,9 @@ # distracting on the command line. These lines here (while undesireable) # reduce the level of verbosity. +import os import sys import warnings -import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -19,21 +19,16 @@ from .common import download_model_artefact from .deep_reference_parser import DeepReferenceParser -from .logger import logger -from .model_utils import get_config -from .reference_utils import ( - break_into_chunks, - labels_to_prodigy, - load_data, +from .io import ( load_tsv, - prodigy_to_conll, - prodigy_to_lists, read_jsonl, read_pickle, - write_json, write_jsonl, write_pickle, write_to_csv, - write_txt, + write_tsv, ) +from .logger import logger +from .model_utils import get_config +from .reference_utils import break_into_chunks from .tokens_to_references import tokens_to_references diff --git a/deep_reference_parser/__version__.py b/deep_reference_parser/__version__.py index 0a989eb..8627b22 100644 --- a/deep_reference_parser/__version__.py +++ b/deep_reference_parser/__version__.py @@ -1,5 +1,5 @@ __name__ = "deep_reference_parser" -__version__ = "2020.3.0" +__version__ = "2020.3.1" __description__ = "Deep learning model for finding and parsing references" __url__ = "https://github.com/wellcometrust/deep_reference_parser" __author__ = "Wellcome Trust DataLabs Team" diff --git a/deep_reference_parser/deep_reference_parser.py b/deep_reference_parser/deep_reference_parser.py index eeaa0c0..b9b28cd 100644 --- a/deep_reference_parser/deep_reference_parser.py +++ b/deep_reference_parser/deep_reference_parser.py @@ -47,7 +47,7 @@ save_confusion_matrix, word2vec_embeddings, ) -from .reference_utils import load_tsv, read_pickle, write_pickle, write_to_csv +from .io import load_tsv, read_pickle, write_pickle, write_to_csv class DeepReferenceParser: @@ -456,7 +456,7 @@ def build_model( self.model = model - logger.debug(self.model.summary(line_length=150)) +# logger.debug(self.model.summary(line_length=150)) def train_model( self, epochs=25, batch_size=100, early_stopping_patience=5, metric="val_f1" diff --git a/deep_reference_parser/io/__init__.py b/deep_reference_parser/io/__init__.py index 613c7e6..4e7eaba 100644 --- a/deep_reference_parser/io/__init__.py +++ b/deep_reference_parser/io/__init__.py @@ -1 +1,2 @@ -from .io import read_jsonl, write_jsonl +from .io import (load_tsv, read_jsonl, read_pickle, write_jsonl, write_pickle, + write_to_csv, write_tsv) diff --git a/deep_reference_parser/io/io.py b/deep_reference_parser/io/io.py index afa2cd4..639b43d 100644 --- a/deep_reference_parser/io/io.py +++ b/deep_reference_parser/io/io.py @@ -6,9 +6,74 @@ """ import json +import pickle +import csv +import os +import pandas as pd from ..logger import logger +def _split_list_by_linebreaks(tokens): + """Cycle through a list of tokens (or labels) and split them into lists + based on the presence of Nones or more likely math.nan caused by converting + pd.DataFrame columns to lists. + """ + out = [] + tokens_gen = iter(tokens) + while True: + try: + token = next(tokens_gen) + if isinstance(token, str) and token: + out.append(token) + else: + yield out + out = [] + except StopIteration: + if out: + yield out + break + +def load_tsv(filepath, split_char="\t"): + """ + Load and return the data stored in the given path. + + Expects data in the following format (tab separations). + + References o o + o o + 1 o o + . o o + o o + WHO title b-r + treatment title i-r + guidelines title i-r + for title i-r + drug title i-r + - title i-r + resistant title i-r + tuberculosis title i-r + , title i-r + 2016 title i-r + + + + Args: + filepath (str): Path to the data. + split_char(str): Character to be used to split each line of the + document. + + Returns: + a series of lists depending on the number of label columns provided in + filepath. + + """ + + df = pd.read_csv(filepath, delimiter=split_char, header=None, skip_blank_lines=False) + out = [list(_split_list_by_linebreaks(column)) for _, column in df.iteritems()] + + logger.info("Loaded %s training examples", len(out[0])) + + return tuple(out) def write_jsonl(input_data, output_file): """ @@ -61,3 +126,71 @@ def read_jsonl(input_file): logger.debug("Read %s lines from %s", len(out), input_file) return out + + +def write_to_csv(filename, columns, rows): + """ + Create a .csv file from data given as columns and rows + + Args: + filename(str): Path and name of the .csv file, without csv extension + columns(list): Columns of the csv file (First row of the file) + rows: Data to write into the csv file, given per row + """ + + with open(filename, "w") as csvfile: + wr = csv.writer(csvfile, quoting=csv.QUOTE_ALL) + wr.writerow(columns) + + for i, row in enumerate(rows): + wr.writerow(row) + logger.info("Wrote results to %s", filename) + + +def write_pickle(input_data, output_file, path=None): + """ + Write an object to pickle + + Args: + input_data(dict): A dict to be written to json. + output_file(str): A filename or path to which the json will be saved. + path(str): A string which will be prepended onto `output_file` with + `os.path.join()`. Obviates the need for lengthy `os.path.join` + statements each time this function is called. + """ + + if path: + + output_file = os.path.join(path, output_file) + + with open(output_file, "wb") as fb: + pickle.dump(input_data, fb) + + +def read_pickle(input_file, path=None): + """Create a list from a jsonl file + + Args: + input_file(str): File to be loaded. + path(str): A string which will be prepended onto `input_file` with + `os.path.join()`. Obviates the need for lengthy `os.path.join` + statements each time this function is called. + """ + + if path: + input_file = os.path.join(path, input_file) + + with open(input_file, "rb") as fb: + out = pickle.load(fb) + + logger.debug("Read data from %s", input_file) + + return out + +def write_tsv(token_label_pairs, output_path): + """ + Write tsv files to disk + """ + with open(output_path, "w") as fb: + writer = csv.writer(fb, delimiter="\t") + writer.writerows(token_label_pairs) diff --git a/deep_reference_parser/prodigy/__init__.py b/deep_reference_parser/prodigy/__init__.py index f90ce43..c582cc9 100644 --- a/deep_reference_parser/prodigy/__init__.py +++ b/deep_reference_parser/prodigy/__init__.py @@ -6,3 +6,5 @@ from .reach_to_prodigy import ReachToProdigy, reach_to_prodigy from .reference_to_token_annotations import TokenTagger, reference_to_token_annotations from .spacy_doc_to_prodigy import SpacyDocToProdigy +from .misc import prodigy_to_conll +from .labels_to_prodigy import labels_to_prodigy diff --git a/deep_reference_parser/prodigy/labels_to_prodigy.py b/deep_reference_parser/prodigy/labels_to_prodigy.py new file mode 100644 index 0000000..b6107d4 --- /dev/null +++ b/deep_reference_parser/prodigy/labels_to_prodigy.py @@ -0,0 +1,57 @@ +def labels_to_prodigy(tokens, labels): + """ + Converts a list of tokens and labels like those used by Rodrigues et al, + and converts to prodigy format dicts. + + Args: + tokens (list): A list of tokens. + labels (list): A list of labels relating to `tokens`. + + Returns: + A list of prodigy format dicts containing annotated data. + """ + + prodigy_data = [] + + all_token_index = 0 + + for line_index, line in enumerate(tokens): + prodigy_example = {} + + tokens = [] + spans = [] + token_start_offset = 0 + + for token_index, token in enumerate(line): + + token_end_offset = token_start_offset + len(token) + + tokens.append( + { + "text": token, + "id": token_index, + "start": token_start_offset, + "end": token_end_offset, + } + ) + + spans.append( + { + "label": labels[line_index][token_index : token_index + 1][0], + "start": token_start_offset, + "end": token_end_offset, + "token_start": token_index, + "token_end": token_index, + } + ) + + prodigy_example["text"] = " ".join(line) + prodigy_example["tokens"] = tokens + prodigy_example["spans"] = spans + prodigy_example["meta"] = {"line": line_index} + + token_start_offset = token_end_offset + 1 + + prodigy_data.append(prodigy_example) + + return prodigy_data diff --git a/deep_reference_parser/prodigy/misc.py b/deep_reference_parser/prodigy/misc.py new file mode 100644 index 0000000..c1f8d5c --- /dev/null +++ b/deep_reference_parser/prodigy/misc.py @@ -0,0 +1,38 @@ +import spacy + + +def _join_prodigy_tokens(text): + """Return all prodigy tokens in a single string + """ + + return "\n".join([str(i) for i in text]) + + +def prodigy_to_conll(docs): + """ + Expect list of jsons loaded from a jsonl + """ + + nlp = spacy.load("en_core_web_sm") + texts = [doc["text"] for doc in docs] + docs = list(nlp.tokenizer.pipe(texts)) + + out = [_join_prodigy_tokens(i) for i in docs] + + out_str = "DOCSTART\n\n" + "\n\n".join(out) + + return out_str + + +def prodigy_to_lists(docs): + """ + Expect list of jsons loaded from a jsonl + """ + + nlp = spacy.load("en_core_web_sm") + texts = [doc["text"] for doc in docs] + docs = list(nlp.tokenizer.pipe(texts)) + + out = [[str(token) for token in doc] for doc in docs] + + return out diff --git a/deep_reference_parser/prodigy/prodigy_to_tsv.py b/deep_reference_parser/prodigy/prodigy_to_tsv.py index 50488fa..41a8716 100644 --- a/deep_reference_parser/prodigy/prodigy_to_tsv.py +++ b/deep_reference_parser/prodigy/prodigy_to_tsv.py @@ -20,7 +20,7 @@ msg = Printer() -ROWS_TO_PRINT=15 +ROWS_TO_PRINT = 15 class TokenLabelPairs: @@ -375,8 +375,6 @@ def prodigy_to_tsv( with open(output_file, "w") as fb: writer = csv.writer(fb, delimiter="\t") - # Write DOCSTART and a blank line - # writer.writerows([("DOCSTART", None), (None, None)]) writer.writerows(merged_pairs) # Print out the first ten rows as a sense check diff --git a/deep_reference_parser/reference_utils.py b/deep_reference_parser/reference_utils.py index 3a2fcd5..fc8e8ab 100644 --- a/deep_reference_parser/reference_utils.py +++ b/deep_reference_parser/reference_utils.py @@ -4,426 +4,9 @@ """ """ -import csv -import json -import os -import pickle - -import spacy - from .logger import logger -def load_data(filepath): - """ - Load and return the data stored in the given path. - - Adapted from: https://github.com/dhlab-epfl/LinkedBooksDeepReferenceParsing - - The data is structured as follows: - * Each line contains four columns separated by a single space. - * Each word has been put on a separate line and there is an empty line - after each sentence. - * The first item on each line is a word, the second, third and fourth are - tags related to the word. - - Example: - - The sentence "L. Antonielli, Iprefetti dell' Italia napoleonica, Bologna - 1983." is represented in the dataset as: - - ``` - L author b-secondary b-r - . author i-secondary i-r - Antonielli author i-secondary i-r - , author i-secondary i-r - Iprefetti title i-secondary i-r - dell title i-secondary i-r - ’ title i-secondary i-r - Italia title i-secondary i-r - napoleonica title i-secondary i-r - , title i-secondary i-r - Bologna publicationplace i-secondary i-r - 1983 year e-secondary i-r - . year e-secondary e-r - ``` - - Args: - filepath (str): Path to the data. - - Returns: - four lists: The first contains tokens, the next three contain - corresponding labels. - - """ - - # Arrays to return - words = [] - tags_1 = [] - tags_2 = [] - tags_3 = [] - - word = tags1 = tags2 = tags3 = [] - with open(filepath, "r") as file: - for line in file: - # Do not take the first line into consideration - - if "DOCSTART" not in line: - # Check if empty line - - if line in ["\n", "\r\n"]: - # Append line - - words.append(word) - tags_1.append(tags1) - tags_2.append(tags2) - tags_3.append(tags3) - - # Reset - word = [] - tags1 = [] - tags2 = [] - tags3 = [] - - else: - # Split the line into words, tag #1 - w = line[:-1].split(" ") - - word.append(w[0]) - tags1.append(w[1]) - tags2.append(w[2]) - tags3.append(w[3]) - - logger.info("Loaded %s training examples", len(words)) - - return words, tags_1, tags_2, tags_3 - - -def load_tsv(filepath, split_char="\t"): - """ - Load and return the data stored in the given path. - - Adapted from: https://github.com/dhlab-epfl/LinkedBooksDeepReferenceParsing - - NOTE: In the current implementation in deep_reference_parser, only one set - of tags is used. The others will be used in a later PR. - - The data is structured as follows: - * Each line contains four columns separated by a single space. - * Each word has been put on a separate line and there is an empty line - after each sentence. - * The first item on each line is a word, the second, third and fourth are - tags related to the word. - - Args: - filepath (str): Path to the data. - split_char(str): Character to be used to split each line of the - document. - - Returns: - two lists: The first contains tokens, the second contains corresponding - labels. - - """ - - # Arrays to return - words = [] - tags_1 = [] - - word = [] - tags1 = [] - - with open(filepath, "r") as file: - for line in file: - # Check if empty line - - if line in ["\n", "\r\n", "\t\n"]: - # Append line - - words.append(word) - tags_1.append(tags1) - - # Reset - word = [] - tags1 = [] - - else: - - # Split the line into words, tag #1 - - w = line[:-1].split(split_char) - word.append(w[0]) - - # If tags are passed, (for training) then also add - - if len(w) == 2: - - tags1.append(w[1]) - - logger.info("Loaded %s training examples", len(words)) - - return words, tags_1 - - -def prodigy_to_conll(docs): - """ - Expect list of jsons loaded from a jsonl - """ - - nlp = spacy.load("en_core_web_sm") - texts = [doc["text"] for doc in docs] - docs = list(nlp.tokenizer.pipe(texts)) - - out = [_join_prodigy_tokens(i) for i in docs] - - out_str = "DOCSTART\n\n" + "\n\n".join(out) - - return out_str - - -def prodigy_to_lists(docs): - """ - Expect list of jsons loaded from a jsonl - """ - - nlp = spacy.load("en_core_web_sm") - texts = [doc["text"] for doc in docs] - docs = list(nlp.tokenizer.pipe(texts)) - - out = [[str(token) for token in doc] for doc in docs] - - return out - - -def _join_prodigy_tokens(text): - """Return all prodigy tokens in a single string - """ - - return "\n".join([str(i) for i in text]) - - -def write_json(input_data, output_file, path=None): - """ - Write a dict to json - - Args: - input_data(dict): A dict to be written to json. - output_file(str): A filename or path to which the json will be saved. - path(str): A string which will be prepended onto `output_file` with - `os.path.join()`. Obviates the need for lengthy `os.path.join` - statements each time this function is called. - """ - - if path: - - output_file = os.path.join(path, output_file) - - logger.info("Writing data to %s", output_file) - - with open(output_file, "w") as fb: - fb.write(json.dumps(input_data)) - - -def write_jsonl(input_data, output_file, path=None): - """ - Write a dict to jsonl (line delimited json) - - Output format will look like: - - ``` - {"a": 0} - {"b": 1} - {"c": 2} - {"d": 3} - ``` - - Args: - input_data(dict): A dict to be written to json. - output_file(str): A filename or path to which the json will be saved. - path(str): A string which will be prepended onto `output_file` with - `os.path.join()`. Obviates the need for lengthy `os.path.join` - statements each time this function is called. - """ - - if path: - - output_file = os.path.join(path, output_file) - - with open(output_file, "w") as fb: - - # Check if a dict (and convert to list if so) - - if isinstance(input_data, dict): - input_data = [value for key, value in input_data.items()] - - # Write out to jsonl file - - logger.info("Writing %s lines to %s", len(input_data), output_file) - - for i in input_data: - json_ = json.dumps(i) + "\n" - fb.write(json_) - - -def read_jsonl(input_file, path=None): - """Create a list from a jsonl file - - Args: - input_file(str): File to be loaded. - path(str): A string which will be prepended onto `input_file` with - `os.path.join()`. Obviates the need for lengthy `os.path.join` - statements each time this function is called. - """ - - if path: - input_file = os.path.join(path, input_file) - - out = [] - with open(input_file, "r") as fb: - - logger.info("Reading contents of %s", input_file) - - for i in fb: - out.append(json.loads(i)) - - logger.info("Read %s lines from %s", len(out), input_file) - - return out - - -def write_txt(input_data, output_file): - """Write a text string to a file - - Args: - input_file (str): String to be written - output_file (str): File to be saved to - """ - - with open(output_file, "w") as fb: - fb.write(input_data) - - logger.info("Read %s characters to file: %s", len(input_data), output_file) - - -def labels_to_prodigy(tokens, labels): - """ - Converts a list of tokens and labels like those used by Rodrigues et al, - and converts to prodigy format dicts. - - Args: - tokens (list): A list of tokens. - labels (list): A list of labels relating to `tokens`. - - Returns: - A list of prodigy format dicts containing annotated data. - """ - - prodigy_data = [] - - all_token_index = 0 - - for line_index, line in enumerate(tokens): - prodigy_example = {} - - tokens = [] - spans = [] - token_start_offset = 0 - - for token_index, token in enumerate(line): - - token_end_offset = token_start_offset + len(token) - - tokens.append( - { - "text": token, - "id": token_index, - "start": token_start_offset, - "end": token_end_offset, - } - ) - - spans.append( - { - "label": labels[line_index][token_index : token_index + 1][0], - "start": token_start_offset, - "end": token_end_offset, - "token_start": token_index, - "token_end": token_index, - } - ) - - prodigy_example["text"] = " ".join(line) - prodigy_example["tokens"] = tokens - prodigy_example["spans"] = spans - prodigy_example["meta"] = {"line": line_index} - - token_start_offset = token_end_offset + 1 - - prodigy_data.append(prodigy_example) - - return prodigy_data - - -def write_to_csv(filename, columns, rows): - """ - Create a .csv file from data given as columns and rows - - Args: - filename(str): Path and name of the .csv file, without csv extension - columns(list): Columns of the csv file (First row of the file) - rows: Data to write into the csv file, given per row - """ - - with open(filename, "w") as csvfile: - wr = csv.writer(csvfile, quoting=csv.QUOTE_ALL) - wr.writerow(columns) - - for i, row in enumerate(rows): - wr.writerow(row) - logger.info("Wrote results to %s", filename) - - -def write_pickle(input_data, output_file, path=None): - """ - Write an object to pickle - - Args: - input_data(dict): A dict to be written to json. - output_file(str): A filename or path to which the json will be saved. - path(str): A string which will be prepended onto `output_file` with - `os.path.join()`. Obviates the need for lengthy `os.path.join` - statements each time this function is called. - """ - - if path: - - output_file = os.path.join(path, output_file) - - with open(output_file, "wb") as fb: - pickle.dump(input_data, fb) - - -def read_pickle(input_file, path=None): - """Create a list from a jsonl file - - Args: - input_file(str): File to be loaded. - path(str): A string which will be prepended onto `input_file` with - `os.path.join()`. Obviates the need for lengthy `os.path.join` - statements each time this function is called. - """ - - if path: - input_file = os.path.join(path, input_file) - - with open(input_file, "rb") as fb: - out = pickle.load(fb) - - logger.debug("Read data from %s", input_file) - - return out - - def yield_token_label_pairs(tokens, labels): """ Convert matching lists of tokens and labels to tuples of (token, label) but @@ -443,15 +26,6 @@ def yield_token_label_pairs(tokens, labels): yield (None, None) -def write_tsv(token_label_pairs, output_path): - """ - Write tsv files to disk - """ - with open(output_path, "w") as fb: - writer = csv.writer(fb, delimiter="\t") - writer.writerows(token_label_pairs) - - def break_into_chunks(doc, max_words=250): """ Breaks a list into lists of lists of length max_words diff --git a/requirements.txt b/requirements.txt index 7e84c93..ca18031 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,9 +42,9 @@ sklearn-crfsuite==0.3.6 spacy==2.1.7 srsly==1.0.1 tabulate==0.8.6 -tensorboard==1.14.0 +tensorboard==1.15.0 tensorflow==1.15.2 -tensorflow-estimator==1.14.0 +tensorflow-estimator==1.15.1 termcolor==1.1.0 thinc==7.0.8 tqdm==4.42.1 diff --git a/tests/common.py b/tests/common.py index 21cfb4e..2bf6107 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,3 +13,4 @@ def get_path(p): TEST_REFERENCES = get_path("test_data/test_references.txt") TEST_TSV_PREDICT = get_path("test_data/test_tsv_predict.tsv") TEST_TSV_TRAIN = get_path("test_data/test_tsv_train.tsv") +TEST_LOAD_TSV = get_path("test_data/test_load_tsv.tsv") diff --git a/tests/test_labels_to_prodigy.py b/tests/prodigy/test_labels_to_prodigy.py similarity index 97% rename from tests/test_labels_to_prodigy.py rename to tests/prodigy/test_labels_to_prodigy.py index 53b8d77..0ef67c8 100644 --- a/tests/test_labels_to_prodigy.py +++ b/tests/prodigy/test_labels_to_prodigy.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # coding: utf-8 -from deep_reference_parser.reference_utils import labels_to_prodigy +from deep_reference_parser.prodigy import labels_to_prodigy def test_labels_to_prodigy(): diff --git a/tests/prodigy/test_misc.py b/tests/prodigy/test_misc.py new file mode 100644 index 0000000..5ed18df --- /dev/null +++ b/tests/prodigy/test_misc.py @@ -0,0 +1,19 @@ +from deep_reference_parser.prodigy import prodigy_to_conll + + +def test_prodigy_to_conll(): + + before = [ + {"text": "References",}, + {"text": "37. No single case of malaria reported in"}, + { + "text": "an essential requirement for the correct labelling of potency for therapeutic" + }, + {"text": "EQAS, quality control for STI"}, + ] + + after = "DOCSTART\n\nReferences\n\n37\n.\nNo\nsingle\ncase\nof\nmalaria\nreported\nin\n\nan\nessential\nrequirement\nfor\nthe\ncorrect\nlabelling\nof\npotency\nfor\ntherapeutic\n\nEQAS\n,\nquality\ncontrol\nfor\nSTI" + + out = prodigy_to_conll(before) + + assert after == out diff --git a/tests/test_data/test_load_tsv.tsv b/tests/test_data/test_load_tsv.tsv new file mode 100644 index 0000000..ad64c5d --- /dev/null +++ b/tests/test_data/test_load_tsv.tsv @@ -0,0 +1,18 @@ +the i-r a +focus i-r a +in i-r a +Daloa i-r a +, i-r a +Côte i-r a +d’Ivoire]. i-r a + +Bulletin i-r a +de i-r a +la i-r a +Société i-r a +de i-r a +Pathologie i-r a + +Exotique i-r a +et i-r a + diff --git a/tests/test_io.py b/tests/test_io.py index 3799131..dd86061 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,13 +2,19 @@ # coding: utf-8 import os -import tempfile import pytest -from deep_reference_parser.io import read_jsonl, write_jsonl +from deep_reference_parser.io.io import ( + read_jsonl, + write_jsonl, + load_tsv, + write_tsv, + _split_list_by_linebreaks, +) +from deep_reference_parser.reference_utils import yield_token_label_pairs -from .common import TEST_JSONL +from .common import TEST_JSONL, TEST_TSV_TRAIN, TEST_TSV_PREDICT, TEST_LOAD_TSV @pytest.fixture(scope="module") @@ -16,6 +22,212 @@ def tmpdir(tmpdir_factory): return tmpdir_factory.mktemp("data") +def test_write_tsv(tmpdir): + + expected = ( + [ + [], + ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], + ["Bulletin", "de", "la", "Société", "de", "Pathologie"], + ["Exotique", "et"], + ], + [ + [], + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r"], + ], + ) + + token_label_tuples = list(yield_token_label_pairs(expected[0], expected[1])) + + PATH = os.path.join(tmpdir, "test_tsv.tsv") + write_tsv(token_label_tuples, PATH) + actual = load_tsv(os.path.join(PATH)) + + assert expected == actual + + +def test_load_tsv_train(): + """ + Text of TEST_TSV_TRAIN: + + ``` + the i-r + focus i-r + in i-r + Daloa i-r + , i-r + Côte i-r + d’Ivoire]. i-r + + Bulletin i-r + de i-r + la i-r + Société i-r + de i-r + Pathologie i-r + + Exotique i-r + et i-r + ``` + """ + + expected = ( + [ + ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], + ["Bulletin", "de", "la", "Société", "de", "Pathologie"], + ["Exotique", "et"], + ], + [ + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r"], + ], + ) + + actual = load_tsv(TEST_TSV_TRAIN) + + assert len(actual[0][0]) == len(expected[0][0]) + assert len(actual[0][1]) == len(expected[0][1]) + assert len(actual[0][2]) == len(expected[0][2]) + + assert len(actual[1][0]) == len(expected[1][0]) + assert len(actual[1][1]) == len(expected[1][1]) + assert len(actual[1][2]) == len(expected[1][2]) + + assert actual == expected + + +def test_load_tsv_predict(): + """ + Text of TEST_TSV_PREDICT: + + ``` + the + focus + in + Daloa + , + Côte + d’Ivoire]. + + Bulletin + de + la + Société + de + Pathologie + + Exotique + et + ``` + """ + + expected = ( + [ + ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], + ["Bulletin", "de", "la", "Société", "de", "Pathologie"], + ["Exotique", "et"], + ], + ) + + actual = load_tsv(TEST_TSV_PREDICT) + + assert actual == expected + + +def test_load_tsv_train_multiple_labels(): + """ + Text of TEST_TSV_TRAIN: + + ``` + the i-r + focus i-r + in i-r + Daloa i-r + , i-r + Côte i-r + d’Ivoire]. i-r + + Bulletin i-r + de i-r + la i-r + Société i-r + de-r + Pathologie i-r + + Exotique i-r + et i-r + ``` + """ + + expected = ( + [ + ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], + ["Bulletin", "de", "la", "Société", "de", "Pathologie"], + ["Exotique", "et"], + ], + [ + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r"], + ], + [ + ["a", "a", "a", "a", "a", "a", "a"], + ["a", "a", "a", "a", "a", "a"], + ["a", "a"], + ], + ) + + actual = load_tsv(TEST_LOAD_TSV) + + assert actual == expected + + +def test_yield_toke_label_pairs(): + + tokens = [ + [], + ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], + ["Bulletin", "de", "la", "Société", "de", "Pathologie"], + ["Exotique", "et"], + ] + + labels = [ + [], + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], + ["i-r", "i-r"], + ] + + expected = [ + (None, None), + ("the", "i-r"), + ("focus", "i-r"), + ("in", "i-r"), + ("Daloa", "i-r"), + (",", "i-r"), + ("Côte", "i-r"), + ("d’Ivoire].", "i-r"), + (None, None), + ("Bulletin", "i-r"), + ("de", "i-r"), + ("la", "i-r"), + ("Société", "i-r"), + ("de", "i-r"), + ("Pathologie", "i-r"), + (None, None), + ("Exotique", "i-r"), + ("et", "i-r"), + (None, None), + ] + + actual = list(yield_token_label_pairs(tokens, labels)) + + assert expected == actual + + def test_read_jsonl(): expected = [ @@ -76,3 +288,27 @@ def test_write_jsonl(tmpdir): actual = read_jsonl(temp_file) assert expected == actual + + +def test_split_list_by_linebreaks(): + + lst = ["a", "b", "c", None, "d"] + expected = [["a", "b", "c"], ["d"]] + + actual = _split_list_by_linebreaks(lst) + + +def test_list_by_linebreaks_ending_in_None(): + + lst = ["a", "b", "c", float("nan"), "d", None] + expected = [["a", "b", "c"], ["d"]] + + actual = _split_list_by_linebreaks(lst) + + +def test_list_by_linebreaks_starting_in_None(): + + lst = [None, "a", "b", "c", None, "d"] + expected = [["a", "b", "c"], ["d"]] + + actual = _split_list_by_linebreaks(lst) diff --git a/tests/test_reference_utils.py b/tests/test_reference_utils.py index 118e399..c6d2091 100644 --- a/tests/test_reference_utils.py +++ b/tests/test_reference_utils.py @@ -1,192 +1,9 @@ #!/usr/bin/env python3 # coding: utf-8 -import os -import tempfile - import pytest -from deep_reference_parser.reference_utils import ( - break_into_chunks, - load_tsv, - prodigy_to_conll, - write_tsv, - yield_token_label_pairs, -) - -from .common import TEST_TSV_PREDICT, TEST_TSV_TRAIN - - -def test_prodigy_to_conll(): - - before = [ - {"text": "References",}, - {"text": "37. No single case of malaria reported in"}, - { - "text": "an essential requirement for the correct labelling of potency for therapeutic" - }, - {"text": "EQAS, quality control for STI"}, - ] - - after = "DOCSTART\n\nReferences\n\n37\n.\nNo\nsingle\ncase\nof\nmalaria\nreported\nin\n\nan\nessential\nrequirement\nfor\nthe\ncorrect\nlabelling\nof\npotency\nfor\ntherapeutic\n\nEQAS\n,\nquality\ncontrol\nfor\nSTI" - - out = prodigy_to_conll(before) - - assert after == out - - -def test_load_tsv_train(): - """ - Text of TEST_TSV_TRAIN: - - ``` - the i-r - focus i-r - in i-r - Daloa i-r - , i-r - Côte i-r - d’Ivoire]. i-r - - Bulletin i-r - de i-r - la i-r - Société i-r - de i-r - Pathologie i-r - - Exotique i-r - et i-r - ``` - """ - - expected = ( - [ - ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], - ["Bulletin", "de", "la", "Société", "de", "Pathologie"], - ["Exotique", "et"], - ], - [ - ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], - ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], - ["i-r", "i-r"], - ], - ) - - actual = load_tsv(TEST_TSV_TRAIN) - - assert actual == expected - - -def test_load_tsv_predict(): - """ - Text of TEST_TSV_PREDICT: - - ``` - the - focus - in - Daloa - , - Côte - d’Ivoire]. - - Bulletin - de - la - Société - de - Pathologie - - Exotique - et - ``` - """ - - expected = ( - [ - ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], - ["Bulletin", "de", "la", "Société", "de", "Pathologie"], - ["Exotique", "et"], - ], - [[], [], [],], - ) - - actual = load_tsv(TEST_TSV_PREDICT) - - assert actual == expected - - -def test_yield_toke_label_pairs(): - - tokens = [ - [], - ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], - ["Bulletin", "de", "la", "Société", "de", "Pathologie"], - ["Exotique", "et"], - ] - - labels = [ - [], - ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], - ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], - ["i-r", "i-r"], - ] - - expected = [ - (None, None), - ("the", "i-r"), - ("focus", "i-r"), - ("in", "i-r"), - ("Daloa", "i-r"), - (",", "i-r"), - ("Côte", "i-r"), - ("d’Ivoire].", "i-r"), - (None, None), - ("Bulletin", "i-r"), - ("de", "i-r"), - ("la", "i-r"), - ("Société", "i-r"), - ("de", "i-r"), - ("Pathologie", "i-r"), - (None, None), - ("Exotique", "i-r"), - ("et", "i-r"), - (None, None), - ] - - actual = list(yield_token_label_pairs(tokens, labels)) - - assert expected == actual - - -def test_write_tsv(): - - expected = ( - [ - [], - ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], - ["Bulletin", "de", "la", "Société", "de", "Pathologie"], - ["Exotique", "et"], - ], - [ - [], - ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], - ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], - ["i-r", "i-r"], - ], - ) - - _, path = tempfile.mkstemp() - - token_label_tuples = list(yield_token_label_pairs(expected[0], expected[1])) - - write_tsv(token_label_tuples, path) - actual = load_tsv(path) - - assert expected == actual - - os.remove(path) +from deep_reference_parser.reference_utils import break_into_chunks def test_break_into_chunks():