Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add user dic support #72

Merged
merged 3 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ In [3]: pyopenjtalk.g2p("こんにちは", kana=True)
Out[3]: 'コンニチワ'
```

### Create/Apply user dictionary

1. Create a CSV file (e.g. `user.csv`) and write custom words like below:

```csv
GNU,,,1,名詞,一般,*,*,*,*,GNU,グヌー,グヌー,2/3,*
```

2. Call `mecab_dict_index` to compile the CSV file.

```python
In [1]: import pyopenjtalk

In [2]: pyopenjtalk.mecab_dict_index("user.csv", "user.dic")
reading user.csv ... 1
emitting double-array: 100% |###########################################|

done!
```

3. Call `update_global_jtalk_with_user_dict` to apply the user dictionary.

```python
In [3]: pyopenjtalk.g2p("GNU")
Out[3]: 'j i i e n u y u u'

In [4]: pyopenjtalk.update_global_jtalk_with_user_dict("user.dic")

In [5]: pyopenjtalk.g2p("GNU")
Out[5]: 'g u n u u'
```

### About `run_marine` option

After v0.3.0, the `run_marine` option has been available for estimating the Japanese accent with the DNN-based method (see [marine](https://github.com/6gsn/marine)). If you want to use the feature, please install pyopenjtalk as below;
Expand Down
42 changes: 42 additions & 0 deletions pyopenjtalk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .htsengine import HTSEngine
from .openjtalk import OpenJTalk
from .openjtalk import mecab_dict_index as _mecab_dict_index
from .utils import merge_njd_marine_features

# Dictionary directory
Expand Down Expand Up @@ -224,3 +225,44 @@ def make_label(njd_features):
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.make_label(njd_features)


def mecab_dict_index(path, out_path, dn_mecab=None):
"""Create user dictionary

Args:
path (str): path to user csv
out_path (str): path to output dictionary
dn_mecab (optional. str): path to mecab dictionary
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
if not exists(path):
raise FileNotFoundError("no such file or directory: %s" % path)
if dn_mecab is None:
dn_mecab = OPEN_JTALK_DICT_DIR
r = _mecab_dict_index(dn_mecab, path.encode("utf-8"), out_path.encode("utf-8"))

# NOTE: mecab load returns 1 if success, but mecab_dict_index return the opposite
# yeah it's confusing...
if r != 0:
raise RuntimeError("Failed to create user dictionary")


def update_global_jtalk_with_user_dict(path):
"""Update global openjtalk instance with the user dictionary

Note that this will change the global state of the openjtalk module.

Args:
path (str): path to user dictionary
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
if not exists(path):
raise FileNotFoundError("no such file or directory: %s" % path)
_global_jtalk = OpenJTalk(
dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8")
)
60 changes: 56 additions & 4 deletions pyopenjtalk/openjtalk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ np.import_array()

cimport cython
from libc.stdlib cimport calloc
from libc.string cimport strlen

from .openjtalk.mecab cimport Mecab, Mecab_initialize, Mecab_load, Mecab_analysis
from .openjtalk.mecab cimport Mecab_get_feature, Mecab_get_size, Mecab_refresh, Mecab_clear
from .openjtalk.mecab cimport createModel, Model, Tagger, Lattice
from .openjtalk.mecab cimport mecab_dict_index as _mecab_dict_index
from .openjtalk.njd cimport NJD, NJD_initialize, NJD_refresh, NJD_print, NJD_clear
from .openjtalk cimport njd as _njd
from .openjtalk.jpcommon cimport JPCommon, JPCommon_initialize,JPCommon_make_label
Expand Down Expand Up @@ -116,18 +119,52 @@ cdef feature2njd(_njd.NJD* njd, features):
_njd.NJDNode_set_chain_flag(node, feature_node["chain_flag"])
_njd.NJD_push_node(njd, node)

# based on Mecab_load in impl. from mecab.cpp
cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic):
if userdic == NULL or strlen(userdic) == 0:
return Mecab_load(m, dicdir)

if m == NULL or dicdir == NULL or strlen(dicdir) == 0:
return 0

Mecab_clear(m)

cdef (char*)[5] argv = ["mecab", "-d", dicdir, "-u", userdic]
cdef Model *model = createModel(5, argv)

if model == NULL:
return 0
m.model = model

cdef Tagger *tagger = model.createTagger()
if tagger == NULL:
Mecab_clear(m)
return 0
m.tagger = tagger

cdef Lattice *lattice = model.createLattice()
if lattice == NULL:
Mecab_clear(m)
return 0
m.lattice = lattice

return 1


cdef class OpenJTalk(object):
"""OpenJTalk

Args:
dn_mecab (bytes): Dictionaly path for MeCab.
userdic (bytes): Dictionary path for MeCab userdic.
This option is ignored when empty bytestring is given.
Default is empty.
"""
cdef Mecab* mecab
cdef NJD* njd
cdef JPCommon* jpcommon

def __cinit__(self, bytes dn_mecab=b"/usr/local/dic"):
def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""):
self.mecab = new Mecab()
self.njd = new NJD()
self.jpcommon = new JPCommon()
Expand All @@ -136,7 +173,7 @@ cdef class OpenJTalk(object):
NJD_initialize(self.njd)
JPCommon_initialize(self.jpcommon)

r = self._load(dn_mecab)
r = self._load(dn_mecab, userdic)
if r != 1:
self._clear()
raise RuntimeError("Failed to initalize Mecab")
Expand All @@ -147,8 +184,8 @@ cdef class OpenJTalk(object):
NJD_clear(self.njd)
JPCommon_clear(self.jpcommon)

def _load(self, bytes dn_mecab):
return Mecab_load(self.mecab, dn_mecab)
def _load(self, bytes dn_mecab, bytes userdic):
return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic)


def run_frontend(self, text):
Expand Down Expand Up @@ -231,3 +268,18 @@ cdef class OpenJTalk(object):
del self.mecab
del self.njd
del self.jpcommon

def mecab_dict_index(bytes dn_mecab, bytes path, bytes out_path):
cdef (char*)[10] argv = [
"mecab-dict-index",
"-d",
dn_mecab,
"-u",
out_path,
"-f",
"utf-8",
"-t",
"utf-8",
path
]
return _mecab_dict_index(10, argv)
11 changes: 11 additions & 0 deletions pyopenjtalk/openjtalk/mecab.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,14 @@ cdef extern from "mecab.h":
char **Mecab_get_feature(Mecab *m)
cdef int Mecab_refresh(Mecab *m)
cdef int Mecab_clear(Mecab *m)
cdef int mecab_dict_index(int argc, char **argv)

cdef extern from "mecab.h" namespace "MeCab":
cdef cppclass Tagger:
pass
cdef cppclass Lattice:
pass
cdef cppclass Model:
Tagger *createTagger()
Lattice *createLattice()
cdef Model *createModel(int argc, char **argv)
2 changes: 2 additions & 0 deletions tests/test_data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
28 changes: 28 additions & 0 deletions tests/test_openjtalk.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pyopenjtalk


Expand Down Expand Up @@ -80,3 +82,29 @@ def test_g2p_phone():
]:
p = pyopenjtalk.g2p(text, kana=False)
assert p == pron


def test_userdic():
for text, expected in [
("nnmn", "n a n a m i N"),
("GNU", "g u n u u"),
]:
p = pyopenjtalk.g2p(text)
assert p != expected

user_csv = str(Path(__file__).parent / "test_data" / "user.csv")
user_dic = str(Path(__file__).parent / "test_data" / "user.dic")

with open(user_csv, "w", encoding="utf-8") as f:
f.write("nnmn,,,1,名詞,一般,*,*,*,*,nnmn,ナナミン,ナナミン,1/4,*\n")
f.write("GNU,,,1,名詞,一般,*,*,*,*,GNU,グヌー,グヌー,2/3,*\n")

pyopenjtalk.mecab_dict_index(f.name, user_dic)
pyopenjtalk.update_global_jtalk_with_user_dict(user_dic)

for text, expected in [
("nnmn", "n a n a m i N"),
("GNU", "g u n u u"),
]:
p = pyopenjtalk.g2p(text)
assert p == expected
Loading