Skip to content

Commit

Permalink
fix test and update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ryokan0123 committed Oct 10, 2022
1 parent d6b72c7 commit 2484818
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 11 deletions.
6 changes: 3 additions & 3 deletions luke/pretraining/dataset.py
Expand Up @@ -60,7 +60,7 @@ def build_wikipedia_pretraining_dataset(
tokenizer_name: str,
entity_vocab_file: str,
output_dir: str,
language: str,
language: Optional[str],
sentence_splitter: str,
**kwargs
):
Expand Down Expand Up @@ -171,7 +171,7 @@ def build(
sentence_splitter: SentenceSplitter,
entity_vocab: EntityVocab,
output_dir: str,
language: str,
language: Optional[str],
max_seq_length: int,
max_entity_length: int,
max_mention_length: int,
Expand Down Expand Up @@ -255,7 +255,7 @@ def _initialize_worker(
tokenizer: PreTrainedTokenizer,
sentence_splitter: SentenceSplitter,
entity_vocab: EntityVocab,
language: str,
language: Optional[str],
max_num_tokens: int,
max_entity_length: int,
max_mention_length: int,
Expand Down
14 changes: 7 additions & 7 deletions luke/utils/entity_vocab.py
Expand Up @@ -6,7 +6,7 @@
from contextlib import closing
from multiprocessing.pool import Pool
from pathlib import Path
from typing import Dict, List, TextIO
from typing import Dict, List, TextIO, Optional

import click
from tqdm import tqdm
Expand Down Expand Up @@ -37,7 +37,7 @@
@click.option("--white-list-only", is_flag=True)
@click.option("--pool-size", default=multiprocessing.cpu_count())
@click.option("--chunk-size", default=100)
def build_entity_vocab(dump_db_file: str, white_list: List[TextIO], language: str, **kwargs):
def build_entity_vocab(dump_db_file: str, white_list: List[TextIO], language: Optional[str], **kwargs):
dump_db = DumpDB(dump_db_file)
white_list = [line.rstrip() for f in white_list for line in f]
EntityVocab.build(dump_db, white_list=white_list, language=language, **kwargs)
Expand Down Expand Up @@ -138,21 +138,21 @@ def __getitem__(self, key: str):
def __iter__(self):
return iter(self.vocab)

def contains(self, title: str, language: str = None):
def contains(self, title: str, language: Optional[str] = None):
return Entity(title, language) in self.vocab

def get_id(self, title: str, language: str = None, default: int = None) -> int:
def get_id(self, title: str, language: Optional[str] = None, default: int = None) -> int:
try:
return self.vocab[Entity(title, language)]
except KeyError:
return default

def get_title_by_id(self, id_: int, language: str = None) -> str:
def get_title_by_id(self, id_: int, language: Optional[str] = None) -> str:
for entity in self.inv_vocab[id_]:
if entity.language == language:
return entity.title

def get_count_by_title(self, title: str, language: str = None) -> int:
def get_count_by_title(self, title: str, language: Optional[str] = None) -> int:
entity = Entity(title, language)
return self.counter.get(entity, 0)

Expand Down Expand Up @@ -188,7 +188,7 @@ def build(
white_list_only: bool,
pool_size: int,
chunk_size: int,
language: str,
language: Optional[str],
):
counter = Counter()
with tqdm(total=dump_db.page_size(), mininterval=0.5) as pbar:
Expand Down
3 changes: 2 additions & 1 deletion pretraining.md
Expand Up @@ -116,7 +116,8 @@ python luke/cli.py \
<BASE_MODEL_NAME> \
mluke_entity_vocab.jsonl \
"mluke_pretraining_dataset/${l}" \
--sentence-splitter=$l
--sentence-splitter=$l \
--language $l
```
## 5. Compute the number of training steps
Expand Down
Empty file added tests/examples/__init__.py
Empty file.
1 change: 1 addition & 0 deletions tests/pretraining/test_dataset.py
Expand Up @@ -29,6 +29,7 @@ def test_build_and_read_dataset():
sentence_tokenizer,
entity_vocab,
temp_directory_path,
language=None,
max_seq_length=512,
max_entity_length=128,
max_mention_length=30,
Expand Down

0 comments on commit 2484818

Please sign in to comment.