Skip to content

yagays/pytorch_bert_japanese

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 

Repository files navigation

PytorchでBERTの日本語学習済みモデルを利用する

これはPytorchで日本語の学習済みBERTモデルを読み込み、文章ベクトル(Sentence Embedding)を計算するためのコードです。

詳細は下記ブログを参考ください。

PytorchでBERTの日本語学習済みモデルを利用する - 文章埋め込み編

環境

準備

日本語の学習済みBERTモデル

京都大学の黒橋・河原研究室が公開している「BERT日本語Pretrainedモデル」を利用します。下記ウェブページからモデルファイルをダウンロードして解凍してください。

BERT日本語Pretrainedモデル - KUROHASHI-KAWAHARA LAB

Juman++

Juman++をインストールします。インストール方法については、下記の公式レポジトリを参照ください。

ku-nlp/jumanpp: Juman++ (a Morphological Analyzer Toolkit)

なお、macOSならばHomebrewを使って下記のように簡単にインストールできます。

$ brew install jumanpp

Pythonパッケージ

pytorch-pretrained-bertおよびpyknpをインストールします。

$ pip install pytorch-pretrained-bert
$ pip install pyknp

なお、ここではPytorchをBERT実装に利用するので、Pytorchはインストールされているものとします。

PyTorch

実行する

本レポジトリのbert_juman.pyからBertWithJumanModelクラスをインポートします。クラスの引数には、ダウンロードした日本語の学習済みBERTモデルのディレクトリを指定します。必要なファイルはpytorch_model.binvocab.txtのみです。

In []: from bert_juman import BertWithJumanModel

In []: bert = BertWithJumanModel("/path/to/Japanese_L-12_H-768_A-12_E-30_BPE")

In []: bert.get_sentence_embedding("吾輩は猫である。")
Out[]:
array([ 2.22642735e-01, -2.40221739e-01,  1.09303640e-02, -1.02307117e+00,
        1.78834641e+00, -2.73566216e-01, -1.57942638e-01, -7.98571169e-01,
       -2.77438164e-02, -8.05811465e-01,  3.46736580e-01, -7.20409870e-01,
        1.03382647e-01, -5.33944130e-01, -3.25344890e-01, -1.02880754e-01,
        2.26500735e-01, -8.97880018e-01,  2.52314955e-01, -7.09809303e-01,
[...]        

またget_sentence_embedding()の引数には、文章ベクトルを計算するのに利用するBERTの隠れ層の位置pooling_layerと、プーリングの方法pooling_strategyが指定できます。pooling_layer-1で最終層、-2で最終層の手前の層となります。また、pooling_strategyには

  • REDUCE_MEAN: 要素ごとにaverage-pooling
  • REDUCE_MAX: 要素ごとにmax-pooling
  • REDUCE_MEAN_MAX: REDUCE_MEANREDUCE_MAXを結合したもの
  • CLS_TOKEN: [CLS]トークンのベクトルをそのまま利用

が選択できます。

In []: bert.get_sentence_embedding("吾輩は猫である。",
   ...:                             pooling_layer=-1,
   ...:                             pooling_strategy="REDUCE_MAX")
   ...:
Out[]:
array([ 1.2089624 ,  0.6267309 ,  0.7243419 , -0.12712255,  1.8050476 ,
        0.43929055,  0.605848  ,  0.5058241 ,  0.8335829 , -0.26000524,
[...]        

これらのパラメータはhanxiao/bert-as-serviceを参考にしています。

GPU Option

In []: bert = BertWithJumanModel("../Japanese_L-12_H-768_A-12_E-30_BPE", use_cuda=True)

In []: bert.get_sentence_embedding("吾輩は猫である。")
Out[]:
array([-4.25627649e-01, -3.42006773e-01, -7.15175271e-02, -1.09820020e+00,
        1.08186746e+00, -2.35576674e-01, -1.89862609e-01, -5.50959229e-01,

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages